test / shadow_generator.py
karthikeya1212's picture
Rename shadow to shadow_generator.py
c031e40 verified
raw
history blame
8.04 kB
import io
import sys
import shutil
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
import gdown
from PIL import Image, ImageFilter
import torch
from torchvision import transforms as T
# -----------------------------
# Repo + weight constants
# -----------------------------
REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing"
REPO_DIR = Path("PixHtLab-Src")
WEIGHT_FILE = REPO_DIR / "Demo" / "PixhtLab" / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
_MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
# -----------------------------
# Logging utility
# -----------------------------
def _log(*args):
print("[shadow]", *args, flush=True)
def _print_image(img: Image.Image, message: str):
_log(f"πŸ–ΌοΈ {message} | Size: {img.size}, Mode: {img.mode}")
# -----------------------------
# Repo download & validation
# -----------------------------
def _download_repo():
"""Download the full repo folder from Google Drive if it does not exist."""
if REPO_DIR.exists():
_log("Repo already exists:", REPO_DIR)
return
_log("Downloading repository from Google Drive folder...")
temp_dir = Path(tempfile.gettempdir()) / "pixhtlab_repo"
if temp_dir.exists():
shutil.rmtree(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
try:
gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False)
candidates = list(temp_dir.glob("*"))
if not candidates:
raise RuntimeError("No files downloaded from the repo folder.")
shutil.move(str(candidates[0]), str(REPO_DIR))
_log("Repo downloaded successfully to:", REPO_DIR)
except Exception as e:
_log("❌ Failed to download repo:", repr(e))
raise RuntimeError("Cannot proceed without repository.")
def _validate_weights():
if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES:
raise RuntimeError(f"SSN weight file missing or too small: {WEIGHT_FILE}")
_log("Weight file exists and looks valid:", WEIGHT_FILE)
# -----------------------------
# SSN model wrapper
# -----------------------------
class SSNWrapper:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
_download_repo()
_validate_weights()
sys.path.insert(0, str(REPO_DIR.resolve()))
try:
from Demo.SSN.models.SSN_Model import SSN_Model # dynamically loaded from repo
self.model = SSN_Model()
state = torch.load(str(WEIGHT_FILE), map_location=self.device)
if isinstance(state, dict):
if "model_state_dict" in state:
sd = state["model_state_dict"]
elif "state_dict" in state:
sd = state["state_dict"]
elif "model" in state and isinstance(state["model"], dict):
sd = state["model"]
else:
sd = state
else:
sd = state
self.model.load_state_dict(sd, strict=False)
self.model.eval().to(self.device)
_log(f"βœ… SSN model loaded on {self.device}")
except Exception as e:
_log("❌ Failed to load SSN model:", repr(e))
self.model = None
def available(self) -> bool:
return self.model is not None
@torch.no_grad()
def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]:
if self.model is None:
_log("SSN model not available.")
return None
_log("Preparing image for inference...")
target_size = (512, 512)
to_tensor = T.ToTensor()
fg_rgb_img = rgba_img.convert("RGB")
fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
if bg_rgb is None:
bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255))
bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
_print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image")
try:
_log("Running SSN inference...")
try:
out = self.model(fg_t, bg_t)
except TypeError:
out = self.model.forward(fg=fg_t, bg=bg_t)
if out is None:
_log("SSN inference returned None.")
return None
out_t = out[0] if isinstance(out, (tuple, list)) else out
out_t = torch.clamp(out_t, 0.0, 1.0)
matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0)
matte_img = T.ToPILImage()(matte.cpu())
_print_image(matte_img, "Generated Shadow Matte (512x512)")
return matte_img.resize(rgba_img.size, Image.BILINEAR)
except Exception as e:
_log("❌ SSN inference error:", repr(e))
return None
# -----------------------------
# Singleton
# -----------------------------
_ssn_wrapper: Optional[SSNWrapper] = None
def load_ssn_once() -> SSNWrapper:
global _ssn_wrapper
if _ssn_wrapper is None:
_ssn_wrapper = SSNWrapper()
return _ssn_wrapper
# -----------------------------
# Shadow compositing
# -----------------------------
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
c = color.strip().lstrip("#")
if len(c) == 3:
c = "".join(ch * 2 for ch in c)
if len(c) != 6:
raise ValueError("Invalid color hex.")
return tuple(int(c[i:i+2],16) for i in (0,2,4))
def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
_log("Compositing shadow and image...")
w, h = original_img.size
r,g,b = _hex_to_rgb(params["color"])
matte_rgba = Image.merge("RGBA", (
Image.new("L", (w,h), r),
Image.new("L", (w,h), g),
Image.new("L", (w,h), b),
matte_img
))
opacity = max(0.0, min(1.0, float(params["opacity"])))
if opacity < 1.0:
a = matte_rgba.split()[-1].point(lambda p: int(p*opacity))
matte_rgba.putalpha(a)
softness = max(0.0, float(params["softness"]))
if softness>0:
_log(f"Applying Gaussian blur: {softness}")
matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness))
out = Image.new("RGBA",(w,h),(0,0,0,0))
out.alpha_composite(matte_rgba)
out.alpha_composite(original_img)
_log("Composition complete.")
return out
# -----------------------------
# Public API
# -----------------------------
def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
defaults = dict(softness=28.0, opacity=0.7, color="#000000")
merged = {**defaults, **(params or {})}
merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
merged["softness"] = max(0.0, float(merged["softness"]))
return merged
def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes:
_log("--- Starting shadow generation ---")
params = _apply_params_defaults(params)
img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA")
_print_image(img, "Input RGBA Image")
ssn = load_ssn_once()
if not ssn.available():
_log("❌ SSN model unavailable")
raise RuntimeError("SSN model not available.")
matte_img = ssn.infer_shadow_matte(img)
if matte_img is None:
_log("❌ Failed to generate shadow matte")
raise RuntimeError("Failed to generate shadow matte.")
final_img = _composite_shadow_and_image(img, matte_img, params)
buf = io.BytesIO()
final_img.save(buf, format="PNG")
buf.seek(0)
_log("βœ… Shadow generation complete")
return buf.read()
def initialize_once():
_log("Initializing assets and SSN model...")
load_ssn_once()