Spaces:
Configuration error
Configuration error
| import torch | |
| from torch import nn | |
| import torchvision.transforms as T | |
| import torch.nn.functional as F | |
| import os | |
| import math | |
| import folder_paths | |
| import comfy.utils | |
| from insightface.app import FaceAnalysis | |
| from facexlib.parsing import init_parsing_model | |
| from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
| from comfy.ldm.modules.attention import optimized_attention | |
| from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD | |
| from .encoders import IDEncoder | |
| INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface") | |
| MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid") | |
| if "pulid" not in folder_paths.folder_names_and_paths: | |
| current_paths = [MODELS_DIR] | |
| else: | |
| current_paths, _ = folder_paths.folder_names_and_paths["pulid"] | |
| folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions) | |
| class PulidModel(nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| self.image_proj_model = self.init_id_adapter() | |
| self.image_proj_model.load_state_dict(model["image_proj"]) | |
| self.ip_layers = To_KV(model["ip_adapter"]) | |
| def init_id_adapter(self): | |
| image_proj_model = IDEncoder() | |
| return image_proj_model | |
| def get_image_embeds(self, face_embed, clip_embeds): | |
| embeds = self.image_proj_model(face_embed, clip_embeds) | |
| return embeds | |
| class To_KV(nn.Module): | |
| def __init__(self, state_dict): | |
| super().__init__() | |
| self.to_kvs = nn.ModuleDict() | |
| for key, value in state_dict.items(): | |
| self.to_kvs[key.replace(".weight", "").replace(".", "_")] = nn.Linear(value.shape[1], value.shape[0], bias=False) | |
| self.to_kvs[key.replace(".weight", "").replace(".", "_")].weight.data = value | |
| def tensor_to_image(tensor): | |
| image = tensor.mul(255).clamp(0, 255).byte().cpu() | |
| image = image[..., [2, 1, 0]].numpy() | |
| return image | |
| def image_to_tensor(image): | |
| tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) | |
| tensor = tensor[..., [2, 1, 0]] | |
| return tensor | |
| def tensor_to_size(source, dest_size): | |
| if isinstance(dest_size, torch.Tensor): | |
| dest_size = dest_size.shape[0] | |
| source_size = source.shape[0] | |
| if source_size < dest_size: | |
| shape = [dest_size - source_size] + [1]*(source.dim()-1) | |
| source = torch.cat((source, source[-1:].repeat(shape)), dim=0) | |
| elif source_size > dest_size: | |
| source = source[:dest_size] | |
| return source | |
| def set_model_patch_replace(model, patch_kwargs, key): | |
| to = model.model_options["transformer_options"].copy() | |
| if "patches_replace" not in to: | |
| to["patches_replace"] = {} | |
| else: | |
| to["patches_replace"] = to["patches_replace"].copy() | |
| if "attn2" not in to["patches_replace"]: | |
| to["patches_replace"]["attn2"] = {} | |
| else: | |
| to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy() | |
| if key not in to["patches_replace"]["attn2"]: | |
| to["patches_replace"]["attn2"][key] = Attn2Replace(pulid_attention, **patch_kwargs) | |
| model.model_options["transformer_options"] = to | |
| else: | |
| to["patches_replace"]["attn2"][key].add(pulid_attention, **patch_kwargs) | |
| class Attn2Replace: | |
| def __init__(self, callback=None, **kwargs): | |
| self.callback = [callback] | |
| self.kwargs = [kwargs] | |
| def add(self, callback, **kwargs): | |
| self.callback.append(callback) | |
| self.kwargs.append(kwargs) | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| def __call__(self, q, k, v, extra_options): | |
| dtype = q.dtype | |
| out = optimized_attention(q, k, v, extra_options["n_heads"]) | |
| sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9 | |
| for i, callback in enumerate(self.callback): | |
| if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: | |
| out = out + callback(out, q, k, v, extra_options, **self.kwargs[i]) | |
| return out.to(dtype=dtype) | |
| def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond=None, uncond=None, weight=1.0, ortho=False, ortho_v2=False, mask=None, **kwargs): | |
| k_key = module_key + "_to_k_ip" | |
| v_key = module_key + "_to_v_ip" | |
| dtype = q.dtype | |
| seq_len = q.shape[1] | |
| cond_or_uncond = extra_options["cond_or_uncond"] | |
| b = q.shape[0] | |
| batch_prompt = b // len(cond_or_uncond) | |
| _, _, oh, ow = extra_options["original_shape"] | |
| #conds = torch.cat([uncond.repeat(batch_prompt, 1, 1), cond.repeat(batch_prompt, 1, 1)], dim=0) | |
| #zero_tensor = torch.zeros((conds.size(0), num_zero, conds.size(-1)), dtype=conds.dtype, device=conds.device) | |
| #conds = torch.cat([conds, zero_tensor], dim=1) | |
| #ip_k = pulid.ip_layers.to_kvs[k_key](conds) | |
| #ip_v = pulid.ip_layers.to_kvs[v_key](conds) | |
| k_cond = pulid.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) | |
| k_uncond = pulid.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) | |
| v_cond = pulid.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) | |
| v_uncond = pulid.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) | |
| ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) | |
| ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) | |
| out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) | |
| if ortho: | |
| out = out.to(dtype=torch.float32) | |
| out_ip = out_ip.to(dtype=torch.float32) | |
| projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) | |
| orthogonal = out_ip - projection | |
| out_ip = weight * orthogonal | |
| elif ortho_v2: | |
| out = out.to(dtype=torch.float32) | |
| out_ip = out_ip.to(dtype=torch.float32) | |
| attn_map = q @ ip_k.transpose(-2, -1) | |
| attn_mean = attn_map.softmax(dim=-1).mean(dim=1, keepdim=True) | |
| attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) | |
| projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) | |
| orthogonal = out_ip + (attn_mean - 1) * projection | |
| out_ip = weight * orthogonal | |
| else: | |
| out_ip = out_ip * weight | |
| if mask is not None: | |
| mask_h = oh / math.sqrt(oh * ow / seq_len) | |
| mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) | |
| mask_w = seq_len // mask_h | |
| mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) | |
| mask = tensor_to_size(mask, batch_prompt) | |
| mask = mask.repeat(len(cond_or_uncond), 1, 1) | |
| mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) | |
| # covers cases where extreme aspect ratios can cause the mask to have a wrong size | |
| mask_len = mask_h * mask_w | |
| if mask_len < seq_len: | |
| pad_len = seq_len - mask_len | |
| pad1 = pad_len // 2 | |
| pad2 = pad_len - pad1 | |
| mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) | |
| elif mask_len > seq_len: | |
| crop_start = (mask_len - seq_len) // 2 | |
| mask = mask[:, crop_start:crop_start+seq_len, :] | |
| out_ip = out_ip * mask | |
| return out_ip.to(dtype=dtype) | |
| def to_gray(img): | |
| x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] | |
| x = x.repeat(1, 3, 1, 1) | |
| return x | |
| """ | |
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| Nodes | |
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| """ | |
| class PulidModelLoader: | |
| def INPUT_TYPES(s): | |
| return {"required": { "pulid_file": (folder_paths.get_filename_list("pulid"), )}} | |
| RETURN_TYPES = ("PULID",) | |
| FUNCTION = "load_model" | |
| CATEGORY = "pulid" | |
| def load_model(self, pulid_file): | |
| ckpt_path = folder_paths.get_full_path("pulid", pulid_file) | |
| model = comfy.utils.load_torch_file(ckpt_path, safe_load=True) | |
| if ckpt_path.lower().endswith(".safetensors"): | |
| st_model = {"image_proj": {}, "ip_adapter": {}} | |
| for key in model.keys(): | |
| if key.startswith("image_proj."): | |
| st_model["image_proj"][key.replace("image_proj.", "")] = model[key] | |
| elif key.startswith("ip_adapter."): | |
| st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] | |
| model = st_model | |
| # Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node | |
| model = PulidModel(model) | |
| return (model,) | |
| class PulidInsightFaceLoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "provider": (["CPU", "CUDA", "ROCM"], ), | |
| }, | |
| } | |
| RETURN_TYPES = ("FACEANALYSIS",) | |
| FUNCTION = "load_insightface" | |
| CATEGORY = "pulid" | |
| def load_insightface(self, provider): | |
| model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l | |
| model.prepare(ctx_id=0, det_size=(640, 640)) | |
| return (model,) | |
| class PulidEvaClipLoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": {}, | |
| } | |
| RETURN_TYPES = ("EVA_CLIP",) | |
| FUNCTION = "load_eva_clip" | |
| CATEGORY = "pulid" | |
| def load_eva_clip(self): | |
| from .eva_clip.factory import create_model_and_transforms | |
| model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) | |
| model = model.visual | |
| eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN) | |
| eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD) | |
| if not isinstance(eva_transform_mean, (list, tuple)): | |
| model["image_mean"] = (eva_transform_mean,) * 3 | |
| if not isinstance(eva_transform_std, (list, tuple)): | |
| model["image_std"] = (eva_transform_std,) * 3 | |
| return (model,) | |
| class ApplyPulid: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("MODEL", ), | |
| "pulid": ("PULID", ), | |
| "eva_clip": ("EVA_CLIP", ), | |
| "face_analysis": ("FACEANALYSIS", ), | |
| "image": ("IMAGE", ), | |
| "method": (["fidelity", "style", "neutral"],), | |
| "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), | |
| "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
| "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
| }, | |
| "optional": { | |
| "attn_mask": ("MASK", ), | |
| }, | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| FUNCTION = "apply_pulid" | |
| CATEGORY = "pulid" | |
| def apply_pulid(self, model, pulid, eva_clip, face_analysis, image, weight, start_at, end_at, method=None, noise=0.0, fidelity=None, projection=None, attn_mask=None): | |
| work_model = model.clone() | |
| device = comfy.model_management.get_torch_device() | |
| dtype = comfy.model_management.unet_dtype() | |
| if dtype not in [torch.float32, torch.float16, torch.bfloat16]: | |
| dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32 | |
| eva_clip.to(device, dtype=dtype) | |
| pulid_model = pulid.to(device, dtype=dtype) | |
| if attn_mask is not None: | |
| if attn_mask.dim() > 3: | |
| attn_mask = attn_mask.squeeze(-1) | |
| elif attn_mask.dim() < 3: | |
| attn_mask = attn_mask.unsqueeze(0) | |
| attn_mask = attn_mask.to(device, dtype=dtype) | |
| if method == "fidelity" or projection == "ortho_v2": | |
| num_zero = 8 | |
| ortho = False | |
| ortho_v2 = True | |
| elif method == "style" or projection == "ortho": | |
| num_zero = 16 | |
| ortho = True | |
| ortho_v2 = False | |
| else: | |
| num_zero = 0 | |
| ortho = False | |
| ortho_v2 = False | |
| if fidelity is not None: | |
| num_zero = fidelity | |
| #face_analysis.det_model.input_size = (640,640) | |
| image = tensor_to_image(image) | |
| face_helper = FaceRestoreHelper( | |
| upscale_factor=1, | |
| face_size=512, | |
| crop_ratio=(1, 1), | |
| det_model='retinaface_resnet50', | |
| save_ext='png', | |
| device=device, | |
| ) | |
| face_helper.face_parse = None | |
| face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device) | |
| bg_label = [0, 16, 18, 7, 8, 9, 14, 15] | |
| cond = [] | |
| uncond = [] | |
| for i in range(image.shape[0]): | |
| # get insightface embeddings | |
| iface_embeds = None | |
| for size in [(size, size) for size in range(640, 256, -64)]: | |
| face_analysis.det_model.input_size = size | |
| face = face_analysis.get(image[i]) | |
| if face: | |
| face = sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)[-1] | |
| iface_embeds = torch.from_numpy(face.embedding).unsqueeze(0).to(device, dtype=dtype) | |
| break | |
| else: | |
| raise Exception('insightface: No face detected.') | |
| # get eva_clip embeddings | |
| face_helper.clean_all() | |
| face_helper.read_image(image[i]) | |
| face_helper.get_face_landmarks_5(only_center_face=True) | |
| face_helper.align_warp_face() | |
| if len(face_helper.cropped_faces) == 0: | |
| raise Exception('facexlib: No face detected.') | |
| face = face_helper.cropped_faces[0] | |
| face = image_to_tensor(face).unsqueeze(0).permute(0,3,1,2).to(device) | |
| parsing_out = face_helper.face_parse(T.functional.normalize(face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] | |
| parsing_out = parsing_out.argmax(dim=1, keepdim=True) | |
| bg = sum(parsing_out == i for i in bg_label).bool() | |
| white_image = torch.ones_like(face) | |
| face_features_image = torch.where(bg, white_image, to_gray(face)) | |
| # apparently MPS only supports NEAREST interpolation? | |
| face_features_image = T.functional.resize(face_features_image, eva_clip.image_size, T.InterpolationMode.BICUBIC if 'cuda' in device.type else T.InterpolationMode.NEAREST).to(device, dtype=dtype) | |
| face_features_image = T.functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std) | |
| id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) | |
| id_cond_vit = id_cond_vit.to(device, dtype=dtype) | |
| for idx in range(len(id_vit_hidden)): | |
| id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype) | |
| id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True)) | |
| # combine embeddings | |
| id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1) | |
| if noise == 0: | |
| id_uncond = torch.zeros_like(id_cond) | |
| else: | |
| id_uncond = torch.rand_like(id_cond) * noise | |
| id_vit_hidden_uncond = [] | |
| for idx in range(len(id_vit_hidden)): | |
| if noise == 0: | |
| id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[idx])) | |
| else: | |
| id_vit_hidden_uncond.append(torch.rand_like(id_vit_hidden[idx]) * noise) | |
| cond.append(pulid_model.get_image_embeds(id_cond, id_vit_hidden)) | |
| uncond.append(pulid_model.get_image_embeds(id_uncond, id_vit_hidden_uncond)) | |
| # average embeddings | |
| cond = torch.cat(cond).to(device, dtype=dtype) | |
| uncond = torch.cat(uncond).to(device, dtype=dtype) | |
| if cond.shape[0] > 1: | |
| cond = torch.mean(cond, dim=0, keepdim=True) | |
| uncond = torch.mean(uncond, dim=0, keepdim=True) | |
| if num_zero > 0: | |
| if noise == 0: | |
| zero_tensor = torch.zeros((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) | |
| else: | |
| zero_tensor = torch.rand((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) * noise | |
| cond = torch.cat([cond, zero_tensor], dim=1) | |
| uncond = torch.cat([uncond, zero_tensor], dim=1) | |
| sigma_start = work_model.get_model_object("model_sampling").percent_to_sigma(start_at) | |
| sigma_end = work_model.get_model_object("model_sampling").percent_to_sigma(end_at) | |
| patch_kwargs = { | |
| "pulid": pulid_model, | |
| "weight": weight, | |
| "cond": cond, | |
| "uncond": uncond, | |
| "sigma_start": sigma_start, | |
| "sigma_end": sigma_end, | |
| "ortho": ortho, | |
| "ortho_v2": ortho_v2, | |
| "mask": attn_mask, | |
| } | |
| number = 0 | |
| for id in [4,5,7,8]: # id of input_blocks that have cross attention | |
| block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth | |
| for index in block_indices: | |
| patch_kwargs["module_key"] = str(number*2+1) | |
| set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) | |
| number += 1 | |
| for id in range(6): # id of output_blocks that have cross attention | |
| block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth | |
| for index in block_indices: | |
| patch_kwargs["module_key"] = str(number*2+1) | |
| set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) | |
| number += 1 | |
| for index in range(10): | |
| patch_kwargs["module_key"] = str(number*2+1) | |
| set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) | |
| number += 1 | |
| return (work_model,) | |
| class ApplyPulidAdvanced(ApplyPulid): | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("MODEL", ), | |
| "pulid": ("PULID", ), | |
| "eva_clip": ("EVA_CLIP", ), | |
| "face_analysis": ("FACEANALYSIS", ), | |
| "image": ("IMAGE", ), | |
| "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), | |
| "projection": (["ortho_v2", "ortho", "none"],), | |
| "fidelity": ("INT", {"default": 8, "min": 0, "max": 32, "step": 1 }), | |
| "noise": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.1 }), | |
| "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
| "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
| }, | |
| "optional": { | |
| "attn_mask": ("MASK", ), | |
| }, | |
| } | |
| NODE_CLASS_MAPPINGS = { | |
| "PulidModelLoader": PulidModelLoader, | |
| "PulidInsightFaceLoader": PulidInsightFaceLoader, | |
| "PulidEvaClipLoader": PulidEvaClipLoader, | |
| "ApplyPulid": ApplyPulid, | |
| "ApplyPulidAdvanced": ApplyPulidAdvanced, | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "PulidModelLoader": "Load PuLID Model", | |
| "PulidInsightFaceLoader": "Load InsightFace (PuLID)", | |
| "PulidEvaClipLoader": "Load Eva Clip (PuLID)", | |
| "ApplyPulid": "Apply PuLID", | |
| "ApplyPulidAdvanced": "Apply PuLID Advanced", | |
| } |