Spaces:
Configuration error
Configuration error
| import comfy.supported_models_base | |
| import comfy.latent_formats | |
| import comfy.model_patcher | |
| import comfy.model_base | |
| import comfy.utils | |
| import comfy.conds | |
| import torch | |
| from comfy import model_management | |
| from tqdm import tqdm | |
| class EXM_HYDiT(comfy.supported_models_base.BASE): | |
| unet_config = {} | |
| unet_extra_config = {} | |
| latent_format = comfy.latent_formats.SDXL | |
| def __init__(self, model_conf): | |
| self.unet_config = model_conf.get("unet_config", {}) | |
| self.sampling_settings = model_conf.get("sampling_settings", {}) | |
| self.latent_format = self.latent_format() | |
| # UNET is handled by extension | |
| self.unet_config["disable_unet_model_creation"] = True | |
| def model_type(self, state_dict, prefix=""): | |
| return comfy.model_base.ModelType.V_PREDICTION | |
| class EXM_HYDiT_Model(comfy.model_base.BaseModel): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def extra_conds(self, **kwargs): | |
| out = super().extra_conds(**kwargs) | |
| for name in ["context_t5", "context_mask", "context_t5_mask"]: | |
| out[name] = comfy.conds.CONDRegular(kwargs[name]) | |
| src_size_cond = kwargs.get("src_size_cond", None) | |
| if src_size_cond is not None: | |
| out["src_size_cond"] = comfy.conds.CONDRegular(torch.tensor(src_size_cond)) | |
| return out | |
| def load_hydit(model_path, model_conf): | |
| state_dict = comfy.utils.load_torch_file(model_path) | |
| state_dict = state_dict.get("model", state_dict) | |
| parameters = comfy.utils.calculate_parameters(state_dict) | |
| unet_dtype = model_management.unet_dtype(model_params=parameters) | |
| load_device = comfy.model_management.get_torch_device() | |
| offload_device = comfy.model_management.unet_offload_device() | |
| # ignore fp8/etc and use directly for now | |
| manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) | |
| if manual_cast_dtype: | |
| print(f"HunYuanDiT: falling back to {manual_cast_dtype}") | |
| unet_dtype = manual_cast_dtype | |
| model_conf = EXM_HYDiT(model_conf) | |
| model = EXM_HYDiT_Model( | |
| model_conf, | |
| model_type=comfy.model_base.ModelType.V_PREDICTION, | |
| device=model_management.get_torch_device() | |
| ) | |
| from .models.models import HunYuanDiT | |
| model.diffusion_model = HunYuanDiT( | |
| **model_conf.unet_config, | |
| log_fn=tqdm.write, | |
| ) | |
| model.diffusion_model.load_state_dict(state_dict) | |
| model.diffusion_model.dtype = unet_dtype | |
| model.diffusion_model.eval() | |
| model.diffusion_model.to(unet_dtype) | |
| model_patcher = comfy.model_patcher.ModelPatcher( | |
| model, | |
| load_device = load_device, | |
| offload_device = offload_device, | |
| current_device = "cpu", | |
| ) | |
| return model_patcher | |