Spaces:
Sleeping
Sleeping
| # import io | |
| # import sys | |
| # import shutil | |
| # import tempfile | |
| # from pathlib import Path | |
| # from typing import Optional, Dict, Any, Tuple | |
| # import gdown | |
| # from PIL import Image, ImageFilter | |
| # import torch | |
| # from torchvision import transforms as T | |
| # # ----------------------------- | |
| # # Repo + weight constants | |
| # # ----------------------------- | |
| # REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing" | |
| # REPO_DIR = Path("PixHtLab-Src") | |
| # WEIGHT_FILE = REPO_DIR / "Demo" / "PixhtLab" / "weights" / "human_baseline_all_21-July-04-52-AM.pt" | |
| # _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024 | |
| # # ----------------------------- | |
| # # Logging utility | |
| # # ----------------------------- | |
| # def _log(*args): | |
| # print("[shadow]", *args, flush=True) | |
| # def _print_image(img: Image.Image, message: str): | |
| # _log(f"πΌοΈ {message} | Size: {img.size}, Mode: {img.mode}") | |
| # # ----------------------------- | |
| # # Repo download & validation | |
| # # ----------------------------- | |
| # def _download_repo(): | |
| # """Download the full repo folder from Google Drive if it does not exist.""" | |
| # if REPO_DIR.exists(): | |
| # _log("Repo already exists:", REPO_DIR) | |
| # return | |
| # _log("Downloading repository from Google Drive folder...") | |
| # temp_dir = Path(tempfile.gettempdir()) / "pixhtlab_repo" | |
| # if temp_dir.exists(): | |
| # shutil.rmtree(temp_dir) | |
| # temp_dir.mkdir(parents=True, exist_ok=True) | |
| # try: | |
| # gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False) | |
| # candidates = list(temp_dir.glob("*")) | |
| # if not candidates: | |
| # raise RuntimeError("No files downloaded from the repo folder.") | |
| # shutil.move(str(candidates[0]), str(REPO_DIR)) | |
| # _log("Repo downloaded successfully to:", REPO_DIR) | |
| # except Exception as e: | |
| # _log("β Failed to download repo:", repr(e)) | |
| # raise RuntimeError("Cannot proceed without repository.") | |
| # def _validate_weights(): | |
| # if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: | |
| # raise RuntimeError(f"SSN weight file missing or too small: {WEIGHT_FILE}") | |
| # _log("Weight file exists and looks valid:", WEIGHT_FILE) | |
| # # ----------------------------- | |
| # # SSN model wrapper | |
| # # ----------------------------- | |
| # class SSNWrapper: | |
| # def __init__(self): | |
| # self.model = None | |
| # self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # _download_repo() | |
| # _validate_weights() | |
| # sys.path.insert(0, str(REPO_DIR.resolve())) | |
| # try: | |
| # from Demo.SSN.models.SSN_Model import SSN_Model # dynamically loaded from repo | |
| # self.model = SSN_Model() | |
| # state = torch.load(str(WEIGHT_FILE), map_location=self.device) | |
| # if isinstance(state, dict): | |
| # if "model_state_dict" in state: | |
| # sd = state["model_state_dict"] | |
| # elif "state_dict" in state: | |
| # sd = state["state_dict"] | |
| # elif "model" in state and isinstance(state["model"], dict): | |
| # sd = state["model"] | |
| # else: | |
| # sd = state | |
| # else: | |
| # sd = state | |
| # self.model.load_state_dict(sd, strict=False) | |
| # self.model.eval().to(self.device) | |
| # _log(f"β SSN model loaded on {self.device}") | |
| # except Exception as e: | |
| # _log("β Failed to load SSN model:", repr(e)) | |
| # self.model = None | |
| # def available(self) -> bool: | |
| # return self.model is not None | |
| # @torch.no_grad() | |
| # def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]: | |
| # if self.model is None: | |
| # _log("SSN model not available.") | |
| # return None | |
| # _log("Preparing image for inference...") | |
| # target_size = (512, 512) | |
| # to_tensor = T.ToTensor() | |
| # fg_rgb_img = rgba_img.convert("RGB") | |
| # fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) | |
| # if bg_rgb is None: | |
| # bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255)) | |
| # bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) | |
| # _print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image") | |
| # try: | |
| # _log("Running SSN inference...") | |
| # try: | |
| # out = self.model(fg_t, bg_t) | |
| # except TypeError: | |
| # out = self.model.forward(fg=fg_t, bg=bg_t) | |
| # if out is None: | |
| # _log("SSN inference returned None.") | |
| # return None | |
| # out_t = out[0] if isinstance(out, (tuple, list)) else out | |
| # out_t = torch.clamp(out_t, 0.0, 1.0) | |
| # matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0) | |
| # matte_img = T.ToPILImage()(matte.cpu()) | |
| # _print_image(matte_img, "Generated Shadow Matte (512x512)") | |
| # return matte_img.resize(rgba_img.size, Image.BILINEAR) | |
| # except Exception as e: | |
| # _log("β SSN inference error:", repr(e)) | |
| # return None | |
| # # ----------------------------- | |
| # # Singleton | |
| # # ----------------------------- | |
| # _ssn_wrapper: Optional[SSNWrapper] = None | |
| # def load_ssn_once() -> SSNWrapper: | |
| # global _ssn_wrapper | |
| # if _ssn_wrapper is None: | |
| # _ssn_wrapper = SSNWrapper() | |
| # return _ssn_wrapper | |
| # # ----------------------------- | |
| # # Shadow compositing | |
| # # ----------------------------- | |
| # def _hex_to_rgb(color: str) -> Tuple[int, int, int]: | |
| # c = color.strip().lstrip("#") | |
| # if len(c) == 3: | |
| # c = "".join(ch * 2 for ch in c) | |
| # if len(c) != 6: | |
| # raise ValueError("Invalid color hex.") | |
| # return tuple(int(c[i:i+2],16) for i in (0,2,4)) | |
| # def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image: | |
| # _log("Compositing shadow and image...") | |
| # w, h = original_img.size | |
| # r,g,b = _hex_to_rgb(params["color"]) | |
| # matte_rgba = Image.merge("RGBA", ( | |
| # Image.new("L", (w,h), r), | |
| # Image.new("L", (w,h), g), | |
| # Image.new("L", (w,h), b), | |
| # matte_img | |
| # )) | |
| # opacity = max(0.0, min(1.0, float(params["opacity"]))) | |
| # if opacity < 1.0: | |
| # a = matte_rgba.split()[-1].point(lambda p: int(p*opacity)) | |
| # matte_rgba.putalpha(a) | |
| # softness = max(0.0, float(params["softness"])) | |
| # if softness>0: | |
| # _log(f"Applying Gaussian blur: {softness}") | |
| # matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness)) | |
| # out = Image.new("RGBA",(w,h),(0,0,0,0)) | |
| # out.alpha_composite(matte_rgba) | |
| # out.alpha_composite(original_img) | |
| # _log("Composition complete.") | |
| # return out | |
| # # ----------------------------- | |
| # # Public API | |
| # # ----------------------------- | |
| # def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]: | |
| # defaults = dict(softness=28.0, opacity=0.7, color="#000000") | |
| # merged = {**defaults, **(params or {})} | |
| # merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"]))) | |
| # merged["softness"] = max(0.0, float(merged["softness"])) | |
| # return merged | |
| # def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes: | |
| # _log("--- Starting shadow generation ---") | |
| # params = _apply_params_defaults(params) | |
| # img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA") | |
| # _print_image(img, "Input RGBA Image") | |
| # ssn = load_ssn_once() | |
| # if not ssn.available(): | |
| # _log("β SSN model unavailable") | |
| # raise RuntimeError("SSN model not available.") | |
| # matte_img = ssn.infer_shadow_matte(img) | |
| # if matte_img is None: | |
| # _log("β Failed to generate shadow matte") | |
| # raise RuntimeError("Failed to generate shadow matte.") | |
| # final_img = _composite_shadow_and_image(img, matte_img, params) | |
| # buf = io.BytesIO() | |
| # final_img.save(buf, format="PNG") | |
| # buf.seek(0) | |
| # _log("β Shadow generation complete") | |
| # return buf.read() | |
| # def initialize_once(): | |
| # _log("Initializing assets and SSN model...") | |
| # load_ssn_once() | |
| import os | |
| os.environ["GDOWN_CACHE_DIR"] = "/tmp/.gdown" | |
| os.makedirs(os.environ["GDOWN_CACHE_DIR"], exist_ok=True) | |
| import io | |
| import sys | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any, Tuple | |
| import subprocess | |
| import gdown | |
| from PIL import Image, ImageFilter | |
| import torch | |
| from torchvision import transforms as T | |
| # ----------------------------- | |
| # Paths & constants | |
| # ----------------------------- | |
| TMP_DIR = Path(tempfile.gettempdir()) / "pixhtlab" | |
| REPO_DIR = TMP_DIR / "PixHtLab-Src" | |
| WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt" | |
| # Hugging Face repo and Google Drive weight links | |
| REPO_HF_URL = "https://huggingface.co/karthikeya1212/test" | |
| DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH" | |
| _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024 | |
| # ----------------------------- | |
| # Logging | |
| # ----------------------------- | |
| def _log(*args): | |
| print("[shadow]", *args, flush=True) | |
| def _print_image(img: Image.Image, message: str): | |
| _log(f"πΌοΈ {message} | Size: {img.size}, Mode: {img.mode}") | |
| # ----------------------------- | |
| # Repo & weight handling | |
| # ----------------------------- | |
| def _download_repo(): | |
| if REPO_DIR.exists(): | |
| _log("Repo already exists:", REPO_DIR) | |
| return | |
| _log("Cloning repository from Hugging Face...") | |
| TMP_DIR.mkdir(parents=True, exist_ok=True) | |
| try: | |
| subprocess.run(["git", "lfs", "install"], check=False) | |
| subprocess.run(["git", "clone", REPO_HF_URL, str(REPO_DIR)], check=True) | |
| _log("Repo cloned successfully to:", REPO_DIR) | |
| except Exception as e: | |
| _log("β Failed to clone repo:", repr(e)) | |
| raise RuntimeError("Cannot proceed without repository.") | |
| def _download_weights(): | |
| if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES: | |
| _log("Weights already exist:", WEIGHT_FILE) | |
| return | |
| _log("Downloading weight file from Google Drive...") | |
| WEIGHT_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| gdown.download(DRIVE_WEIGHT_FILE, str(WEIGHT_FILE), quiet=False, use_cookies=False) | |
| if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: | |
| raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}") | |
| _log("β Weight file downloaded:", WEIGHT_FILE) | |
| def _validate_assets(): | |
| if not REPO_DIR.exists(): | |
| raise RuntimeError("Repo folder missing.") | |
| if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES: | |
| raise RuntimeError("Weight file missing or too small.") | |
| _log("Assets validated successfully.") | |
| # ----------------------------- | |
| # Auto-search for SSN_Model.py | |
| # ----------------------------- | |
| def _find_ssn_model_file(): | |
| candidates = list(REPO_DIR.rglob("SSN_Model.py")) | |
| if not candidates: | |
| raise RuntimeError("SSN_Model.py not found in repo.") | |
| _log("Found SSN_Model.py:", candidates[0]) | |
| return candidates[0].parent | |
| # ----------------------------- | |
| # SSN model wrapper | |
| # ----------------------------- | |
| class SSNWrapper: | |
| def __init__(self): | |
| self.model = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _download_repo() | |
| _download_weights() | |
| _validate_assets() | |
| ssn_model_dir = _find_ssn_model_file() | |
| sys.path.insert(0, str(ssn_model_dir.resolve())) | |
| try: | |
| from SSN_Model import SSN_Model # dynamic import from found directory | |
| self.model = SSN_Model() | |
| import numpy as np | |
| import torch.serialization as torch_serialization | |
| with torch.serialization.safe_globals([np.core.multiarray.scalar]): | |
| state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False) | |
| if isinstance(state, dict): | |
| if "model_state_dict" in state: | |
| sd = state["model_state_dict"] | |
| elif "state_dict" in state: | |
| sd = state["state_dict"] | |
| elif "model" in state and isinstance(state["model"], dict): | |
| sd = state["model"] | |
| else: | |
| sd = state | |
| else: | |
| sd = state | |
| self.model.load_state_dict(sd, strict=False) | |
| self.model.eval().to(self.device) | |
| _log(f"β SSN model loaded on {self.device}") | |
| except Exception as e: | |
| _log("β Failed to load SSN model:", repr(e)) | |
| self.model = None | |
| def available(self) -> bool: | |
| return self.model is not None | |
| def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]: | |
| if self.model is None: | |
| _log("SSN model not available.") | |
| return None | |
| _log("Preparing image for inference...") | |
| target_size = (512, 512) | |
| to_tensor = T.ToTensor() | |
| fg_rgb_img = rgba_img.convert("RGB") | |
| fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) | |
| if bg_rgb is None: | |
| bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255)) | |
| bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device) | |
| _print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image") | |
| try: | |
| _log("Running SSN inference...") | |
| try: | |
| out = self.model(fg_t, bg_t) | |
| except TypeError: | |
| out = self.model.forward(fg=fg_t, bg=bg_t) | |
| if out is None: | |
| _log("SSN inference returned None.") | |
| return None | |
| out_t = out[0] if isinstance(out, (tuple, list)) else out | |
| out_t = torch.clamp(out_t, 0.0, 1.0) | |
| matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0) | |
| matte_img = T.ToPILImage()(matte.cpu()) | |
| _print_image(matte_img, "Generated Shadow Matte (512x512)") | |
| return matte_img.resize(rgba_img.size, Image.BILINEAR) | |
| except Exception as e: | |
| _log("β SSN inference error:", repr(e)) | |
| return None | |
| # ----------------------------- | |
| # Singleton | |
| # ----------------------------- | |
| _ssn_wrapper: Optional[SSNWrapper] = None | |
| def load_ssn_once() -> SSNWrapper: | |
| global _ssn_wrapper | |
| if _ssn_wrapper is None: | |
| _ssn_wrapper = SSNWrapper() | |
| return _ssn_wrapper | |
| # ----------------------------- | |
| # Shadow compositing | |
| # ----------------------------- | |
| def _hex_to_rgb(color: str) -> Tuple[int, int, int]: | |
| c = color.strip().lstrip("#") | |
| if len(c) == 3: | |
| c = "".join(ch * 2 for ch in c) | |
| if len(c) != 6: | |
| raise ValueError("Invalid color hex.") | |
| return tuple(int(c[i:i+2],16) for i in (0,2,4)) | |
| def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image: | |
| _log("Compositing shadow and image...") | |
| w, h = original_img.size | |
| r,g,b = _hex_to_rgb(params["color"]) | |
| matte_rgba = Image.merge("RGBA", ( | |
| Image.new("L", (w,h), r), | |
| Image.new("L", (w,h), g), | |
| Image.new("L", (w,h), b), | |
| matte_img | |
| )) | |
| opacity = max(0.0, min(1.0, float(params["opacity"]))) | |
| if opacity < 1.0: | |
| a = matte_rgba.split()[-1].point(lambda p: int(p*opacity)) | |
| matte_rgba.putalpha(a) | |
| softness = max(0.0, float(params["softness"])) | |
| if softness>0: | |
| _log(f"Applying Gaussian blur: {softness}") | |
| matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness)) | |
| out = Image.new("RGBA",(w,h),(0,0,0,0)) | |
| out.alpha_composite(matte_rgba) | |
| out.alpha_composite(original_img) | |
| _log("Composition complete.") | |
| return out | |
| # ----------------------------- | |
| # Public API | |
| # ----------------------------- | |
| def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]: | |
| defaults = dict(softness=28.0, opacity=0.7, color="#000000") | |
| merged = {**defaults, **(params or {})} | |
| merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"]))) | |
| merged["softness"] = max(0.0, float(merged["softness"])) | |
| return merged | |
| def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes: | |
| _log("--- Starting shadow generation ---") | |
| params = _apply_params_defaults(params) | |
| img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA") | |
| _print_image(img, "Input RGBA Image") | |
| ssn = load_ssn_once() | |
| if not ssn.available(): | |
| _log("β SSN model unavailable") | |
| raise RuntimeError("SSN model not available.") | |
| matte_img = ssn.infer_shadow_matte(img) | |
| if matte_img is None: | |
| _log("β Failed to generate shadow matte") | |
| raise RuntimeError("Failed to generate shadow matte.") | |
| final_img = _composite_shadow_and_image(img, matte_img, params) | |
| buf = io.BytesIO() | |
| final_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| _log("β Shadow generation complete") | |
| return buf.read() | |
| def initialize_once(): | |
| _log("Initializing assets and SSN model...") | |
| load_ssn_once() | |