# import io # import sys # import shutil # import tempfile # from pathlib import Path # from typing import Optional, Dict, Any, Tuple # import gdown # from PIL import Image, ImageFilter # import torch # from torchvision import transforms as T # # ----------------------------- # # Repo + weight constants # # ----------------------------- # REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing" # REPO_DIR = Path("PixHtLab-Src") # WEIGHT_FILE = REPO_DIR / "Demo" / "PixhtLab" / "weights" / "human_baseline_all_21-July-04-52-AM.pt" # _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024 # # ----------------------------- # # Logging utility # # ----------------------------- # def _log(*args): # print("[shadow]", *args, flush=True) # def _print_image(img: Image.Image, message: str): # _log(f"🖼️ {message} | Size: {img.size}, Mode: {img.mode}") # # ----------------------------- # # Repo download & validation # # ----------------------------- # def _download_repo(): # """Download the full repo folder from Google Drive if it does not exist.""" # if REPO_DIR.exists(): # _log("Repo already exists:", REPO_DIR) # return # _log("Downloading repository from Google Drive folder...") # temp_dir = Path(tempfile.gettempdir()) / "pixhtlab_repo" # if temp_dir.exists(): # shutil.rmtree(temp_dir) # temp_dir.mkdir(parents=True, exist_ok=True) # try: # gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False) # candidates = list(temp_dir.glob("*")) # if not candidates: # raise RuntimeError("No files downloaded from the repo folder.") # shutil.move(str(candidates[0]), str(REPO_DIR)) # _log("Repo downloaded successfully to:", REPO_DIR) # except Exception as e: # _log("❌ Failed to download repo:", repr(e)) # raise RuntimeError("Cannot proceed without repository.") # def _validate_weights(): # if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: # raise RuntimeError(f"SSN weight file missing or too small: {WEIGHT_FILE}") # _log("Weight file exists and looks valid:", WEIGHT_FILE) # # ----------------------------- # # SSN model wrapper # # ----------------------------- # class SSNWrapper: # def __init__(self): # self.model = None # self.device = "cuda" if torch.cuda.is_available() else "cpu" # _download_repo() # _validate_weights() # sys.path.insert(0, str(REPO_DIR.resolve())) # try: # from Demo.SSN.models.SSN_Model import SSN_Model # dynamically loaded from repo # self.model = SSN_Model() # state = torch.load(str(WEIGHT_FILE), map_location=self.device) # if isinstance(state, dict): # if "model_state_dict" in state: # sd = state["model_state_dict"] # elif "state_dict" in state: # sd = state["state_dict"] # elif "model" in state and isinstance(state["model"], dict): # sd = state["model"] # else: # sd = state # else: # sd = state # self.model.load_state_dict(sd, strict=False) # self.model.eval().to(self.device) # _log(f"✅ SSN model loaded on {self.device}") # except Exception as e: # _log("❌ Failed to load SSN model:", repr(e)) # self.model = None # def available(self) -> bool: # return self.model is not None # @torch.no_grad() # def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]: # if self.model is None: # _log("SSN model not available.") # return None # _log("Preparing image for inference...") # target_size = (512, 512) # to_tensor = T.ToTensor() # fg_rgb_img = rgba_img.convert("RGB") # fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) # if bg_rgb is None: # bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255)) # bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) # _print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image") # try: # _log("Running SSN inference...") # try: # out = self.model(fg_t, bg_t) # except TypeError: # out = self.model.forward(fg=fg_t, bg=bg_t) # if out is None: # _log("SSN inference returned None.") # return None # out_t = out[0] if isinstance(out, (tuple, list)) else out # out_t = torch.clamp(out_t, 0.0, 1.0) # matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0) # matte_img = T.ToPILImage()(matte.cpu()) # _print_image(matte_img, "Generated Shadow Matte (512x512)") # return matte_img.resize(rgba_img.size, Image.BILINEAR) # except Exception as e: # _log("❌ SSN inference error:", repr(e)) # return None # # ----------------------------- # # Singleton # # ----------------------------- # _ssn_wrapper: Optional[SSNWrapper] = None # def load_ssn_once() -> SSNWrapper: # global _ssn_wrapper # if _ssn_wrapper is None: # _ssn_wrapper = SSNWrapper() # return _ssn_wrapper # # ----------------------------- # # Shadow compositing # # ----------------------------- # def _hex_to_rgb(color: str) -> Tuple[int, int, int]: # c = color.strip().lstrip("#") # if len(c) == 3: # c = "".join(ch * 2 for ch in c) # if len(c) != 6: # raise ValueError("Invalid color hex.") # return tuple(int(c[i:i+2],16) for i in (0,2,4)) # def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image: # _log("Compositing shadow and image...") # w, h = original_img.size # r,g,b = _hex_to_rgb(params["color"]) # matte_rgba = Image.merge("RGBA", ( # Image.new("L", (w,h), r), # Image.new("L", (w,h), g), # Image.new("L", (w,h), b), # matte_img # )) # opacity = max(0.0, min(1.0, float(params["opacity"]))) # if opacity < 1.0: # a = matte_rgba.split()[-1].point(lambda p: int(p*opacity)) # matte_rgba.putalpha(a) # softness = max(0.0, float(params["softness"])) # if softness>0: # _log(f"Applying Gaussian blur: {softness}") # matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness)) # out = Image.new("RGBA",(w,h),(0,0,0,0)) # out.alpha_composite(matte_rgba) # out.alpha_composite(original_img) # _log("Composition complete.") # return out # # ----------------------------- # # Public API # # ----------------------------- # def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]: # defaults = dict(softness=28.0, opacity=0.7, color="#000000") # merged = {**defaults, **(params or {})} # merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"]))) # merged["softness"] = max(0.0, float(merged["softness"])) # return merged # def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes: # _log("--- Starting shadow generation ---") # params = _apply_params_defaults(params) # img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA") # _print_image(img, "Input RGBA Image") # ssn = load_ssn_once() # if not ssn.available(): # _log("❌ SSN model unavailable") # raise RuntimeError("SSN model not available.") # matte_img = ssn.infer_shadow_matte(img) # if matte_img is None: # _log("❌ Failed to generate shadow matte") # raise RuntimeError("Failed to generate shadow matte.") # final_img = _composite_shadow_and_image(img, matte_img, params) # buf = io.BytesIO() # final_img.save(buf, format="PNG") # buf.seek(0) # _log("✅ Shadow generation complete") # return buf.read() # def initialize_once(): # _log("Initializing assets and SSN model...") # load_ssn_once() import os os.environ["GDOWN_CACHE_DIR"] = "/tmp/.gdown" os.makedirs(os.environ["GDOWN_CACHE_DIR"], exist_ok=True) import io import sys import tempfile from pathlib import Path from typing import Optional, Dict, Any, Tuple import subprocess import gdown from PIL import Image, ImageFilter, ImageChops import torch from torchvision import transforms as T import numpy as np # ----------------------------- # Paths & constants # ----------------------------- TMP_DIR = Path(tempfile.gettempdir()) / "pixhtlab" REPO_DIR = TMP_DIR / "PixHtLab-Src" WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt" REPO_HF_URL = "https://huggingface.co/karthikeya1212/test" DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH" _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024 # ----------------------------- # Logging # ----------------------------- def _log(*args): print("[shadow]", *args, flush=True) def _print_image(img: Image.Image, message: str): _log(f"🖼️ {message} | Size: {img.size}, Mode: {img.mode}") # ----------------------------- # Repo & weight handling # ----------------------------- def _download_repo(): if REPO_DIR.exists(): _log("Repo already exists:", REPO_DIR) return _log("Cloning repository from Hugging Face...") TMP_DIR.mkdir(parents=True, exist_ok=True) try: subprocess.run(["git", "lfs", "install"], check=False) subprocess.run(["git", "clone", REPO_HF_URL, str(REPO_DIR)], check=True) _log("Repo cloned successfully to:", REPO_DIR) except Exception as e: _log("❌ Failed to clone repo:", repr(e)) raise RuntimeError("Cannot proceed without repository.") def _download_weights(): if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES: _log("Weights already exist:", WEIGHT_FILE) return _log("Downloading weight file from Google Drive...") WEIGHT_FILE.parent.mkdir(parents=True, exist_ok=True) gdown.download(DRIVE_WEIGHT_FILE, str(WEIGHT_FILE), quiet=False, use_cookies=False) if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}") _log("✅ Weight file downloaded:", WEIGHT_FILE) def _validate_assets(): if not REPO_DIR.exists(): raise RuntimeError("Repo folder missing.") if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: raise RuntimeError("Weight file missing or too small.") _log("Assets validated successfully.") # ----------------------------- # Auto-search for SSN_Model.py # ----------------------------- def _find_ssn_model_file(): candidates = list(REPO_DIR.rglob("SSN_Model.py")) if not candidates: raise RuntimeError("SSN_Model.py not found in repo.") _log("Found SSN_Model.py:", candidates[0]) return candidates[0].parent # ----------------------------- # Utility: Resize with aspect ratio (padding) # ----------------------------- def _resize_with_aspect(img: Image.Image, target_size=(512, 512), fill_color=(255,255,255)): img.thumbnail(target_size, Image.BICUBIC) new_img = Image.new("RGB", target_size, fill_color) new_img.paste(img, ((target_size[0]-img.width)//2, (target_size[1]-img.height)//2)) return new_img # ----------------------------- # Utility: Create IBL Tensor for directional light # ----------------------------- def _create_ibl_tensor(direction: float, device: str) -> torch.Tensor: _log(f"Creating IBL for direction: {direction} degrees") ibl_h, ibl_w = 16, 32 ibl = np.zeros((ibl_h, ibl_w), dtype=np.float32) light_direction = (direction + 180) % 360 x_norm = light_direction / 360.0 x = int(x_norm * ibl_w) y = ibl_h // 2 y_start = max(0, y - 1) y_end = min(ibl_h, y + 2) x_start = max(0, x - 1) x_end = min(ibl_w, x + 2) ibl[y_start:y_end, x_start:x_end] = 1.0 ibl_flat = ibl.flatten() ibl_t = torch.from_numpy(ibl_flat).unsqueeze(0).to(device) return ibl_t # ----------------------------- # SSN model wrapper # ----------------------------- class SSNWrapper: def __init__(self): self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" _download_repo() _download_weights() _validate_assets() ssn_model_dir = _find_ssn_model_file() sys.path.insert(0, str(ssn_model_dir.resolve())) try: from SSN_Model import SSN_Model self.model = SSN_Model() # ✅ Patch for numpy scalar safety torch.serialization.add_safe_globals([np.core.multiarray.scalar]) # ✅ Allow full checkpoint load state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False) if isinstance(state, dict): if "model_state_dict" in state: sd = state["model_state_dict"] elif "state_dict" in state: sd = state["state_dict"] elif "model" in state and isinstance(state["model"], dict): sd = state["model"] else: sd = state else: sd = state self.model.load_state_dict(sd, strict=False) self.model.eval().to(self.device) _log(f"✅ SSN model loaded on {self.device}") except Exception as e: _log("❌ Failed to load SSN model:", repr(e)) self.model = None def available(self) -> bool: return self.model is not None @torch.no_grad() def infer_shadow_matte(self, rgba_img: Image.Image, direction: float) -> Optional[Image.Image]: if self.model is None: _log("SSN model not available.") return None _log("Preparing image for inference...") to_tensor = T.ToTensor() target_size = (512, 512) fg_rgb_img = _resize_with_aspect(rgba_img.convert("RGB"), target_size=target_size, fill_color=(255,255,255)) fg_t = to_tensor(fg_rgb_img).unsqueeze(0).to(self.device) _print_image(fg_rgb_img, "Input Foreground Image") try: _log(f"Running SSN inference with direction {direction}...") ibl_t = _create_ibl_tensor(direction, self.device) out = self.model(fg_t, ibl_t) if out is None: _log("SSN inference returned None.") return None out_t = out[0] if isinstance(out, (tuple, list)) else out out_t = torch.clamp(out_t, 0.0, 1.0) matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0) matte_img = T.ToPILImage()(matte.cpu()) _print_image(matte_img, "Generated Shadow Matte (512x512)") return matte_img.resize(rgba_img.size, Image.BILINEAR) except Exception as e: _log("❌ SSN inference error:", repr(e)) return None # ----------------------------- # Singleton # ----------------------------- _ssn_wrapper: Optional[SSNWrapper] = None def load_ssn_once() -> SSNWrapper: global _ssn_wrapper if _ssn_wrapper is None: _ssn_wrapper = SSNWrapper() return _ssn_wrapper # ----------------------------- # Shadow compositing (REALISTIC SSN) # ----------------------------- def _hex_to_rgb(color: str) -> Tuple[int, int, int]: c = color.strip().lstrip("#") if len(c) == 3: c = "".join(ch * 2 for ch in c) if len(c) != 6: raise ValueError("Invalid color hex.") return tuple(int(c[i:i+2],16) for i in (0,2,4)) def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image: _log("Compositing shadow and image (REAL SSN)...") w, h = original_img.size r, g, b = _hex_to_rgb(params.get("color", "#000000")) matte_resized = matte_img.resize((w, h)).convert("L") shadow_rgba = Image.new("RGBA", (w, h), (r, g, b, 0)) alpha = matte_resized.point(lambda p: int(p * float(params.get("opacity", 0.7)))) shadow_rgba.putalpha(alpha) softness = max(0.0, float(params.get("softness", 20.0))) if softness > 0: shadow_rgba = shadow_rgba.filter(ImageFilter.GaussianBlur(radius=softness)) final_rgba = Image.new("RGBA", (w, h), (0, 0, 0, 0)) final_rgba.alpha_composite(shadow_rgba) final_rgba.alpha_composite(original_img) _log("✅ Composition complete with realistic SSN shadow.") return final_rgba # ----------------------------- # Public API # ----------------------------- def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]: defaults = dict(softness=20.0, opacity=0.7, color="#000000", direction=45.0, distance=80.0) merged = {**defaults, **(params or {})} merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"]))) merged["softness"] = max(0.0, float(merged["softness"])) merged["direction"] = float(merged["direction"]) merged["distance"] = float(merged["distance"]) return merged def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes: _log("--- Starting shadow generation ---") params = _apply_params_defaults(params) img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA") _print_image(img, "Input RGBA Image") ssn = load_ssn_once() if not ssn.available(): _log("❌ SSN model unavailable") raise RuntimeError("SSN model not available.") matte_img = ssn.infer_shadow_matte(img, params["direction"]) if matte_img is None: _log("❌ Failed to generate shadow matte") raise RuntimeError("Failed to generate shadow matte.") final_img = _composite_shadow_and_image(img, matte_img, params) buf = io.BytesIO() final_img.save(buf, format="PNG") buf.seek(0) _log("✅ Shadow generation complete") return buf.read() def initialize_once(): _log("Initializing assets and SSN model...") load_ssn_once()