karthikeya1212 commited on
Commit
37490ac
Β·
verified Β·
1 Parent(s): fef8aea

Create shadow

Browse files
Files changed (1) hide show
  1. shadow +215 -0
shadow ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import sys
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Optional, Dict, Any, Tuple
7
+
8
+ import gdown
9
+ from PIL import Image, ImageFilter
10
+ import torch
11
+ from torchvision import transforms as T
12
+
13
+ # -----------------------------
14
+ # Repo + weight constants
15
+ # -----------------------------
16
+ REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing"
17
+ REPO_DIR = Path("PixHtLab-Src")
18
+ WEIGHT_FILE = REPO_DIR / "Demo" / "PixhtLab" / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
19
+ _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
20
+
21
+ # -----------------------------
22
+ # Logging utility
23
+ # -----------------------------
24
+ def _log(*args):
25
+ print("[shadow]", *args, flush=True)
26
+
27
+ def _print_image(img: Image.Image, message: str):
28
+ _log(f"πŸ–ΌοΈ {message} | Size: {img.size}, Mode: {img.mode}")
29
+
30
+ # -----------------------------
31
+ # Repo download & validation
32
+ # -----------------------------
33
+ def _download_repo():
34
+ """Download the full repo folder from Google Drive if it does not exist."""
35
+ if REPO_DIR.exists():
36
+ _log("Repo already exists:", REPO_DIR)
37
+ return
38
+
39
+ _log("Downloading repository from Google Drive folder...")
40
+ temp_dir = Path(tempfile.gettempdir()) / "pixhtlab_repo"
41
+ if temp_dir.exists():
42
+ shutil.rmtree(temp_dir)
43
+ temp_dir.mkdir(parents=True, exist_ok=True)
44
+
45
+ try:
46
+ gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False)
47
+ candidates = list(temp_dir.glob("*"))
48
+ if not candidates:
49
+ raise RuntimeError("No files downloaded from the repo folder.")
50
+ shutil.move(str(candidates[0]), str(REPO_DIR))
51
+ _log("Repo downloaded successfully to:", REPO_DIR)
52
+ except Exception as e:
53
+ _log("❌ Failed to download repo:", repr(e))
54
+ raise RuntimeError("Cannot proceed without repository.")
55
+
56
+ def _validate_weights():
57
+ if not WEIGHT_FILE.exists() or WEIGHT_FILE.stat().st_size < _MIN_EXPECTED_WEIGHT_BYTES:
58
+ raise RuntimeError(f"SSN weight file missing or too small: {WEIGHT_FILE}")
59
+ _log("Weight file exists and looks valid:", WEIGHT_FILE)
60
+
61
+ # -----------------------------
62
+ # SSN model wrapper
63
+ # -----------------------------
64
+ class SSNWrapper:
65
+ def __init__(self):
66
+ self.model = None
67
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
68
+
69
+ _download_repo()
70
+ _validate_weights()
71
+
72
+ sys.path.insert(0, str(REPO_DIR.resolve()))
73
+
74
+ try:
75
+ from Demo.SSN.models.SSN_Model import SSN_Model # dynamically loaded from repo
76
+ self.model = SSN_Model()
77
+ state = torch.load(str(WEIGHT_FILE), map_location=self.device)
78
+ if isinstance(state, dict):
79
+ if "model_state_dict" in state:
80
+ sd = state["model_state_dict"]
81
+ elif "state_dict" in state:
82
+ sd = state["state_dict"]
83
+ elif "model" in state and isinstance(state["model"], dict):
84
+ sd = state["model"]
85
+ else:
86
+ sd = state
87
+ else:
88
+ sd = state
89
+ self.model.load_state_dict(sd, strict=False)
90
+ self.model.eval().to(self.device)
91
+ _log(f"βœ… SSN model loaded on {self.device}")
92
+ except Exception as e:
93
+ _log("❌ Failed to load SSN model:", repr(e))
94
+ self.model = None
95
+
96
+ def available(self) -> bool:
97
+ return self.model is not None
98
+
99
+ @torch.no_grad()
100
+ def infer_shadow_matte(self, rgba_img: Image.Image, bg_rgb: Optional[Image.Image] = None) -> Optional[Image.Image]:
101
+ if self.model is None:
102
+ _log("SSN model not available.")
103
+ return None
104
+
105
+ _log("Preparing image for inference...")
106
+ target_size = (512, 512)
107
+ to_tensor = T.ToTensor()
108
+ fg_rgb_img = rgba_img.convert("RGB")
109
+ fg_t = to_tensor(fg_rgb_img.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
110
+
111
+ if bg_rgb is None:
112
+ bg_rgb = Image.new("RGB", rgba_img.size, (255, 255, 255))
113
+ bg_t = to_tensor(bg_rgb.resize(target_size, Image.BICUBIC)).unsqueeze(0).to(self.device)
114
+
115
+ _print_image(fg_rgb_img.resize(target_size, Image.BICUBIC), "Input Foreground Image")
116
+
117
+ try:
118
+ _log("Running SSN inference...")
119
+ try:
120
+ out = self.model(fg_t, bg_t)
121
+ except TypeError:
122
+ out = self.model.forward(fg=fg_t, bg=bg_t)
123
+ if out is None:
124
+ _log("SSN inference returned None.")
125
+ return None
126
+
127
+ out_t = out[0] if isinstance(out, (tuple, list)) else out
128
+ out_t = torch.clamp(out_t, 0.0, 1.0)
129
+ matte = out_t[0, 0] if out_t.shape[1] == 1 else out_t[0].mean(0)
130
+ matte_img = T.ToPILImage()(matte.cpu())
131
+ _print_image(matte_img, "Generated Shadow Matte (512x512)")
132
+ return matte_img.resize(rgba_img.size, Image.BILINEAR)
133
+ except Exception as e:
134
+ _log("❌ SSN inference error:", repr(e))
135
+ return None
136
+
137
+ # -----------------------------
138
+ # Singleton
139
+ # -----------------------------
140
+ _ssn_wrapper: Optional[SSNWrapper] = None
141
+
142
+ def load_ssn_once() -> SSNWrapper:
143
+ global _ssn_wrapper
144
+ if _ssn_wrapper is None:
145
+ _ssn_wrapper = SSNWrapper()
146
+ return _ssn_wrapper
147
+
148
+ # -----------------------------
149
+ # Shadow compositing
150
+ # -----------------------------
151
+ def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
152
+ c = color.strip().lstrip("#")
153
+ if len(c) == 3:
154
+ c = "".join(ch * 2 for ch in c)
155
+ if len(c) != 6:
156
+ raise ValueError("Invalid color hex.")
157
+ return tuple(int(c[i:i+2],16) for i in (0,2,4))
158
+
159
+ def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
160
+ _log("Compositing shadow and image...")
161
+ w, h = original_img.size
162
+ r,g,b = _hex_to_rgb(params["color"])
163
+ matte_rgba = Image.merge("RGBA", (
164
+ Image.new("L", (w,h), r),
165
+ Image.new("L", (w,h), g),
166
+ Image.new("L", (w,h), b),
167
+ matte_img
168
+ ))
169
+ opacity = max(0.0, min(1.0, float(params["opacity"])))
170
+ if opacity < 1.0:
171
+ a = matte_rgba.split()[-1].point(lambda p: int(p*opacity))
172
+ matte_rgba.putalpha(a)
173
+ softness = max(0.0, float(params["softness"]))
174
+ if softness>0:
175
+ _log(f"Applying Gaussian blur: {softness}")
176
+ matte_rgba = matte_rgba.filter(ImageFilter.GaussianBlur(radius=softness))
177
+ out = Image.new("RGBA",(w,h),(0,0,0,0))
178
+ out.alpha_composite(matte_rgba)
179
+ out.alpha_composite(original_img)
180
+ _log("Composition complete.")
181
+ return out
182
+
183
+ # -----------------------------
184
+ # Public API
185
+ # -----------------------------
186
+ def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
187
+ defaults = dict(softness=28.0, opacity=0.7, color="#000000")
188
+ merged = {**defaults, **(params or {})}
189
+ merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
190
+ merged["softness"] = max(0.0, float(merged["softness"]))
191
+ return merged
192
+
193
+ def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]] = None) -> bytes:
194
+ _log("--- Starting shadow generation ---")
195
+ params = _apply_params_defaults(params)
196
+ img = Image.open(io.BytesIO(rgba_file_bytes)).convert("RGBA")
197
+ _print_image(img, "Input RGBA Image")
198
+ ssn = load_ssn_once()
199
+ if not ssn.available():
200
+ _log("❌ SSN model unavailable")
201
+ raise RuntimeError("SSN model not available.")
202
+ matte_img = ssn.infer_shadow_matte(img)
203
+ if matte_img is None:
204
+ _log("❌ Failed to generate shadow matte")
205
+ raise RuntimeError("Failed to generate shadow matte.")
206
+ final_img = _composite_shadow_and_image(img, matte_img, params)
207
+ buf = io.BytesIO()
208
+ final_img.save(buf, format="PNG")
209
+ buf.seek(0)
210
+ _log("βœ… Shadow generation complete")
211
+ return buf.read()
212
+
213
+ def initialize_once():
214
+ _log("Initializing assets and SSN model...")
215
+ load_ssn_once()