| | import base64 |
| | import json |
| | import sys |
| | from collections import defaultdict |
| | from io import BytesIO |
| | from pprint import pprint |
| | from typing import Any, Dict, List |
| | import os |
| | import re |
| | from pathlib import Path |
| | from typing import Union |
| | from concurrent.futures import ThreadPoolExecutor |
| | import numpy as np |
| | from PIL import ImageFilter |
| | from transformers import CLIPImageProcessor, CLIPTokenizer, CLIPModel |
| |
|
| | import torch |
| | from diffusers import ( |
| | DiffusionPipeline, |
| | DPMSolverMultistepScheduler, |
| | DPMSolverSinglestepScheduler, |
| | EulerAncestralDiscreteScheduler, |
| | StableDiffusionPipeline, |
| | utils, |
| | ) |
| | from safetensors.torch import load_file |
| | from torch import autocast, tensor |
| | import torchvision.transforms |
| | from PIL import Image |
| |
|
| | REPO_DIR = Path(__file__).resolve().parent |
| |
|
| | |
| | |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | if device.type != "cuda": |
| | raise ValueError("need to run on GPU") |
| |
|
| |
|
| | class EndpointHandler: |
| | LORA_PATHS = { |
| | "hairdetailer": [str(REPO_DIR / "lora/hairdetailer.safetensors"), ""], |
| | "lora_leica": [str(REPO_DIR / "lora/lora_leica.safetensors"), "leica_style"], |
| | "epiNoiseoffset_v2": [str(REPO_DIR / "lora/epiNoiseoffset_v2.safetensors"), ""], |
| | "MBHU-TT2FRS": [ |
| | str(REPO_DIR / "lora/MBHU-TT2FRS.safetensors"), |
| | "flat breast, small breast, big breast, fake breast", |
| | ], |
| | "polyhedron_new_skin_v1.1": [ |
| | str(REPO_DIR / "lora/polyhedron_new_skin_v1.1.safetensors"), |
| | "skin blemish, detailed skin ", |
| | ], |
| | "ShinyOiledSkin_v20": [ |
| | str(REPO_DIR / "lora/ShinyOiledSkin_v20-LoRA.safetensors"), |
| | "shiny skin", |
| | ], |
| | "detailed_eye-10": [str(REPO_DIR / "lora/detailed_eye-10.safetensors"), ""], |
| | "add_detail": [str(REPO_DIR / "lora/add_detail.safetensors"), ""], |
| | "MuscleGirl_v1": [str(REPO_DIR / "lora/MuscleGirl_v1.safetensors"), "abs"], |
| | "nurse_v11-05": [str(REPO_DIR / "lora/nurse_v11-05.safetensors"), "nurse"], |
| | "shibari_v20": [str(REPO_DIR / "lora/shibari_v20.safetensors"), "shibari,rope"], |
| | "tajnaclub_high_heelsv1.2": [ |
| | str(REPO_DIR / "lora/tajnaclub_high_heelsv1.2.safetensors"), |
| | "high heels", |
| | ], |
| | "CyberPunkAI": [ |
| | str(REPO_DIR / "lora/CyberPunkAI.safetensors"), |
| | "neon CyberpunkAI", |
| | ], |
| | "FutaCockCloseUp-v1": [ |
| | str(REPO_DIR / "lora/FutaCockCloseUp-v1.safetensors"), |
| | "huge penis", |
| | ], |
| | "PovBlowjob-v3": [ |
| | str(REPO_DIR / "lora/PovBlowjob-v3.safetensors"), |
| | "blowjob, deepthroat, kneeling, runny makeup, creampie", |
| | ], |
| | "dp_from_behind_v0.1b": [ |
| | str(REPO_DIR / "lora/dp_from_behind_v0.1b.safetensors"), |
| | "1girl, 2boys, double penetration, multiple penises", |
| | ], |
| | "EkuneSideDoggy": [ |
| | str(REPO_DIR / "lora/EkuneSideDoggy.safetensors"), |
| | "sidedoggystyle, doggystyle", |
| | ], |
| | "qqq-grabbing_from_behind-v2-000006": [ |
| | str(REPO_DIR / "lora/qqq-grabbing_from_behind-v2-000006.safetensors"), |
| | "grabbing from behind, breast grab", |
| | ], |
| | "ftm-v0": [ |
| | str(REPO_DIR / "lora/ftm-v0.safetensors"), |
| | "big mouth, tongue, long tongue", |
| | ], |
| | "tgirls_V3_5": [ |
| | str(REPO_DIR / "lora/tgirls_V3_5.safetensors"), |
| | "large penis, penis, erect penis", |
| | ], |
| | "fapp9": [ |
| | str(REPO_DIR / "lora/fapp9.safetensors"), |
| | "large penis, penis, erect penis", |
| | ], |
| | "pov-doggy-graphos": [ |
| | str(REPO_DIR / "lora/pov-doggy-graphos.safetensors"), |
| | "penis in vagina, white man grabbing her ass", |
| | ], |
| | "reelmech1v2": [ |
| | str(REPO_DIR / "lora/reelmech1v2.safetensors"), |
| | "reelmech", |
| | ], |
| | } |
| |
|
| | TEXTUAL_INVERSION = [ |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/EasyNegative.safetensors"), |
| | "token": "easynegative", |
| | }, |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/kkw-NativeAmerican.pt"), |
| | "token": "badhandv4", |
| | }, |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/badhandv4.pt"), |
| | "token": "kkw-Afro, kkw-Asian, kkw-Euro ", |
| | }, |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/bad-artist-anime.pt"), |
| | "token": "bad-artist-anime", |
| | }, |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/NegfeetV2.pt"), |
| | "token": "negfeetv2", |
| | }, |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/ng_deepnegative_v1_75t.pt"), |
| | "token": "ng_deepnegative_v1_75t", |
| | }, |
| | { |
| | "weight_name": str(REPO_DIR / "embeddings/bad-hands-5.pt"), |
| | "token": "bad-hands-5", |
| | }, |
| | ] |
| |
|
| | def __init__(self, path="."): |
| | self.inference_progress = {} |
| | self.inference_images = {} |
| | self.total_steps = {} |
| | self.active_request_ids = list() |
| | self.inference_in_progress = False |
| |
|
| | self.executor = ThreadPoolExecutor( |
| | max_workers=1 |
| | ) |
| |
|
| | realistic_path = str(REPO_DIR / "realistic/") |
| | self.pipe_realistic, self.safety_checker = self.load_realistic(realistic_path) |
| |
|
| | anime_path = str(REPO_DIR / "anime/") |
| | self.pipe_anime, self.pipe_anime_safety_checker = self.load_anime(anime_path) |
| |
|
| | |
| | self.image_processor = CLIPImageProcessor.from_pretrained( |
| | "openai/clip-vit-base-patch16" |
| | ) |
| |
|
| | def load_model_essentials(self, model_path): |
| | """common to all models""" |
| |
|
| | |
| |
|
| | if "realistic" in model_path: |
| | pipe = DiffusionPipeline.from_pretrained( |
| | pretrained_model_name_or_path=model_path, |
| | custom_pipeline="lpw_stable_diffusion", |
| | torch_dtype=torch.float16, |
| | ) |
| |
|
| | safety_checker = pipe.safety_checker.to(device).to(torch.float16) |
| | else: |
| | safety_checker = None |
| |
|
| | pipe = DiffusionPipeline.from_pretrained( |
| | pretrained_model_name_or_path=model_path, |
| | custom_pipeline="lpw_stable_diffusion", |
| | torch_dtype=torch.float16, |
| | safety_checker=None, |
| | ) |
| |
|
| | pipe = pipe.to(device) |
| |
|
| | |
| | pipe.set_progress_bar_config(disable=True) |
| |
|
| | |
| | self.load_embeddings(pipe) |
| |
|
| | |
| | pipe.enable_xformers_memory_efficient_attention() |
| | pipe.enable_attention_slicing() |
| |
|
| | return pipe, safety_checker |
| |
|
| | def load_anime(self, path): |
| | """Load anime model""" |
| |
|
| | |
| | pipe, safety_checker = self.load_model_essentials(path) |
| |
|
| | |
| |
|
| | |
| | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( |
| | pipe.scheduler.config, |
| | ) |
| |
|
| | |
| | |
| | |
| | self.load_selected_loras( |
| | pipe, |
| | [ |
| | |
| | |
| | ["MuscleGirl_v1", 0.1], |
| | ["tgirls_V3_5", 0.02], |
| | ["PovBlowjob-v3", 0.02], |
| | ["pov-doggy-graphos", 0.02], |
| | ["shibari_v20", 0.02], |
| | ["ftm-v0", 0.02], |
| | ["reelmech1v2", 0.02], |
| | ], |
| | ) |
| |
|
| | return pipe, safety_checker |
| |
|
| | def load_realistic(self, path): |
| | """Load realistic model""" |
| |
|
| | |
| | pipe, safety_checker = self.load_model_essentials(path) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | pipe.scheduler = DPMSolverMultistepScheduler.from_config( |
| | pipe.scheduler.config, |
| | algorithm_type="sde-dpmsolver++", |
| | use_karras_sigmas=True, |
| | ) |
| |
|
| | |
| | |
| | |
| | self.load_selected_loras( |
| | pipe, |
| | [ |
| | ["polyhedron_new_skin_v1.1", 0.15], |
| | ["detailed_eye-10", 0.2], |
| | ["add_detail", 0.1], |
| | ["MuscleGirl_v1", 0.2], |
| | ["tgirls_V3_5", 0.02], |
| | ["PovBlowjob-v3", 0.02], |
| | ["pov-doggy-graphos", 0.02], |
| | ["shibari_v20", 0.02], |
| | ["ftm-v0", 0.02], |
| | ["reelmech1v2", 0.02], |
| | ], |
| | ) |
| |
|
| | return pipe, safety_checker |
| |
|
| | def load_lora(self, pipeline, lora_path, lora_weight=0.5): |
| | state_dict = load_file(lora_path) |
| | LORA_PREFIX_UNET = "lora_unet" |
| | LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| |
|
| | alpha = lora_weight |
| | visited = [] |
| |
|
| | for key in state_dict: |
| | state_dict[key] = state_dict[key].to(device) |
| |
|
| | |
| | for key in state_dict: |
| | |
| | if ".alpha" in key or key in visited: |
| | continue |
| |
|
| | if "text" in key: |
| | layer_infos = ( |
| | key.split(".")[0] |
| | .split(LORA_PREFIX_TEXT_ENCODER + "_")[-1] |
| | .split("_") |
| | ) |
| | curr_layer = pipeline.text_encoder |
| | else: |
| | layer_infos = ( |
| | key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") |
| | ) |
| | curr_layer = pipeline.unet |
| |
|
| | |
| | temp_name = layer_infos.pop(0) |
| | while len(layer_infos) > -1: |
| | try: |
| | curr_layer = curr_layer.__getattr__(temp_name) |
| | if len(layer_infos) > 0: |
| | temp_name = layer_infos.pop(0) |
| | elif len(layer_infos) == 0: |
| | break |
| | except Exception: |
| | if len(temp_name) > 0: |
| | temp_name += "_" + layer_infos.pop(0) |
| | else: |
| | temp_name = layer_infos.pop(0) |
| |
|
| | |
| | pair_keys = [] |
| | if "lora_down" in key: |
| | pair_keys.append(key.replace("lora_down", "lora_up")) |
| | pair_keys.append(key) |
| | else: |
| | pair_keys.append(key) |
| | pair_keys.append(key.replace("lora_up", "lora_down")) |
| |
|
| | |
| | if len(state_dict[pair_keys[0]].shape) == 4: |
| | weight_up = ( |
| | state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
| | ) |
| | weight_down = ( |
| | state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
| | ) |
| | curr_layer.weight.data += alpha * torch.mm( |
| | weight_up, weight_down |
| | ).unsqueeze(2).unsqueeze(3) |
| | else: |
| | weight_up = state_dict[pair_keys[0]].to(torch.float32) |
| | weight_down = state_dict[pair_keys[1]].to(torch.float32) |
| | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) |
| |
|
| | |
| | for item in pair_keys: |
| | visited.append(item) |
| |
|
| | return pipeline |
| |
|
| | def load_embeddings(self, pipeline): |
| | """Load textual inversions, avoid bad prompts""" |
| | for model in EndpointHandler.TEXTUAL_INVERSION: |
| | pipeline.load_textual_inversion( |
| | ".", weight_name=model["weight_name"], token=model["token"] |
| | ) |
| |
|
| | def load_selected_loras(self, pipeline, selections): |
| | """Load Loras models, can lead to marvelous creations""" |
| | for model_name, weight in selections: |
| | lora_path = EndpointHandler.LORA_PATHS[model_name][0] |
| | |
| | self.load_lora(pipeline, lora_path, weight) |
| |
|
| | def clean_negative_prompt(self, negative_prompt): |
| | """Clean negative prompt to remove already used negative prompt handlers""" |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | tokens = [item["token"] for item in self.TEXTUAL_INVERSION] |
| |
|
| | |
| | for token in tokens: |
| | |
| | negative_prompt = re.sub( |
| | r"\b" + re.escape(token) + r"\b", |
| | "", |
| | negative_prompt, |
| | flags=re.IGNORECASE, |
| | ).strip() |
| |
|
| | |
| | negative_prompt += " " + " ".join(tokens) |
| |
|
| | return negative_prompt |
| |
|
| | def check_fields(self, data): |
| | """check for fields, if some missing return error""" |
| |
|
| | |
| | required_fields = [ |
| | "prompt", |
| | "negative_prompt", |
| | "width", |
| | "num_inference_steps", |
| | "height", |
| | "guidance_scale", |
| | "request_id", |
| | ] |
| |
|
| | missing_fields = [field for field in required_fields if field not in data] |
| |
|
| | if missing_fields: |
| | return { |
| | "flag": "error", |
| | "message": f"Missing fields: {', '.join(missing_fields)}", |
| | } |
| |
|
| | return False |
| |
|
| | def clean_request_data(self): |
| | """Clean up the data related to a specific request ID.""" |
| |
|
| | |
| | self.inference_progress.clear() |
| |
|
| | |
| | self.inference_images.clear() |
| |
|
| | |
| | self.total_steps.clear() |
| |
|
| | |
| | self.active_request_ids.clear() |
| |
|
| | |
| | self.inference_in_progress = False |
| |
|
| | def progress_callback( |
| | self, |
| | step: int, |
| | timestep: int, |
| | latents: Any, |
| | request_id: str, |
| | status: str, |
| | pipeline: Any, |
| | ): |
| | try: |
| | if status == "progress": |
| | |
| | img_data = pipeline.decode_latents(latents) |
| | img_data = (img_data.squeeze() * 255).astype(np.uint8) |
| | img = Image.fromarray(img_data, "RGB") |
| |
|
| | |
| | |
| | if step < int(self.total_steps[self.active_request_ids[0]] / 1.5): |
| | img = img.filter(ImageFilter.GaussianBlur(radius=30)) |
| | else: |
| | img = img.filter(ImageFilter.GaussianBlur(radius=10)) |
| |
|
| | |
| | else: |
| | |
| | |
| |
|
| | img = latents |
| |
|
| | buffered = BytesIO() |
| | img.save(buffered, format="PNG") |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | img_str = base64.b64encode(buffered.getvalue()).decode() |
| |
|
| | except Exception as e: |
| | print(f"Error: {e}") |
| |
|
| | |
| | progress_percentage = ( |
| | step / self.total_steps[request_id] |
| | ) * 100 |
| |
|
| | self.inference_progress[request_id] = progress_percentage |
| | self.inference_images[request_id] = img_str |
| |
|
| | def check_progress(self, request_id: str) -> Dict[str, Union[str, float]]: |
| | progress = self.inference_progress.get(request_id, 0) |
| | latest_image = self.inference_images.get(request_id, None) |
| |
|
| | |
| |
|
| | if progress >= 100: |
| | status = "complete" |
| |
|
| | |
| | image_data = base64.b64decode(latest_image) |
| | image_io = BytesIO(image_data) |
| | is_nsfw = self.check_nsfw(Image.open(image_io))[0] |
| | |
| | else: |
| | status = "in-progress" |
| | is_nsfw = "" |
| |
|
| | return { |
| | "flag": "success", |
| | "status": status, |
| | "progress": int(progress), |
| | "image": latest_image, |
| | "is_nsfw": is_nsfw, |
| | } |
| |
|
| | def check_nsfw(self, image): |
| | """Check if image is NSFW""" |
| |
|
| | safety_checker_input = self.image_processor(image, return_tensors="pt").to( |
| | device |
| | ) |
| |
|
| | image, has_nsfw_concept = self.safety_checker( |
| | images=np.array(image), |
| | clip_input=safety_checker_input.pixel_values.to(torch.float16), |
| | ) |
| |
|
| | return has_nsfw_concept |
| |
|
| | def start_inference(self, pipeline, data: Dict) -> Dict: |
| | """Start a new inference.""" |
| |
|
| | global device |
| |
|
| | |
| | prompt = data["prompt"] |
| | negative_prompt = data["negative_prompt"] |
| | loras_model = data.get("loras_model", None) |
| | seed = data.get("seed", None) |
| | width = data["width"] |
| | num_inference_steps = data["num_inference_steps"] |
| | height = data["height"] |
| | guidance_scale = data["guidance_scale"] |
| | request_id = data["request_id"] |
| |
|
| | |
| | self.total_steps[request_id] = num_inference_steps |
| |
|
| | |
| | forced_negative = self.clean_negative_prompt(negative_prompt) |
| |
|
| | |
| | generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None |
| |
|
| | |
| | |
| | |
| |
|
| | try: |
| | |
| | with autocast(device.type): |
| | image = pipeline.text2img( |
| | prompt=prompt, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | height=height, |
| | width=width, |
| | negative_prompt=forced_negative, |
| | generator=generator, |
| | max_embeddings_multiples=5, |
| | callback=lambda step, timestep, latents: self.progress_callback( |
| | step, timestep, latents, request_id, "progress", pipeline |
| | ), |
| | callback_steps=5, |
| | |
| | ) |
| |
|
| | |
| | self.progress_callback( |
| | num_inference_steps, |
| | 0, |
| | image.images[0], |
| | request_id, |
| | "complete", |
| | pipeline, |
| | ) |
| |
|
| | self.inference_in_progress = False |
| |
|
| | |
| | |
| |
|
| | except Exception as e: |
| | |
| | return {"flag": "error", "message": str(e)} |
| |
|
| | def __call__(self, data: Any) -> Dict: |
| | """Handle incoming requests.""" |
| |
|
| | action = data.get("action", None) |
| | request_id = data.get("request_id") |
| | genre = data.get("genre") |
| |
|
| | |
| | if not request_id: |
| | return {"flag": "error", "message": "Missing request_id."} |
| |
|
| | if action == "check_progress": |
| | if request_id not in self.active_request_ids: |
| | return { |
| | "flag": "error", |
| | "message": "Request id doesn't match any active request.", |
| | } |
| | return self.check_progress(request_id) |
| |
|
| | elif action == "inference": |
| | |
| | check_fields = self.check_fields(data) |
| | if check_fields: |
| | return check_fields |
| |
|
| | |
| | if self.inference_in_progress: |
| | return { |
| | "flag": "error", |
| | "message": "Another inference is already in progress. Please wait.", |
| | } |
| |
|
| | |
| | self.clean_request_data() |
| | self.inference_in_progress = True |
| | self.inference_progress[request_id] = 0 |
| | self.inference_images[request_id] = None |
| | self.active_request_ids.append(request_id) |
| |
|
| | |
| | if genre == "anime": |
| | pipe = self.pipe_anime |
| | else: |
| | pipe = self.pipe_realistic |
| |
|
| | self.executor.submit(self.start_inference, pipe, data) |
| | |
| |
|
| | return { |
| | "flag": "success", |
| | "status": "started", |
| | "message": "Inference started", |
| | "request_id": request_id, |
| | } |
| |
|
| | else: |
| | return {"flag": "error", "message": f"Unsupported action: {action}"} |
| |
|