File size: 8,589 Bytes
c749ab5
 
e719a67
 
 
46d271d
 
 
 
2b7cea2
46d271d
 
 
 
 
81e21b3
 
 
 
e719a67
81e21b3
 
 
 
 
 
c749ab5
 
acce386
 
46d271d
 
c749ab5
46d271d
 
 
 
 
 
 
 
 
acce386
81e21b3
acce386
 
 
3a2ddb6
acce386
 
 
46d271d
c749ab5
46d271d
 
 
 
 
 
 
c749ab5
46d271d
 
acce386
46d271d
 
 
 
 
 
 
 
 
 
 
 
 
c749ab5
46d271d
acce386
46d271d
 
 
 
 
 
e719a67
81e21b3
46d271d
 
acce386
 
 
 
46d271d
 
81e21b3
46d271d
c749ab5
81e21b3
46d271d
 
c749ab5
46d271d
e719a67
acce386
e719a67
 
 
 
 
 
 
acce386
46d271d
81e21b3
46d271d
e719a67
46d271d
 
 
acce386
 
46d271d
 
 
 
e719a67
81e21b3
 
 
46d271d
 
81e21b3
 
 
 
 
 
 
c749ab5
46d271d
81e21b3
 
c749ab5
46d271d
 
 
81e21b3
 
 
 
 
 
46d271d
 
 
 
 
 
 
 
3a2ddb6
46d271d
3a2ddb6
 
 
 
acce386
3a2ddb6
46d271d
acce386
 
46d271d
 
 
 
acce386
46d271d
 
e719a67
1b4e413
acce386
 
 
 
 
 
 
 
 
e719a67
acce386
 
 
 
 
 
0103eea
c749ab5
 
1b4e413
46d271d
3a2ddb6
acce386
 
 
 
 
3a2ddb6
 
acce386
 
 
 
 
46d271d
c749ab5
46d271d
81e21b3
 
46d271d
c749ab5
46d271d
acce386
e719a67
46d271d
3a2ddb6
c749ab5
81e21b3
46d271d
0103eea
acce386
46d271d
 
acce386
46d271d
 
 
acce386
46d271d
 
 
 
 
 
 
 
 
 
acce386
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
SHARP Gradio Demo
- Standard Native Layout
- Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output)
- Fixed: Download Button visibility logic
"""

from __future__ import annotations

import warnings
import json
from pathlib import Path
from typing import Final
import gradio as gr

# --- 1. Import Spaces for ZeroGPU Support ---
try:
    import spaces
except ImportError:
    # Fallback for local testing if spaces is not installed
    class spaces:
        @staticmethod
        def GPU(func):
            return func

# Suppress internal warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed")

# Ensure model_utils is present in your directory
from model_utils import TrajectoryType, predict_and_maybe_render_gpu

# -----------------------------------------------------------------------------
# Paths & Config
# -----------------------------------------------------------------------------

APP_DIR: Final[Path] = Path(__file__).resolve().parent
OUTPUTS_DIR: Final[Path] = APP_DIR / "outputs"
ASSETS_DIR: Final[Path] = APP_DIR / "assets"
EXAMPLES_DIR: Final[Path] = ASSETS_DIR / "examples"

IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp")

# -----------------------------------------------------------------------------
# SEO
# -----------------------------------------------------------------------------

SEO_HEAD = """
<meta name="description" content="Turn 2D images into 3D Gaussian Splats instantly. SHARP (Apple) AI Demo.">
<meta name="viewport" content="width=device-width, initial-scale=1">
"""

# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------

def _ensure_dir(path: Path) -> Path:
    path.mkdir(parents=True, exist_ok=True)
    return path

def get_example_files() -> list[list[str]]:
    """Discover images in assets/examples for the UI."""
    _ensure_dir(EXAMPLES_DIR)
    
    # Check manifest.json first
    manifest_path = EXAMPLES_DIR / "manifest.json"
    if manifest_path.exists():
        try:
            data = json.loads(manifest_path.read_text(encoding="utf-8"))
            examples = []
            for entry in data:
                if "image" in entry:
                    img_path = EXAMPLES_DIR / entry["image"]
                    if img_path.exists():
                        examples.append([str(img_path)])
            if examples:
                return examples
        except Exception as e:
            print(f"Manifest error: {e}")

    # Fallback: simple file scan
    examples = []
    for ext in IMAGE_EXTS:
        for img in sorted(EXAMPLES_DIR.glob(f"*{ext}")):
            examples.append([str(img)])
    return examples

# --- 2. Apply @spaces.GPU Decorator ---
@spaces.GPU(duration=120) 
def run_sharp(
    image_path: str | None,
    trajectory_type: str,
    output_long_side: int,
    num_frames: int,
    fps: int,
    render_video: bool,
    progress=gr.Progress()
) -> tuple[str | None, dict, str]:
    """
    Main Inference Function
    Decorated with @spaces.GPU to ensure it runs on the GPU node.
    """
    if not image_path:
        raise gr.Error("Please upload an image first.")

    # Validate inputs
    out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
    
    # Convert trajectory string to Enum safely
    traj_key = trajectory_type.upper()
    if hasattr(TrajectoryType, traj_key):
        traj_enum = TrajectoryType[traj_key]
    else:
        traj_enum = trajectory_type

    try:
        progress(0.1, desc="Initializing SHARP model on GPU...")
        
        # Call the backend model
        video_path, ply_path = predict_and_maybe_render_gpu(
            image_path,
            trajectory_type=traj_enum,
            num_frames=int(num_frames),
            fps=int(fps),
            output_long_side=out_long_side_val,
            render_video=bool(render_video),
        )

        # Prepare outputs
        status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
        
        video_result = str(video_path) if video_path else None
        if video_path:
            status_msg += f"\nVideo: `{video_path.name}`"

        # Explicitly update the Download Button
        download_btn_update = gr.DownloadButton(
            value=str(ply_path), 
            visible=True,
            label=f"Download {ply_path.name}"
        )
        
        return (
            video_result,
            download_btn_update,
            status_msg
        )

    except Exception as e:
        # If it fails, we return None for video, hide button, and show error
        return (
            None, 
            gr.DownloadButton(visible=False), 
            f"### ❌ Error\n{str(e)}"
        )

# -----------------------------------------------------------------------------
# UI Construction
# -----------------------------------------------------------------------------

def build_demo() -> gr.Blocks:
    theme = gr.themes.Default()

    with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
        
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")

        # --- Main Layout (Strict Two Columns) ---
        with gr.Row(equal_height=False):
            
            # --- LEFT COLUMN: Input & Controls ---
            with gr.Column(scale=1):
                image_in = gr.Image(
                    label="Input Image",
                    type="filepath",
                    sources=["upload", "clipboard"],
                    interactive=True
                )

                # Configs
                with gr.Group():
                    with gr.Row():
                        trajectory = gr.Dropdown(
                            label="Camera Movement",
                            choices=["swipe", "shake", "rotate", "rotate_forward"],
                            value="rotate_forward",
                            scale=2
                        )
                        output_res = gr.Dropdown(
                            label="Output Resolution",
                            choices=[("Original", 0), ("512px", 512), ("1024px", 1024)],
                            value=0,
                            scale=1
                        )
                    with gr.Row():
                        frames = gr.Slider(label="Frames", minimum=24, maximum=120, step=1, value=60)
                        fps_in = gr.Slider(label="FPS", minimum=8, maximum=60, step=1, value=30)
                    
                    render_toggle = gr.Checkbox(label="Render Video Preview", value=True)

                run_btn = gr.Button("🚀 Generate 3D Scene", variant="primary", size="lg")

                # Examples
                example_files = get_example_files()
                if example_files:
                    gr.Examples(
                        examples=example_files,
                        inputs=[image_in],
                        label="Examples",
                        run_on_click=False,
                        cache_examples=False
                    )

            # --- RIGHT COLUMN: Output ---
            with gr.Column(scale=1):
                video_out = gr.Video(
                    label="3D Preview", 
                    autoplay=True,
                    elem_id="output-video",
                    interactive=False
                )
                
                with gr.Group():
                    status_md = gr.Markdown("Ready to generate.")
                    # Button starts hidden
                    ply_download = gr.DownloadButton(
                        label="Download .PLY File",
                        variant="secondary",
                        visible=False
                    )

        # --- Logic Binding ---
        run_btn.click(
            fn=run_sharp,
            inputs=[image_in, trajectory, output_res, frames, fps_in, render_toggle],
            outputs=[video_out, ply_download, status_md],
            concurrency_limit=1
        )
        
    return demo

# -----------------------------------------------------------------------------
# Entry Point
# -----------------------------------------------------------------------------

_ensure_dir(OUTPUTS_DIR)

if __name__ == "__main__":
    demo = build_demo()
    demo.queue().launch(
        allowed_paths=[str(ASSETS_DIR)],
        ssr_mode=False
    )