| import torch | |
| import os | |
| from diffusers import ( | |
| DDPMScheduler, | |
| StableDiffusionXLImg2ImgPipeline, | |
| LTXPipeline, | |
| AutoencoderKL, | |
| ) | |
| from hidiffusion import apply_hidiffusion | |
| from mediapipe.tasks import python | |
| from mediapipe.tasks.python import vision | |
| from image_gen_aux import UpscaleWithModel | |
| BASE_MODEL = "stabilityai/sdxl-turbo" | |
| VIDEO_MODEL = "Lightricks/LTX-Video" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| class ModelHandler: | |
| def __init__(self): | |
| self.base_pipe = None | |
| self.video_pipe = None | |
| self.compiled_model = None | |
| self.segmenter = None | |
| self.upscaler = None | |
| self.upscaler4SD = None | |
| self.load_models() | |
| def load_base(self): | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", | |
| torch_dtype=torch.float16, | |
| ) | |
| base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
| BASE_MODEL, | |
| vae=vae, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| use_safetensors=True, | |
| ) | |
| base_pipe = base_pipe.to(device, silence_dtype_warnings=True) | |
| base_pipe.scheduler = DDPMScheduler.from_pretrained( | |
| BASE_MODEL, | |
| subfolder="scheduler", | |
| ) | |
| apply_hidiffusion(base_pipe) | |
| return base_pipe | |
| def load_video_pipe(self): | |
| pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16) | |
| pipe.to(device) | |
| return pipe | |
| def load_segmenter(self): | |
| segment_model = "checkpoints/selfie_multiclass_256x256.tflite" | |
| base_options = python.BaseOptions(model_asset_path=segment_model) | |
| options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) | |
| segmenter = vision.ImageSegmenter.create_from_options(options) | |
| return segmenter | |
| def load_upscaler(self): | |
| model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR") | |
| upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) | |
| return upscaler | |
| def load_upscaler4SD(self): | |
| model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf") | |
| upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) | |
| return upscaler | |
| def load_models(self): | |
| base_pipe = self.load_base() | |
| segmenter = self.load_segmenter() | |
| upscaler = self.load_upscaler() | |
| upscaler4SD = self.load_upscaler4SD() | |
| self.base_pipe = base_pipe | |
| self.segmenter = segmenter | |
| self.upscaler = upscaler | |
| self.upscaler4SD = upscaler4SD | |
| MODELS = ModelHandler() |