Spaces:
Configuration error
Configuration error
| import os | |
| import folder_paths | |
| from copy import deepcopy | |
| from .conf import hydit_conf | |
| from .loader import load_hydit | |
| class HYDiTCheckpointLoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), | |
| "model": (list(hydit_conf.keys()),{"default":"G/2"}), | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| RETURN_NAMES = ("model",) | |
| FUNCTION = "load_checkpoint" | |
| CATEGORY = "ExtraModels/HunyuanDiT" | |
| TITLE = "Hunyuan DiT Checkpoint Loader" | |
| def load_checkpoint(self, ckpt_name, model): | |
| ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
| model_conf = hydit_conf[model] | |
| model = load_hydit( | |
| model_path = ckpt_path, | |
| model_conf = model_conf, | |
| ) | |
| return (model,) | |
| #### temp stuff for the text encoder #### | |
| import torch | |
| from .tenc import load_clip, load_t5 | |
| from ..utils.dtype import string_to_dtype | |
| dtypes = [ | |
| "default", | |
| "auto (comfy)", | |
| "FP32", | |
| "FP16", | |
| "BF16" | |
| ] | |
| class HYDiTTextEncoderLoader: | |
| def INPUT_TYPES(s): | |
| devices = ["auto", "cpu", "gpu"] | |
| # hack for using second GPU as offload | |
| for k in range(1, torch.cuda.device_count()): | |
| devices.append(f"cuda:{k}") | |
| return { | |
| "required": { | |
| "clip_name": (folder_paths.get_filename_list("clip"),), | |
| "mt5_name": (folder_paths.get_filename_list("t5"),), | |
| "device": (devices, {"default":"cpu"}), | |
| "dtype": (dtypes,), | |
| } | |
| } | |
| RETURN_TYPES = ("CLIP", "T5") | |
| FUNCTION = "load_model" | |
| CATEGORY = "ExtraModels/HunyuanDiT" | |
| TITLE = "Hunyuan DiT Text Encoder Loader" | |
| def load_model(self, clip_name, mt5_name, device, dtype): | |
| dtype = string_to_dtype(dtype, "text_encoder") | |
| if device == "cpu": | |
| assert dtype in [None, torch.float32, torch.bfloat16], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default' or 'bf16'." | |
| clip = load_clip( | |
| model_path = folder_paths.get_full_path("clip", clip_name), | |
| device = device, | |
| dtype = dtype, | |
| ) | |
| t5 = load_t5( | |
| model_path = folder_paths.get_full_path("t5", mt5_name), | |
| device = device, | |
| dtype = dtype, | |
| ) | |
| return(clip, t5) | |
| class HYDiTTextEncode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "text": ("STRING", {"multiline": True}), | |
| "text_t5": ("STRING", {"multiline": True}), | |
| "CLIP": ("CLIP",), | |
| "T5": ("T5",), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| FUNCTION = "encode" | |
| CATEGORY = "ExtraModels/HunyuanDiT" | |
| TITLE = "Hunyuan DiT Text Encode" | |
| def encode(self, text, text_t5, CLIP, T5): | |
| # T5 | |
| T5.load_model() | |
| t5_pre = T5.tokenizer( | |
| text_t5, | |
| max_length = T5.cond_stage_model.max_length, | |
| padding = 'max_length', | |
| truncation = True, | |
| return_attention_mask = True, | |
| add_special_tokens = True, | |
| return_tensors = 'pt' | |
| ) | |
| t5_mask = t5_pre["attention_mask"] | |
| with torch.no_grad(): | |
| t5_outs = T5.cond_stage_model.transformer( | |
| input_ids = t5_pre["input_ids"].to(T5.load_device), | |
| attention_mask = t5_mask.to(T5.load_device), | |
| output_hidden_states = True, | |
| ) | |
| # to-do: replace -1 for clip skip | |
| t5_embs = t5_outs["hidden_states"][-1].float().cpu() | |
| # "clip" | |
| CLIP.load_model() | |
| clip_pre = CLIP.tokenizer( | |
| text, | |
| max_length = CLIP.cond_stage_model.max_length, | |
| padding = 'max_length', | |
| truncation = True, | |
| return_attention_mask = True, | |
| add_special_tokens = True, | |
| return_tensors = 'pt' | |
| ) | |
| clip_mask = clip_pre["attention_mask"] | |
| with torch.no_grad(): | |
| clip_outs = CLIP.cond_stage_model.transformer( | |
| input_ids = clip_pre["input_ids"].to(CLIP.load_device), | |
| attention_mask = clip_mask.to(CLIP.load_device), | |
| ) | |
| # to-do: add hidden states | |
| clip_embs = clip_outs[0].float().cpu() | |
| # combined cond | |
| return ([[ | |
| clip_embs, { | |
| "context_t5": t5_embs, | |
| "context_mask": clip_mask.float(), | |
| "context_t5_mask": t5_mask.float() | |
| } | |
| ]],) | |
| class HYDiTTextEncodeSimple(HYDiTTextEncode): | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "text": ("STRING", {"multiline": True}), | |
| "CLIP": ("CLIP",), | |
| "T5": ("T5",), | |
| } | |
| } | |
| FUNCTION = "encode_simple" | |
| TITLE = "Hunyuan DiT Text Encode (simple)" | |
| def encode_simple(self, text, **args): | |
| return self.encode(text=text, text_t5=text, **args) | |
| class HYDiTSrcSizeCond: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "cond": ("CONDITIONING", ), | |
| "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}), | |
| "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| RETURN_NAMES = ("cond",) | |
| FUNCTION = "add_cond" | |
| CATEGORY = "ExtraModels/HunyuanDiT" | |
| TITLE = "Hunyuan DiT Size Conditioning (advanced)" | |
| def add_cond(self, cond, width, height): | |
| cond = deepcopy(cond) | |
| for c in range(len(cond)): | |
| cond[c][1].update({ | |
| "src_size_cond": [[height, width]], | |
| }) | |
| return (cond,) | |
| NODE_CLASS_MAPPINGS = { | |
| "HYDiTCheckpointLoader": HYDiTCheckpointLoader, | |
| "HYDiTTextEncoderLoader": HYDiTTextEncoderLoader, | |
| "HYDiTTextEncode": HYDiTTextEncode, | |
| "HYDiTTextEncodeSimple": HYDiTTextEncodeSimple, | |
| "HYDiTSrcSizeCond": HYDiTSrcSizeCond, | |
| } | |