karthikeya1212 commited on
Commit
b12ac8d
Β·
verified Β·
1 Parent(s): 6a1604f

Update shadow_generator.py

Browse files
Files changed (1) hide show
  1. shadow_generator.py +32 -22
shadow_generator.py CHANGED
@@ -213,13 +213,13 @@
213
  # def initialize_once():
214
  # _log("Initializing assets and SSN model...")
215
  # load_ssn_once()
216
-
217
- import io
218
  import sys
219
  import shutil
220
  import tempfile
221
  from pathlib import Path
222
  from typing import Optional, Dict, Any, Tuple
 
223
 
224
  import gdown
225
  from PIL import Image, ImageFilter
@@ -233,8 +233,8 @@ TMP_DIR = Path(tempfile.gettempdir()) / "pixhtlab"
233
  REPO_DIR = TMP_DIR / "PixHtLab-Src"
234
  WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
235
 
236
- # Replace with your Drive folder & file link
237
- REPO_DRIVE_FOLDER = "https://drive.google.com/drive/folders/1YzxVaxoOXwBrdB9XHyoOgWz8z4BpLQFl?usp=sharing"
238
  DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH"
239
 
240
  _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
@@ -242,6 +242,7 @@ _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
242
  # -----------------------------
243
  # Logging
244
  # -----------------------------
 
245
  def _log(*args):
246
  print("[shadow]", *args, flush=True)
247
 
@@ -251,31 +252,24 @@ def _print_image(img: Image.Image, message: str):
251
  # -----------------------------
252
  # Repo & weight handling
253
  # -----------------------------
 
254
  def _download_repo():
255
- """Download the repo folder from Google Drive into /tmp."""
256
  if REPO_DIR.exists():
257
  _log("Repo already exists:", REPO_DIR)
258
  return
259
 
260
- _log("Downloading repository from Google Drive folder...")
261
- temp_dir = TMP_DIR / "repo_dl"
262
- if temp_dir.exists():
263
- shutil.rmtree(temp_dir)
264
- temp_dir.mkdir(parents=True, exist_ok=True)
265
-
266
  try:
267
- gdown.download_folder(REPO_DRIVE_FOLDER, output=str(temp_dir), quiet=False, use_cookies=False)
268
- candidates = list(temp_dir.glob("*"))
269
- if not candidates:
270
- raise RuntimeError("No files downloaded from the repo folder.")
271
- shutil.move(str(candidates[0]), str(REPO_DIR))
272
- _log("Repo downloaded successfully to:", REPO_DIR)
273
  except Exception as e:
274
- _log("❌ Failed to download repo:", repr(e))
275
  raise RuntimeError("Cannot proceed without repository.")
276
 
 
277
  def _download_weights():
278
- """Download .pt weight file into /tmp/weights/."""
279
  if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES:
280
  _log("Weights already exist:", WEIGHT_FILE)
281
  return
@@ -288,6 +282,7 @@ def _download_weights():
288
  raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}")
289
  _log("βœ… Weight file downloaded:", WEIGHT_FILE)
290
 
 
291
  def _validate_assets():
292
  if not REPO_DIR.exists():
293
  raise RuntimeError("Repo folder missing.")
@@ -295,9 +290,21 @@ def _validate_assets():
295
  raise RuntimeError("Weight file missing or too small.")
296
  _log("Assets validated successfully.")
297
 
 
 
 
 
 
 
 
 
 
 
 
298
  # -----------------------------
299
  # SSN model wrapper
300
  # -----------------------------
 
301
  class SSNWrapper:
302
  def __init__(self):
303
  self.model = None
@@ -307,11 +314,11 @@ class SSNWrapper:
307
  _download_weights()
308
  _validate_assets()
309
 
310
- # Add repo code to sys.path
311
- sys.path.insert(0, str(REPO_DIR.resolve()))
312
 
313
  try:
314
- from Demo.SSN.models.SSN_Model import SSN_Model # loaded from repo
315
  self.model = SSN_Model()
316
  state = torch.load(str(WEIGHT_FILE), map_location=self.device)
317
 
@@ -378,6 +385,7 @@ class SSNWrapper:
378
  # -----------------------------
379
  # Singleton
380
  # -----------------------------
 
381
  _ssn_wrapper: Optional[SSNWrapper] = None
382
 
383
  def load_ssn_once() -> SSNWrapper:
@@ -389,6 +397,7 @@ def load_ssn_once() -> SSNWrapper:
389
  # -----------------------------
390
  # Shadow compositing
391
  # -----------------------------
 
392
  def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
393
  c = color.strip().lstrip("#")
394
  if len(c) == 3:
@@ -424,6 +433,7 @@ def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Imag
424
  # -----------------------------
425
  # Public API
426
  # -----------------------------
 
427
  def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
428
  defaults = dict(softness=28.0, opacity=0.7, color="#000000")
429
  merged = {**defaults, **(params or {})}
 
213
  # def initialize_once():
214
  # _log("Initializing assets and SSN model...")
215
  # load_ssn_once()
216
+ import io
 
217
  import sys
218
  import shutil
219
  import tempfile
220
  from pathlib import Path
221
  from typing import Optional, Dict, Any, Tuple
222
+ import subprocess
223
 
224
  import gdown
225
  from PIL import Image, ImageFilter
 
233
  REPO_DIR = TMP_DIR / "PixHtLab-Src"
234
  WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
235
 
236
+ # Hugging Face repo and Google Drive weight links
237
+ REPO_HF_URL = "https://huggingface.co/karthikeya1212/test"
238
  DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH"
239
 
240
  _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
 
242
  # -----------------------------
243
  # Logging
244
  # -----------------------------
245
+
246
  def _log(*args):
247
  print("[shadow]", *args, flush=True)
248
 
 
252
  # -----------------------------
253
  # Repo & weight handling
254
  # -----------------------------
255
+
256
  def _download_repo():
 
257
  if REPO_DIR.exists():
258
  _log("Repo already exists:", REPO_DIR)
259
  return
260
 
261
+ _log("Cloning repository from Hugging Face...")
262
+ TMP_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
263
  try:
264
+ subprocess.run(["git", "lfs", "install"], check=False)
265
+ subprocess.run(["git", "clone", REPO_HF_URL, str(REPO_DIR)], check=True)
266
+ _log("Repo cloned successfully to:", REPO_DIR)
 
 
 
267
  except Exception as e:
268
+ _log("❌ Failed to clone repo:", repr(e))
269
  raise RuntimeError("Cannot proceed without repository.")
270
 
271
+
272
  def _download_weights():
 
273
  if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES:
274
  _log("Weights already exist:", WEIGHT_FILE)
275
  return
 
282
  raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}")
283
  _log("βœ… Weight file downloaded:", WEIGHT_FILE)
284
 
285
+
286
  def _validate_assets():
287
  if not REPO_DIR.exists():
288
  raise RuntimeError("Repo folder missing.")
 
290
  raise RuntimeError("Weight file missing or too small.")
291
  _log("Assets validated successfully.")
292
 
293
+ # -----------------------------
294
+ # Auto-search for SSN_Model.py
295
+ # -----------------------------
296
+
297
+ def _find_ssn_model_file():
298
+ candidates = list(REPO_DIR.rglob("SSN_Model.py"))
299
+ if not candidates:
300
+ raise RuntimeError("SSN_Model.py not found in repo.")
301
+ _log("Found SSN_Model.py:", candidates[0])
302
+ return candidates[0].parent
303
+
304
  # -----------------------------
305
  # SSN model wrapper
306
  # -----------------------------
307
+
308
  class SSNWrapper:
309
  def __init__(self):
310
  self.model = None
 
314
  _download_weights()
315
  _validate_assets()
316
 
317
+ ssn_model_dir = _find_ssn_model_file()
318
+ sys.path.insert(0, str(ssn_model_dir.resolve()))
319
 
320
  try:
321
+ from SSN_Model import SSN_Model # dynamic import from found directory
322
  self.model = SSN_Model()
323
  state = torch.load(str(WEIGHT_FILE), map_location=self.device)
324
 
 
385
  # -----------------------------
386
  # Singleton
387
  # -----------------------------
388
+
389
  _ssn_wrapper: Optional[SSNWrapper] = None
390
 
391
  def load_ssn_once() -> SSNWrapper:
 
397
  # -----------------------------
398
  # Shadow compositing
399
  # -----------------------------
400
+
401
  def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
402
  c = color.strip().lstrip("#")
403
  if len(c) == 3:
 
433
  # -----------------------------
434
  # Public API
435
  # -----------------------------
436
+
437
  def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
438
  defaults = dict(softness=28.0, opacity=0.7, color="#000000")
439
  merged = {**defaults, **(params or {})}