Spaces:
Sleeping
Sleeping
Update shadow_generator.py
Browse files- shadow_generator.py +32 -22
shadow_generator.py
CHANGED
|
@@ -213,13 +213,13 @@
|
|
| 213 |
# def initialize_once():
|
| 214 |
# _log("Initializing assets and SSN model...")
|
| 215 |
# load_ssn_once()
|
| 216 |
-
|
| 217 |
-
import io
|
| 218 |
import sys
|
| 219 |
import shutil
|
| 220 |
import tempfile
|
| 221 |
from pathlib import Path
|
| 222 |
from typing import Optional, Dict, Any, Tuple
|
|
|
|
| 223 |
|
| 224 |
import gdown
|
| 225 |
from PIL import Image, ImageFilter
|
|
@@ -233,8 +233,8 @@ TMP_DIR = Path(tempfile.gettempdir()) / "pixhtlab"
|
|
| 233 |
REPO_DIR = TMP_DIR / "PixHtLab-Src"
|
| 234 |
WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
|
| 235 |
|
| 236 |
-
#
|
| 237 |
-
|
| 238 |
DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH"
|
| 239 |
|
| 240 |
_MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
|
|
@@ -242,6 +242,7 @@ _MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
|
|
| 242 |
# -----------------------------
|
| 243 |
# Logging
|
| 244 |
# -----------------------------
|
|
|
|
| 245 |
def _log(*args):
|
| 246 |
print("[shadow]", *args, flush=True)
|
| 247 |
|
|
@@ -251,31 +252,24 @@ def _print_image(img: Image.Image, message: str):
|
|
| 251 |
# -----------------------------
|
| 252 |
# Repo & weight handling
|
| 253 |
# -----------------------------
|
|
|
|
| 254 |
def _download_repo():
|
| 255 |
-
"""Download the repo folder from Google Drive into /tmp."""
|
| 256 |
if REPO_DIR.exists():
|
| 257 |
_log("Repo already exists:", REPO_DIR)
|
| 258 |
return
|
| 259 |
|
| 260 |
-
_log("
|
| 261 |
-
|
| 262 |
-
if temp_dir.exists():
|
| 263 |
-
shutil.rmtree(temp_dir)
|
| 264 |
-
temp_dir.mkdir(parents=True, exist_ok=True)
|
| 265 |
-
|
| 266 |
try:
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
raise RuntimeError("No files downloaded from the repo folder.")
|
| 271 |
-
shutil.move(str(candidates[0]), str(REPO_DIR))
|
| 272 |
-
_log("Repo downloaded successfully to:", REPO_DIR)
|
| 273 |
except Exception as e:
|
| 274 |
-
_log("β Failed to
|
| 275 |
raise RuntimeError("Cannot proceed without repository.")
|
| 276 |
|
|
|
|
| 277 |
def _download_weights():
|
| 278 |
-
"""Download .pt weight file into /tmp/weights/."""
|
| 279 |
if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES:
|
| 280 |
_log("Weights already exist:", WEIGHT_FILE)
|
| 281 |
return
|
|
@@ -288,6 +282,7 @@ def _download_weights():
|
|
| 288 |
raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}")
|
| 289 |
_log("β
Weight file downloaded:", WEIGHT_FILE)
|
| 290 |
|
|
|
|
| 291 |
def _validate_assets():
|
| 292 |
if not REPO_DIR.exists():
|
| 293 |
raise RuntimeError("Repo folder missing.")
|
|
@@ -295,9 +290,21 @@ def _validate_assets():
|
|
| 295 |
raise RuntimeError("Weight file missing or too small.")
|
| 296 |
_log("Assets validated successfully.")
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
# -----------------------------
|
| 299 |
# SSN model wrapper
|
| 300 |
# -----------------------------
|
|
|
|
| 301 |
class SSNWrapper:
|
| 302 |
def __init__(self):
|
| 303 |
self.model = None
|
|
@@ -307,11 +314,11 @@ class SSNWrapper:
|
|
| 307 |
_download_weights()
|
| 308 |
_validate_assets()
|
| 309 |
|
| 310 |
-
|
| 311 |
-
sys.path.insert(0, str(
|
| 312 |
|
| 313 |
try:
|
| 314 |
-
from
|
| 315 |
self.model = SSN_Model()
|
| 316 |
state = torch.load(str(WEIGHT_FILE), map_location=self.device)
|
| 317 |
|
|
@@ -378,6 +385,7 @@ class SSNWrapper:
|
|
| 378 |
# -----------------------------
|
| 379 |
# Singleton
|
| 380 |
# -----------------------------
|
|
|
|
| 381 |
_ssn_wrapper: Optional[SSNWrapper] = None
|
| 382 |
|
| 383 |
def load_ssn_once() -> SSNWrapper:
|
|
@@ -389,6 +397,7 @@ def load_ssn_once() -> SSNWrapper:
|
|
| 389 |
# -----------------------------
|
| 390 |
# Shadow compositing
|
| 391 |
# -----------------------------
|
|
|
|
| 392 |
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
|
| 393 |
c = color.strip().lstrip("#")
|
| 394 |
if len(c) == 3:
|
|
@@ -424,6 +433,7 @@ def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Imag
|
|
| 424 |
# -----------------------------
|
| 425 |
# Public API
|
| 426 |
# -----------------------------
|
|
|
|
| 427 |
def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 428 |
defaults = dict(softness=28.0, opacity=0.7, color="#000000")
|
| 429 |
merged = {**defaults, **(params or {})}
|
|
|
|
| 213 |
# def initialize_once():
|
| 214 |
# _log("Initializing assets and SSN model...")
|
| 215 |
# load_ssn_once()
|
| 216 |
+
import io
|
|
|
|
| 217 |
import sys
|
| 218 |
import shutil
|
| 219 |
import tempfile
|
| 220 |
from pathlib import Path
|
| 221 |
from typing import Optional, Dict, Any, Tuple
|
| 222 |
+
import subprocess
|
| 223 |
|
| 224 |
import gdown
|
| 225 |
from PIL import Image, ImageFilter
|
|
|
|
| 233 |
REPO_DIR = TMP_DIR / "PixHtLab-Src"
|
| 234 |
WEIGHT_FILE = TMP_DIR / "weights" / "human_baseline_all_21-July-04-52-AM.pt"
|
| 235 |
|
| 236 |
+
# Hugging Face repo and Google Drive weight links
|
| 237 |
+
REPO_HF_URL = "https://huggingface.co/karthikeya1212/test"
|
| 238 |
DRIVE_WEIGHT_FILE = "https://drive.google.com/uc?id=1XwR2krb473vc46E_XB9ET611XPD1i8SH"
|
| 239 |
|
| 240 |
_MIN_EXPECTED_WEIGHT_BYTES = 50 * 1024 * 1024
|
|
|
|
| 242 |
# -----------------------------
|
| 243 |
# Logging
|
| 244 |
# -----------------------------
|
| 245 |
+
|
| 246 |
def _log(*args):
|
| 247 |
print("[shadow]", *args, flush=True)
|
| 248 |
|
|
|
|
| 252 |
# -----------------------------
|
| 253 |
# Repo & weight handling
|
| 254 |
# -----------------------------
|
| 255 |
+
|
| 256 |
def _download_repo():
|
|
|
|
| 257 |
if REPO_DIR.exists():
|
| 258 |
_log("Repo already exists:", REPO_DIR)
|
| 259 |
return
|
| 260 |
|
| 261 |
+
_log("Cloning repository from Hugging Face...")
|
| 262 |
+
TMP_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
try:
|
| 264 |
+
subprocess.run(["git", "lfs", "install"], check=False)
|
| 265 |
+
subprocess.run(["git", "clone", REPO_HF_URL, str(REPO_DIR)], check=True)
|
| 266 |
+
_log("Repo cloned successfully to:", REPO_DIR)
|
|
|
|
|
|
|
|
|
|
| 267 |
except Exception as e:
|
| 268 |
+
_log("β Failed to clone repo:", repr(e))
|
| 269 |
raise RuntimeError("Cannot proceed without repository.")
|
| 270 |
|
| 271 |
+
|
| 272 |
def _download_weights():
|
|
|
|
| 273 |
if WEIGHT_FILE.exists() and WEIGHT_FILE.stat().st_size > _MIN_EXPECTED_WEIGHT_BYTES:
|
| 274 |
_log("Weights already exist:", WEIGHT_FILE)
|
| 275 |
return
|
|
|
|
| 282 |
raise RuntimeError(f"Downloaded weight file is invalid: {WEIGHT_FILE}")
|
| 283 |
_log("β
Weight file downloaded:", WEIGHT_FILE)
|
| 284 |
|
| 285 |
+
|
| 286 |
def _validate_assets():
|
| 287 |
if not REPO_DIR.exists():
|
| 288 |
raise RuntimeError("Repo folder missing.")
|
|
|
|
| 290 |
raise RuntimeError("Weight file missing or too small.")
|
| 291 |
_log("Assets validated successfully.")
|
| 292 |
|
| 293 |
+
# -----------------------------
|
| 294 |
+
# Auto-search for SSN_Model.py
|
| 295 |
+
# -----------------------------
|
| 296 |
+
|
| 297 |
+
def _find_ssn_model_file():
|
| 298 |
+
candidates = list(REPO_DIR.rglob("SSN_Model.py"))
|
| 299 |
+
if not candidates:
|
| 300 |
+
raise RuntimeError("SSN_Model.py not found in repo.")
|
| 301 |
+
_log("Found SSN_Model.py:", candidates[0])
|
| 302 |
+
return candidates[0].parent
|
| 303 |
+
|
| 304 |
# -----------------------------
|
| 305 |
# SSN model wrapper
|
| 306 |
# -----------------------------
|
| 307 |
+
|
| 308 |
class SSNWrapper:
|
| 309 |
def __init__(self):
|
| 310 |
self.model = None
|
|
|
|
| 314 |
_download_weights()
|
| 315 |
_validate_assets()
|
| 316 |
|
| 317 |
+
ssn_model_dir = _find_ssn_model_file()
|
| 318 |
+
sys.path.insert(0, str(ssn_model_dir.resolve()))
|
| 319 |
|
| 320 |
try:
|
| 321 |
+
from SSN_Model import SSN_Model # dynamic import from found directory
|
| 322 |
self.model = SSN_Model()
|
| 323 |
state = torch.load(str(WEIGHT_FILE), map_location=self.device)
|
| 324 |
|
|
|
|
| 385 |
# -----------------------------
|
| 386 |
# Singleton
|
| 387 |
# -----------------------------
|
| 388 |
+
|
| 389 |
_ssn_wrapper: Optional[SSNWrapper] = None
|
| 390 |
|
| 391 |
def load_ssn_once() -> SSNWrapper:
|
|
|
|
| 397 |
# -----------------------------
|
| 398 |
# Shadow compositing
|
| 399 |
# -----------------------------
|
| 400 |
+
|
| 401 |
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
|
| 402 |
c = color.strip().lstrip("#")
|
| 403 |
if len(c) == 3:
|
|
|
|
| 433 |
# -----------------------------
|
| 434 |
# Public API
|
| 435 |
# -----------------------------
|
| 436 |
+
|
| 437 |
def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 438 |
defaults = dict(softness=28.0, opacity=0.7, color="#000000")
|
| 439 |
merged = {**defaults, **(params or {})}
|