Spaces:
Sleeping
Sleeping
| import io | |
| import sys | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any, Tuple | |
| import gdown | |
| from PIL import Image, ImageFilter | |
| import torch | |
| from torchvision import transforms as T | |
| # ----------------------------- | |
| # Repo + weight constants | |
| # ----------------------------- | |
| REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing" | |
| REPO_DIR = Path("PixHtLab-Src") | |
| WEIGHT_FILE = REPO_DIR / "Demo" / "PixhtLab" / "weights" / "human_baseline_all_21-July-04-52-AM.pt" | |
| _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024 | |
| # ----------------------------- | |
| # Logging utility | |
| # ----------------------------- | |
| def _log(*args): | |
| print("[shadow]", *args, flush=True) | |
| def _print_image(img: Image.Image, message: str): | |
| _log(f"πΌοΈ {message} | Size: {img.size}, Mode: {img.mode}") | |
| # ----------------------------- | |
| # Repo download & validation | |
| # ----------------------------- | |
| def _download_repo(): | |
| """Download the full repo folder from Google Drive if it does not exist.""" | |
| if REPO_DIR.exists(): | |
| _log("Repo already exists:", REPO_DIR) | |
| return | |
| _log("Downloading repository from Google Drive folder...") | |
| temp_dir = Path(tempfile.gettempdir()) / "pixhtlab_repo" | |
| if temp_dir.exists(): | |
| shutil.rmtree(temp_dir) | |
| temp_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False) | |
| candidates = list(temp_dir.glob("*")) | |
| if not candidates: | |
| raise RuntimeError("No files downloaded from the repo folder.") | |
| shutil.move(str(candidates[0]), str(REPO_DIR)) | |
| _log("Repo downloaded successfully to:", REPO_DIR) | |
| except Exception as e: | |
| _log("β Failed to download repo:", repr(e)) | |
| raise RuntimeError("Cannot proceed without repository.") | |
| def _validate_weights(): | |
| if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: | |
| raise RuntimeError(f"SSN weight file missing or too small: {WEIGHT_FILE}") | |
| _log("Weight file exists and looks valid:", WEIGHT_FILE) | |
| # ----------------------------- | |
| # SSN model wrapper | |
| # ----------------------------- | |
| class SSNWrapper: | |
| def __init__(self): | |
| self.model = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _download_repo() | |
| _validate_weights() | |
| sys.path.insert(0, str(REPO_DIR.resolve())) | |
| try: | |
| from Demo.SSN.models.SSN_Model import SSN_Model # dynamically loaded from repo | |
| self.model = SSN_Model() | |
| state = torch.load(str(WEIGHT_FILE), map_location=self.device) | |
| if isinstance(state, dict): | |
| if "model_state_dict" in state: | |
| sd = state["model_state_dict"] | |
| elif "state_dict" in state: | |
| sd = state["state_dict"] | |
| elif "model" in state and isinstance(state["model"], dict): | |
| sd = state["model"] | |
| else: | |
| sd = state | |
| else: | |
| sd = state | |
| self.model.load_state_dict(sd, strict=False) | |
| self.model.eval().to(self.device) | |
| _log(f"β SSN model loaded on {self.device}") | |
| except Exception as e: | |
| _log("β Failed to load SSN model:", repr(e)) | |
| self.model = None | |
| def available(self) -> bool: | |
| return self.model is not None | |
| def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]: | |
| if self.model is None: | |
| _log("SSN model not available.") | |
| return None | |
| _log("Preparing image for inference...") | |
| target_size = (512, 512) | |
| to_tensor = T.ToTensor() | |
| fg_rgb_img = rgba_img.convert("RGB") | |
| fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) | |
| if bg_rgb is None: | |
| bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255)) | |
| bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) | |
| _print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image") | |
| try: | |
| _log("Running SSN inference...") | |
| try: | |
| out = self.model(fg_t, bg_t) | |
| except TypeError: | |
| out = self.model.forward(fg=fg_t, bg=bg_t) | |
| if out is None: | |
| _log("SSN inference returned None.") | |
| return None | |
| out_t = out[0] if isinstance(out, (tuple, list)) else out | |
| out_t = torch.clamp(out_t, 0.0, 1.0) | |
| matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0) | |
| matte_img = T.ToPILImage()(matte.cpu()) | |
| _print_image(matte_img, "Generated Shadow Matte (512x512)") | |
| return matte_img.resize(rgba_img.size, Image.BILINEAR) | |
| except Exception as e: | |
| _log("β SSN inference error:", repr(e)) | |
| return None | |
| # ----------------------------- | |
| # Singleton | |
| # ----------------------------- | |
| _ssn_wrapper: Optional[SSNWrapper] = None | |
| def load_ssn_once() -> SSNWrapper: | |
| global _ssn_wrapper | |
| if _ssn_wrapper is None: | |
| _ssn_wrapper = SSNWrapper() | |
| return _ssn_wrapper | |
| # ----------------------------- | |
| # Shadow compositing | |
| # ----------------------------- | |
| def _hex_to_rgb(color: str) -> Tuple[int, int, int]: | |
| c = color.strip().lstrip("#") | |
| if len(c) == 3: | |
| c = "".join(ch * 2 for ch in c) | |
| if len(c) != 6: | |
| raise ValueError("Invalid color hex.") | |
| return tuple(int(c[i:i+2],16) for i in (0,2,4)) | |
| def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image: | |
| _log("Compositing shadow and image...") | |
| w, h = original_img.size | |
| r,g,b = _hex_to_rgb(params["color"]) | |
| matte_rgba = Image.merge("RGBA", ( | |
| Image.new("L", (w,h), r), | |
| Image.new("L", (w,h), g), | |
| Image.new("L", (w,h), b), | |
| matte_img | |
| )) | |
| opacity = max(0.0, min(1.0, float(params["opacity"]))) | |
| if opacity < 1.0: | |
| a = matte_rgba.split()[-1].point(lambda p: int(p*opacity)) | |
| matte_rgba.putalpha(a) | |
| softness = max(0.0, float(params["softness"])) | |
| if softness>0: | |
| _log(f"Applying Gaussian blur: {softness}") | |
| matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness)) | |
| out = Image.new("RGBA",(w,h),(0,0,0,0)) | |
| out.alpha_composite(matte_rgba) | |
| out.alpha_composite(original_img) | |
| _log("Composition complete.") | |
| return out | |
| # ----------------------------- | |
| # Public API | |
| # ----------------------------- | |
| def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]: | |
| defaults = dict(softness=28.0, opacity=0.7, color="#000000") | |
| merged = {**defaults, **(params or {})} | |
| merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"]))) | |
| merged["softness"] = max(0.0, float(merged["softness"])) | |
| return merged | |
| def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes: | |
| _log("--- Starting shadow generation ---") | |
| params = _apply_params_defaults(params) | |
| img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA") | |
| _print_image(img, "Input RGBA Image") | |
| ssn = load_ssn_once() | |
| if not ssn.available(): | |
| _log("β SSN model unavailable") | |
| raise RuntimeError("SSN model not available.") | |
| matte_img = ssn.infer_shadow_matte(img) | |
| if matte_img is None: | |
| _log("β Failed to generate shadow matte") | |
| raise RuntimeError("Failed to generate shadow matte.") | |
| final_img = _composite_shadow_and_image(img, matte_img, params) | |
| buf = io.BytesIO() | |
| final_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| _log("β Shadow generation complete") | |
| return buf.read() | |
| def initialize_once(): | |
| _log("Initializing assets and SSN model...") | |
| load_ssn_once() | |