test / shadow_generator.py
karthikeya1212's picture
Update shadow_generator.py
6bb15a5 verified
raw
history blame
17.6 kB
# import io
# import sys
# import shutil
# import tempfile
# from pathlib import Path
# from typing import Optional, Dict, Any, Tuple
# import gdown
# from PIL import Image, ImageFilter
# import torch
# from torchvision import transforms as T
# # -----------------------------
# # Repo + weight constants
# # -----------------------------
# REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing"
# REPO_DIR = Path("PixHtLab-Src")
# WEIGHT_FILE = REPO_DIR / "Demo" / "PixhtLab" / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
# _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
# # -----------------------------
# # Logging utility
# # -----------------------------
# def _log(*args):
# print("[shadow]", *args, flush=True)
# def _print_image(img: Image.Image, message: str):
# _log(f"πŸ–ΌοΈ {message} | Size: {img.size}, Mode: {img.mode}")
# # -----------------------------
# # Repo download & validation
# # -----------------------------
# def _download_repo():
# """Download the full repo folder from Google Drive if it does not exist."""
# if REPO_DIR.exists():
# _log("Repo already exists:", REPO_DIR)
# return
# _log("Downloading repository from Google Drive folder...")
# temp_dir = Path(tempfile.gettempdir()) / "pixhtlab_repo"
# if temp_dir.exists():
# shutil.rmtree(temp_dir)
# temp_dir.mkdir(parents=True, exist_ok=True)
# try:
# gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False)
# candidates = list(temp_dir.glob("*"))
# if not candidates:
# raise RuntimeError("No files downloaded from the repo folder.")
# shutil.move(str(candidates[0]), str(REPO_DIR))
# _log("Repo downloaded successfully to:", REPO_DIR)
# except Exception as e:
# _log("❌ Failed to download repo:", repr(e))
# raise RuntimeError("Cannot proceed without repository.")
# def _validate_weights():
# if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES:
# raise RuntimeError(f"SSN weight file missing or too small: {WEIGHT_FILE}")
# _log("Weight file exists and looks valid:", WEIGHT_FILE)
# # -----------------------------
# # SSN model wrapper
# # -----------------------------
# class SSNWrapper:
# def __init__(self):
# self.model = None
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# _download_repo()
# _validate_weights()
# sys.path.insert(0, str(REPO_DIR.resolve()))
# try:
# from Demo.SSN.models.SSN_Model import SSN_Model # dynamically loaded from repo
# self.model = SSN_Model()
# state = torch.load(str(WEIGHT_FILE), map_location=self.device)
# if isinstance(state, dict):
# if "model_state_dict" in state:
# sd = state["model_state_dict"]
# elif "state_dict" in state:
# sd = state["state_dict"]
# elif "model" in state and isinstance(state["model"], dict):
# sd = state["model"]
# else:
# sd = state
# else:
# sd = state
# self.model.load_state_dict(sd, strict=False)
# self.model.eval().to(self.device)
# _log(f"βœ… SSN model loaded on {self.device}")
# except Exception as e:
# _log("❌ Failed to load SSN model:", repr(e))
# self.model = None
# def available(self) -> bool:
# return self.model is not None
# @torch.no_grad()
# def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]:
# if self.model is None:
# _log("SSN model not available.")
# return None
# _log("Preparing image for inference...")
# target_size = (512, 512)
# to_tensor = T.ToTensor()
# fg_rgb_img = rgba_img.convert("RGB")
# fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
# if bg_rgb is None:
# bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255))
# bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
# _print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image")
# try:
# _log("Running SSN inference...")
# try:
# out = self.model(fg_t, bg_t)
# except TypeError:
# out = self.model.forward(fg=fg_t, bg=bg_t)
# if out is None:
# _log("SSN inference returned None.")
# return None
# out_t = out[0] if isinstance(out, (tuple, list)) else out
# out_t = torch.clamp(out_t, 0.0, 1.0)
# matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0)
# matte_img = T.ToPILImage()(matte.cpu())
# _print_image(matte_img, "Generated Shadow Matte (512x512)")
# return matte_img.resize(rgba_img.size, Image.BILINEAR)
# except Exception as e:
# _log("❌ SSN inference error:", repr(e))
# return None
# # -----------------------------
# # Singleton
# # -----------------------------
# _ssn_wrapper: Optional[SSNWrapper] = None
# def load_ssn_once() -> SSNWrapper:
# global _ssn_wrapper
# if _ssn_wrapper is None:
# _ssn_wrapper = SSNWrapper()
# return _ssn_wrapper
# # -----------------------------
# # Shadow compositing
# # -----------------------------
# def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
# c = color.strip().lstrip("#")
# if len(c) == 3:
# c = "".join(ch * 2 for ch in c)
# if len(c) != 6:
# raise ValueError("Invalid color hex.")
# return tuple(int(c[i:i+2],16) for i in (0,2,4))
# def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
# _log("Compositing shadow and image...")
# w, h = original_img.size
# r,g,b = _hex_to_rgb(params["color"])
# matte_rgba = Image.merge("RGBA", (
# Image.new("L", (w,h), r),
# Image.new("L", (w,h), g),
# Image.new("L", (w,h), b),
# matte_img
# ))
# opacity = max(0.0, min(1.0, float(params["opacity"])))
# if opacity < 1.0:
# a = matte_rgba.split()[-1].point(lambda p: int(p*opacity))
# matte_rgba.putalpha(a)
# softness = max(0.0, float(params["softness"]))
# if softness>0:
# _log(f"Applying Gaussian blur: {softness}")
# matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness))
# out = Image.new("RGBA",(w,h),(0,0,0,0))
# out.alpha_composite(matte_rgba)
# out.alpha_composite(original_img)
# _log("Composition complete.")
# return out
# # -----------------------------
# # Public API
# # -----------------------------
# def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
# defaults = dict(softness=28.0, opacity=0.7, color="#000000")
# merged = {**defaults, **(params or {})}
# merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
# merged["softness"] = max(0.0, float(merged["softness"]))
# return merged
# def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes:
# _log("--- Starting shadow generation ---")
# params = _apply_params_defaults(params)
# img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA")
# _print_image(img, "Input RGBA Image")
# ssn = load_ssn_once()
# if not ssn.available():
# _log("❌ SSN model unavailable")
# raise RuntimeError("SSN model not available.")
# matte_img = ssn.infer_shadow_matte(img)
# if matte_img is None:
# _log("❌ Failed to generate shadow matte")
# raise RuntimeError("Failed to generate shadow matte.")
# final_img = _composite_shadow_and_image(img, matte_img, params)
# buf = io.BytesIO()
# final_img.save(buf, format="PNG")
# buf.seek(0)
# _log("βœ… Shadow generation complete")
# return buf.read()
# def initialize_once():
# _log("Initializing assets and SSN model...")
# load_ssn_once()
import os
os.environ["GDOWN_CACHE_DIR"] = "/tmp/.gdown"
os.makedirs(os.environ["GDOWN_CACHE_DIR"], exist_ok=True)
import io
import sys
import shutil
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
import subprocess
import gdown
from PIL import Image, ImageFilter
import torch
from torchvision import transforms as T
# -----------------------------
# Paths & constants
# -----------------------------
TMP_DIR = Path(tempfile.gettempdir()) / "pixhtlab"
REPO_DIR = TMP_DIR / "PixHtLab-Src"
WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
# Hugging Face repo and Google Drive weight links
REPO_HF_URL = "https://huggingface.co/karthikeya1212/test"
DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH"
_MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
# -----------------------------
# Logging
# -----------------------------
def _log(*args):
print("[shadow]", *args, flush=True)
def _print_image(img: Image.Image, message: str):
_log(f"πŸ–ΌοΈ {message} | Size: {img.size}, Mode: {img.mode}")
# -----------------------------
# Repo & weight handling
# -----------------------------
def _download_repo():
if REPO_DIR.exists():
_log("Repo already exists:", REPO_DIR)
return
_log("Cloning repository from Hugging Face...")
TMP_DIR.mkdir(parents=True, exist_ok=True)
try:
subprocess.run(["git", "lfs", "install"], check=False)
subprocess.run(["git", "clone", REPO_HF_URL, str(REPO_DIR)], check=True)
_log("Repo cloned successfully to:", REPO_DIR)
except Exception as e:
_log("❌ Failed to clone repo:", repr(e))
raise RuntimeError("Cannot proceed without repository.")
def _download_weights():
if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES:
_log("Weights already exist:", WEIGHT_FILE)
return
_log("Downloading weight file from Google Drive...")
WEIGHT_FILE.parent.mkdir(parents=True, exist_ok=True)
gdown.download(DRIVE_WEIGHT_FILE, str(WEIGHT_FILE), quiet=False, use_cookies=False)
if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES:
raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}")
_log("βœ… Weight file downloaded:", WEIGHT_FILE)
def _validate_assets():
if not REPO_DIR.exists():
raise RuntimeError("Repo folder missing.")
if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES:
raise RuntimeError("Weight file missing or too small.")
_log("Assets validated successfully.")
# -----------------------------
# Auto-search for SSN_Model.py
# -----------------------------
def _find_ssn_model_file():
candidates = list(REPO_DIR.rglob("SSN_Model.py"))
if not candidates:
raise RuntimeError("SSN_Model.py not found in repo.")
_log("Found SSN_Model.py:", candidates[0])
return candidates[0].parent
# -----------------------------
# SSN model wrapper
# -----------------------------
class SSNWrapper:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
_download_repo()
_download_weights()
_validate_assets()
ssn_model_dir = _find_ssn_model_file()
sys.path.insert(0, str(ssn_model_dir.resolve()))
try:
from SSN_Model import SSN_Model # dynamic import from found directory
self.model = SSN_Model()
import numpy as np
import torch.serialization as torch_serialization
with torch.serialization.safe_globals([np.core.multiarray.scalar]):
state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False)
if isinstance(state, dict):
if "model_state_dict" in state:
sd = state["model_state_dict"]
elif "state_dict" in state:
sd = state["state_dict"]
elif "model" in state and isinstance(state["model"], dict):
sd = state["model"]
else:
sd = state
else:
sd = state
self.model.load_state_dict(sd, strict=False)
self.model.eval().to(self.device)
_log(f"βœ… SSN model loaded on {self.device}")
except Exception as e:
_log("❌ Failed to load SSN model:", repr(e))
self.model = None
def available(self) -> bool:
return self.model is not None
@torch.no_grad()
def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]:
if self.model is None:
_log("SSN model not available.")
return None
_log("Preparing image for inference...")
target_size = (512, 512)
to_tensor = T.ToTensor()
fg_rgb_img = rgba_img.convert("RGB")
fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
if bg_rgb is None:
bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255))
bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
_print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image")
try:
_log("Running SSN inference...")
try:
out = self.model(fg_t, bg_t)
except TypeError:
out = self.model.forward(fg=fg_t, bg=bg_t)
if out is None:
_log("SSN inference returned None.")
return None
out_t = out[0] if isinstance(out, (tuple, list)) else out
out_t = torch.clamp(out_t, 0.0, 1.0)
matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0)
matte_img = T.ToPILImage()(matte.cpu())
_print_image(matte_img, "Generated Shadow Matte (512x512)")
return matte_img.resize(rgba_img.size, Image.BILINEAR)
except Exception as e:
_log("❌ SSN inference error:", repr(e))
return None
# -----------------------------
# Singleton
# -----------------------------
_ssn_wrapper: Optional[SSNWrapper] = None
def load_ssn_once() -> SSNWrapper:
global _ssn_wrapper
if _ssn_wrapper is None:
_ssn_wrapper = SSNWrapper()
return _ssn_wrapper
# -----------------------------
# Shadow compositing
# -----------------------------
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
c = color.strip().lstrip("#")
if len(c) == 3:
c = "".join(ch * 2 for ch in c)
if len(c) != 6:
raise ValueError("Invalid color hex.")
return tuple(int(c[i:i+2],16) for i in (0,2,4))
def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
_log("Compositing shadow and image...")
w, h = original_img.size
r,g,b = _hex_to_rgb(params["color"])
matte_rgba = Image.merge("RGBA", (
Image.new("L", (w,h), r),
Image.new("L", (w,h), g),
Image.new("L", (w,h), b),
matte_img
))
opacity = max(0.0, min(1.0, float(params["opacity"])))
if opacity < 1.0:
a = matte_rgba.split()[-1].point(lambda p: int(p*opacity))
matte_rgba.putalpha(a)
softness = max(0.0, float(params["softness"]))
if softness>0:
_log(f"Applying Gaussian blur: {softness}")
matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness))
out = Image.new("RGBA",(w,h),(0,0,0,0))
out.alpha_composite(matte_rgba)
out.alpha_composite(original_img)
_log("Composition complete.")
return out
# -----------------------------
# Public API
# -----------------------------
def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
defaults = dict(softness=28.0, opacity=0.7, color="#000000")
merged = {**defaults, **(params or {})}
merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
merged["softness"] = max(0.0, float(merged["softness"]))
return merged
def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes:
_log("--- Starting shadow generation ---")
params = _apply_params_defaults(params)
img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA")
_print_image(img, "Input RGBA Image")
ssn = load_ssn_once()
if not ssn.available():
_log("❌ SSN model unavailable")
raise RuntimeError("SSN model not available.")
matte_img = ssn.infer_shadow_matte(img)
if matte_img is None:
_log("❌ Failed to generate shadow matte")
raise RuntimeError("Failed to generate shadow matte.")
final_img = _composite_shadow_and_image(img, matte_img, params)
buf = io.BytesIO()
final_img.save(buf, format="PNG")
buf.seek(0)
_log("βœ… Shadow generation complete")
return buf.read()
def initialize_once():
_log("Initializing assets and SSN model...")
load_ssn_once()