LiuZichen commited on
Commit
191dbfa
·
verified ·
1 Parent(s): b05fd22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -17
app.py CHANGED
@@ -3,38 +3,261 @@ import random
3
  import torch
4
  import numpy as np
5
  from PIL import Image, ImageOps
6
- from fastapi import FastAPI, Request
 
 
 
 
 
 
 
7
  from MagicQuill import folder_paths
8
  from MagicQuill.llava_new import LLaVAModel
9
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
10
  snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
 
11
 
 
 
 
12
  llavaModel = LLaVAModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def numpy_to_tensor(numpy_array):
15
  tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
16
  return tensor
17
 
18
- def guess(original_image_tensor, add_color_image_tensor, add_edge_mask):
19
- # print("original_image_tensor:", original_image_tensor.shape)
20
- # print("add_color_image_tensor:", add_color_image_tensor.shape)
21
- # print("add_edge_mask:", add_edge_mask.shape)
22
- original_image_tensor = numpy_to_tensor(original_image_tensor)
23
- add_color_image_tensor = numpy_to_tensor(add_color_image_tensor)
24
- add_edge_mask = numpy_to_tensor(add_edge_mask)
25
- description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask)
26
  ans_list = []
27
  if ans1 and ans1 != "":
28
  ans_list.append(ans1)
29
  if ans2 and ans2 != "":
30
  ans_list.append(ans2)
 
31
  return ", ".join(ans_list)
32
 
33
- # 简化 Gradio 接口,参考官方格式
34
- gr.Interface(
35
- fn=guess,
36
- inputs=[gr.Image(label="Original Image"),
37
- gr.Image(label="Colored Image"),
38
- gr.Image(image_mode="L", label="Edge Mask")],
39
- outputs=gr.Textbox(label="Prediction Output")
40
- ).queue(max_size=40, status_update_rate=0.1).launch(max_threads=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image, ImageOps
6
+ import os
7
+ import json
8
+ import sys
9
+ import multiprocessing
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ import time
12
+
13
+ # Assume MagicQuill and other dependencies are present as per user instruction
14
  from MagicQuill import folder_paths
15
  from MagicQuill.llava_new import LLaVAModel
16
  from huggingface_hub import snapshot_download
17
+
18
+ # Imports for SAM (Only needed in worker process, but imported here for checking)
19
+ from segment_anything import sam_model_registry, SamPredictor
20
+
21
+ # Download models (Main process does this once)
22
+ hf_token = os.environ.get("HF_TOKEN")
23
  snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
24
+ snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
25
 
26
+ # --- Global Models for Main Process ---
27
+ print("Initializing LLaVAModel (Main Process)...")
28
+ # LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads)
29
  llavaModel = LLaVAModel()
30
+ print("LLaVAModel initialized.")
31
+
32
+ # --- Worker Process Logic for SAM ---
33
+ # Global variable for the worker process to hold its own SAM instance
34
+ worker_sam = None
35
+
36
+ def init_worker_sam(device='cuda'):
37
+ """
38
+ This function is called when a new worker process starts.
39
+ It initializes a standalone SAM model for that process.
40
+ """
41
+ global worker_sam
42
+ print(f"Process {os.getpid()}: Initializing SAM model...")
43
+
44
+ # Define SAM class locally or import it. Since it was defined in the script,
45
+ # we can redefine a helper or import the logic.
46
+ # Ideally, the SAM logic should be in a separate module to be picklable easily.
47
+ # But for this script, we can define the loading logic here.
48
+
49
+ checkpoint_path = 'models_v2/sams/sam_vit_b_01ec64.pth'
50
+
51
+ # Load Model
52
+ try:
53
+ sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path)
54
+ sam.to(device=device)
55
+ predictor = SamPredictor(sam)
56
+
57
+ worker_sam = {
58
+ "predictor": predictor
59
+ }
60
+ print(f"Process {os.getpid()}: SAM initialized.")
61
+ except Exception as e:
62
+ print(f"Process {os.getpid()}: Failed to init SAM: {e}")
63
+
64
+ def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes):
65
+ """
66
+ The actual inference function running inside the worker process.
67
+ """
68
+ global worker_sam
69
+
70
+ if worker_sam is None:
71
+ # Fallback if init didn't run or failed (though ProcessPool initializer should handle it)
72
+ init_worker_sam()
73
+
74
+ predictor = worker_sam["predictor"]
75
+
76
+ # Set Image
77
+ predictor.set_image(image_np)
78
+
79
+ input_point = []
80
+ input_label = []
81
+
82
+ # Process points
83
+ if coordinates_positive:
84
+ coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
85
+ for p in coords:
86
+ input_point.append([p['x'], p['y']])
87
+ input_label.append(1)
88
+
89
+ if coordinates_negative:
90
+ coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
91
+ for p in coords:
92
+ input_point.append([p['x'], p['y']])
93
+ input_label.append(0)
94
+
95
+ # Process bbox
96
+ input_box = None
97
+ if bboxes:
98
+ if isinstance(bboxes, str):
99
+ try:
100
+ bboxes = json.loads(bboxes)
101
+ except:
102
+ pass
103
+
104
+ box_list = []
105
+ if isinstance(bboxes, list):
106
+ for box in bboxes:
107
+ box_list.append(list(box))
108
+
109
+ if len(box_list) > 0:
110
+ input_box = np.array(box_list)
111
+
112
+ if len(input_point) > 0:
113
+ input_point = np.array(input_point)
114
+ input_label = np.array(input_label)
115
+ else:
116
+ input_point = None
117
+ input_label = None
118
+
119
+ # Predict
120
+ masks, scores, logits = predictor.predict(
121
+ point_coords=input_point,
122
+ point_labels=input_label,
123
+ box=input_box,
124
+ multimask_output=False,
125
+ )
126
+
127
+ mask_np = masks[0]
128
+
129
+ # Post-processing
130
+ # Simply convert mask to uint8 [0, 255] for transport
131
+ if mask_np.dtype == bool:
132
+ mask_np = mask_np.astype(np.uint8) * 255
133
+ else:
134
+ mask_np = (mask_np > 0).astype(np.uint8) * 255
135
+
136
+ # Return mask as image for client to use
137
+ # We return mask_np twice to satisfy the function signature or unpacker in segment()
138
+ # segment() expects (image_with_alpha_np, mask_np)
139
+ return mask_np, mask_np
140
+
141
+
142
+ # --- Main Process Helpers ---
143
+
144
+ # We need a pool. Since we are in a script, we initialize it in main block.
145
+ sam_pool = None
146
 
147
  def numpy_to_tensor(numpy_array):
148
  tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
149
  return tensor
150
 
151
+ def guess(original_image, add_color_image, add_edge_mask):
152
+ # LLaVA inference runs in the main process (threaded)
153
+ original_image_tensor = numpy_to_tensor(original_image)
154
+ add_color_image_tensor = numpy_to_tensor(add_color_image)
155
+ add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
156
+
157
+ description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
158
+
159
  ans_list = []
160
  if ans1 and ans1 != "":
161
  ans_list.append(ans1)
162
  if ans2 and ans2 != "":
163
  ans_list.append(ans2)
164
+
165
  return ", ".join(ans_list)
166
 
167
+ def get_mask_bbox(mask_np):
168
+ # mask_np: [1, H, W] or [H, W]
169
+ if mask_np.ndim == 3:
170
+ mask_np = mask_np[0]
171
+
172
+ rows = np.any(mask_np, axis=1)
173
+ cols = np.any(mask_np, axis=0)
174
+ if not np.any(rows) or not np.any(cols):
175
+ return None
176
+
177
+ y_min, y_max = np.where(rows)[0][[0, -1]]
178
+ x_min, x_max = np.where(cols)[0][[0, -1]]
179
+ return int(x_min), int(y_min), int(x_max), int(y_max)
180
+
181
+ def segment(image, coordinates_positive, coordinates_negative, bboxes):
182
+ # image: numpy array (uint8)
183
+ # Submit task to process pool
184
+
185
+ print("image.shape:", image.shape)
186
+ print("coordinates_positive:", coordinates_positive)
187
+ print("coordinates_negative:", coordinates_negative)
188
+ print("bboxes:", bboxes)
189
+
190
+ if sam_pool is None:
191
+ return None, json.dumps({'error': 'SAM pool not initialized'})
192
+
193
+ # Future result
194
+ future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes)
195
+
196
+ # Wait for result
197
+ image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout
198
+
199
+ # Convert back to PIL for Gradio
200
+ res_pil = Image.fromarray(image_with_alpha_np)
201
+
202
+ # Calculate bbox
203
+ mask_bbox = get_mask_bbox(mask_np)
204
+ if mask_bbox:
205
+ x_min, y_min, x_max, y_max = mask_bbox
206
+ seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
207
+ else:
208
+ seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
209
+
210
+ return res_pil, json.dumps(seg_bbox)
211
+
212
+ # --- Gradio UI ---
213
+ with gr.Blocks() as app:
214
+ with gr.Row():
215
+ gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
216
+
217
+ with gr.Tab("Draw & Guess"):
218
+ with gr.Row():
219
+ dg_input_img = gr.Image(label="Original Image")
220
+ dg_color_img = gr.Image(label="Colored Image")
221
+ dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
222
+ dg_output = gr.Textbox(label="Prediction Output")
223
+ dg_btn = gr.Button("Guess")
224
+
225
+ dg_btn.click(
226
+ fn=guess,
227
+ inputs=[dg_input_img, dg_color_img, dg_edge_img],
228
+ outputs=dg_output,
229
+ api_name="guess_prompt"
230
+ )
231
+
232
+ with gr.Tab("SAM Segmentation"):
233
+ with gr.Row():
234
+ sam_input_img = gr.Image(label="Input Image", type="numpy")
235
+ sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
236
+ sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
237
+ sam_bboxes = gr.Textbox(label="BBoxes JSON")
238
+
239
+ with gr.Row():
240
+ sam_output_img = gr.Image(label="Segmented Image", format="png")
241
+ sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
242
+
243
+ sam_btn = gr.Button("Segment")
244
+
245
+ sam_btn.click(
246
+ fn=segment,
247
+ inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
248
+ outputs=[sam_output_img, sam_output_bbox],
249
+ api_name="segment"
250
+ )
251
+
252
+ if __name__ == "__main__":
253
+ # Set start method to spawn for CUDA compatibility
254
+ multiprocessing.set_start_method('spawn', force=True)
255
+
256
+ # Initialize SAM Pool
257
+ # Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B)
258
+ NUM_SAM_WORKERS = 5
259
+ print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...")
260
+ sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam)
261
+
262
+ # Launch Gradio
263
+ app.queue(max_size=40).launch(max_threads=5, server_port=7861)