WizardWang01's picture
Add app.py
d6deb62 verified
import numpy as np
from PIL import Image, ImageFilter
import torch
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation, AutoModelForDepthEstimation
from scipy.ndimage import gaussian_filter
import gradio as gr
# Global models (loaded once at startup)
segmentation_model = None
segmentation_processor = None
depth_model = None
depth_processor = None
device = None
def load_models():
"""Load all required models at startup"""
global segmentation_model, segmentation_processor, depth_model, depth_processor, device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load segmentation model
print("Loading segmentation model...")
seg_model_id = "nvidia/segformer-b0-finetuned-ade-512-512"
segmentation_processor = AutoImageProcessor.from_pretrained(seg_model_id)
segmentation_model = AutoModelForSemanticSegmentation.from_pretrained(seg_model_id)
segmentation_model.eval()
segmentation_model.to(device)
# Load depth estimation model
print("Loading depth estimation model...")
depth_model_id = "depth-anything/Depth-Anything-V2-Base-hf"
depth_processor = AutoImageProcessor.from_pretrained(depth_model_id)
depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_id)
depth_model.eval()
depth_model.to(device)
print("Models loaded successfully!")
def get_person_mask(image):
"""Extract person mask from image using semantic segmentation"""
# Resize to 512x512 for processing
img_512 = image.resize((512, 512), Image.BILINEAR)
# Run segmentation
inputs = segmentation_processor(images=img_512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = segmentation_model(**inputs)
logits = torch.nn.functional.interpolate(
outputs.logits, size=(512, 512), mode="bilinear", align_corners=False
)
pred = logits.argmax(dim=1)[0].cpu().numpy()
# Find person class ID
id2label = segmentation_model.config.id2label
label2id = {v.lower(): int(k) for k, v in id2label.items()}
person_key = next((k for k in label2id.keys() if k in ["person", "people", "human"]), None)
if person_key is None:
# If no person found, return empty mask
return Image.new("L", (512, 512), 0)
person_id = label2id[person_key]
mask = (pred == person_id).astype(np.uint8) * 255
return Image.fromarray(mask, mode="L")
def gaussian_blur_effect(image, blur_radius=15):
"""Apply Gaussian blur to background, keep person sharp"""
if image is None:
return None
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 512x512
img_512 = image.resize((512, 512), Image.BILINEAR)
# Get person mask
mask_img = get_person_mask(img_512)
# Apply Gaussian blur to entire image
blurred_img = img_512.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# Composite: person (sharp) + background (blurred)
input_array = np.array(img_512)
blurred_array = np.array(blurred_img)
mask_array = np.array(mask_img) / 255.0
mask_3ch = np.stack([mask_array] * 3, axis=-1)
output_array = (input_array * mask_3ch + blurred_array * (1 - mask_3ch)).astype(np.uint8)
output_img = Image.fromarray(output_array)
return output_img
def get_depth_map(image):
"""Estimate depth map from image"""
# Resize to 512x512
img_512 = image.resize((512, 512), Image.BILINEAR)
# Run depth estimation
inputs = depth_processor(images=img_512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to 512x512
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=(512, 512),
mode="bicubic",
align_corners=False,
)
depth_map = prediction.squeeze().cpu().numpy()
return depth_map
def lens_blur_effect(image, max_blur=15, focus_threshold=5.0):
"""Apply depth-based lens blur (foreground sharp, background blurred)"""
if image is None:
return None
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 512x512
img_512 = image.resize((512, 512), Image.BILINEAR)
# Get depth map
depth_map = get_depth_map(img_512)
# Invert depth (higher values = farther = more blur)
depth_normalized = (depth_map.max() - depth_map) / (depth_map.max() - depth_map.min())
depth_normalized = depth_normalized * max_blur
# Create blur map
blur_map = np.zeros_like(depth_normalized)
close_mask = depth_normalized <= focus_threshold
blur_map[close_mask] = 0.0
far_mask = depth_normalized > focus_threshold
blur_map[far_mask] = ((depth_normalized[far_mask] - focus_threshold) / (max_blur - focus_threshold)) * max_blur
# Apply variable blur
img_array = np.array(img_512).astype(np.float32)
output_array = img_array.copy()
num_blur_levels = 20
for level in range(1, num_blur_levels + 1):
sigma_min = (level - 1) * max_blur / num_blur_levels
sigma_max = level * max_blur / num_blur_levels
sigma_avg = (sigma_min + sigma_max) / 2.0
mask = ((blur_map >= sigma_min) & (blur_map < sigma_max)).astype(np.float32)
if mask.sum() > 0 and sigma_avg > 0.1:
blurred = np.zeros_like(img_array)
for c in range(3):
blurred[:, :, c] = gaussian_filter(img_array[:, :, c], sigma=sigma_avg)
mask_3ch = np.stack([mask] * 3, axis=-1)
output_array = output_array * (1 - mask_3ch) + blurred * mask_3ch
output_array = np.clip(output_array, 0, 255).astype(np.uint8)
output_img = Image.fromarray(output_array)
return output_img
# Load models at startup
load_models()
# Create Gradio interface
with gr.Blocks(title="Image Blur Effects Demo") as demo:
gr.Markdown("""
# 🎨 Image Blur Effects Demo
Upload an image to apply **Gaussian Blur** or **Lens Blur** effects.
- **Gaussian Blur**: Detects people and blurs the background, keeping the person sharp.
- **Lens Blur**: Uses depth estimation to simulate camera lens bokeh effect (foreground sharp, background blurred).
""")
with gr.Tab("Gaussian Blur"):
gr.Markdown("### Background blur with person detection")
with gr.Row():
with gr.Column():
gaussian_input = gr.Image(type="pil", label="Input Image")
gaussian_radius = gr.Slider(
minimum=5, maximum=30, value=15, step=1,
label="Blur Radius (σ)"
)
gaussian_btn = gr.Button("Apply Gaussian Blur", variant="primary")
with gr.Column():
gaussian_output = gr.Image(type="pil", label="Output Image")
gaussian_btn.click(
fn=gaussian_blur_effect,
inputs=[gaussian_input, gaussian_radius],
outputs=gaussian_output
)
gr.Examples(
examples=[["self.jpg"], ["self-pic.jpg"]],
inputs=gaussian_input,
label="Example Images"
)
with gr.Tab("Lens Blur (Depth-Based)"):
gr.Markdown("### Depth-based bokeh effect simulation")
with gr.Row():
with gr.Column():
lens_input = gr.Image(type="pil", label="Input Image")
lens_max_blur = gr.Slider(
minimum=5, maximum=25, value=15, step=1,
label="Max Blur Intensity"
)
lens_focus = gr.Slider(
minimum=0, maximum=10, value=5.0, step=0.5,
label="Focus Threshold (lower = more blur)"
)
lens_btn = gr.Button("Apply Lens Blur", variant="primary")
with gr.Column():
lens_output = gr.Image(type="pil", label="Output Image")
lens_btn.click(
fn=lens_blur_effect,
inputs=[lens_input, lens_max_blur, lens_focus],
outputs=lens_output
)
gr.Examples(
examples=[["self.jpg"], ["self-pic.jpg"]],
inputs=lens_input,
label="Example Images"
)
gr.Markdown("""
---
**Technical Details:**
- Segmentation: NVIDIA SegFormer (ADE20K)
- Depth Estimation: Depth Anything V2
- All images resized to 512×512 for processing
""")
if __name__ == "__main__":
demo.launch()