import gradio as gr import random import torch import numpy as np from PIL import Image, ImageOps import os import json import sys import multiprocessing from concurrent.futures import ProcessPoolExecutor import time # Assume MagicQuill and other dependencies are present as per user instruction from MagicQuill import folder_paths from MagicQuill.llava_new import LLaVAModel from huggingface_hub import snapshot_download # Imports for SAM (Only needed in worker process, but imported here for checking) from segment_anything import sam_model_registry, SamPredictor # Download models (Main process does this once) hf_token = os.environ.get("HF_TOKEN") snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models") snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token) # --- Global Models for Main Process --- print("Initializing LLaVAModel (Main Process)...") # LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads) llavaModel = LLaVAModel() print("LLaVAModel initialized.") # --- Worker Process Logic for SAM --- # Global variable for the worker process to hold its own SAM instance worker_sam = None def init_worker_sam(device='cuda'): """ This function is called when a new worker process starts. It initializes a standalone SAM model for that process. """ global worker_sam print(f"Process {os.getpid()}: Initializing SAM model...") # Define SAM class locally or import it. Since it was defined in the script, # we can redefine a helper or import the logic. # Ideally, the SAM logic should be in a separate module to be picklable easily. # But for this script, we can define the loading logic here. checkpoint_path = 'models_v2/sam/sam_vit_b_01ec64.pth' # Load Model try: sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path) sam.to(device=device) predictor = SamPredictor(sam) worker_sam = { "predictor": predictor } print(f"Process {os.getpid()}: SAM initialized.") except Exception as e: print(f"Process {os.getpid()}: Failed to init SAM: {e}") def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes): """ The actual inference function running inside the worker process. """ global worker_sam if worker_sam is None: # Fallback if init didn't run or failed (though ProcessPool initializer should handle it) init_worker_sam() predictor = worker_sam["predictor"] # Set Image predictor.set_image(image_np) input_point = [] input_label = [] # Process points if coordinates_positive: coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive for p in coords: input_point.append([p['x'], p['y']]) input_label.append(1) if coordinates_negative: coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative for p in coords: input_point.append([p['x'], p['y']]) input_label.append(0) # Process bbox input_box = None if bboxes: if isinstance(bboxes, str): try: bboxes = json.loads(bboxes) except: pass box_list = [] if isinstance(bboxes, list): for box in bboxes: box_list.append(list(box)) if len(box_list) > 0: input_box = np.array(box_list) if len(input_point) > 0: input_point = np.array(input_point) input_label = np.array(input_label) else: input_point = None input_label = None # Predict masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box, multimask_output=False, ) mask_np = masks[0] # Post-processing # Simply convert mask to uint8 [0, 255] for transport if mask_np.dtype == bool: mask_np = mask_np.astype(np.uint8) * 255 else: mask_np = (mask_np > 0).astype(np.uint8) * 255 # Return mask as image for client to use # We return mask_np twice to satisfy the function signature or unpacker in segment() # segment() expects (image_with_alpha_np, mask_np) return mask_np, mask_np # --- Main Process Helpers --- # We need a pool. Since we are in a script, we initialize it in main block. sam_pool = None def numpy_to_tensor(numpy_array): tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255. return tensor def guess(original_image, add_color_image, add_edge_mask): # LLaVA inference runs in the main process (threaded) original_image_tensor = numpy_to_tensor(original_image) add_color_image_tensor = numpy_to_tensor(add_color_image) add_edge_mask_tensor = numpy_to_tensor(add_edge_mask) description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor) ans_list = [] if ans1 and ans1 != "": ans_list.append(ans1) if ans2 and ans2 != "": ans_list.append(ans2) return ", ".join(ans_list) def get_mask_bbox(mask_np): # mask_np: [1, H, W] or [H, W] if mask_np.ndim == 3: mask_np = mask_np[0] rows = np.any(mask_np, axis=1) cols = np.any(mask_np, axis=0) if not np.any(rows) or not np.any(cols): return None y_min, y_max = np.where(rows)[0][[0, -1]] x_min, x_max = np.where(cols)[0][[0, -1]] return int(x_min), int(y_min), int(x_max), int(y_max) def segment(image, coordinates_positive, coordinates_negative, bboxes): # image: numpy array (uint8) # Submit task to process pool print("image.shape:", image.shape) print("coordinates_positive:", coordinates_positive) print("coordinates_negative:", coordinates_negative) print("bboxes:", bboxes) if sam_pool is None: return None, json.dumps({'error': 'SAM pool not initialized'}) # Future result future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes) # Wait for result image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout # Convert back to PIL for Gradio res_pil = Image.fromarray(image_with_alpha_np) # Calculate bbox mask_bbox = get_mask_bbox(mask_np) if mask_bbox: x_min, y_min, x_max, y_max = mask_bbox seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} else: seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} return res_pil, json.dumps(seg_bbox) # --- Gradio UI --- with gr.Blocks() as app: with gr.Row(): gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)") with gr.Tab("Draw & Guess"): with gr.Row(): dg_input_img = gr.Image(label="Original Image") dg_color_img = gr.Image(label="Colored Image") dg_edge_img = gr.Image(image_mode="L", label="Edge Mask") dg_output = gr.Textbox(label="Prediction Output") dg_btn = gr.Button("Guess") dg_btn.click( fn=guess, inputs=[dg_input_img, dg_color_img, dg_edge_img], outputs=dg_output, api_name="guess_prompt", concurrency_limit=1 ) with gr.Tab("SAM Segmentation"): with gr.Row(): sam_input_img = gr.Image(label="Input Image", type="numpy") sam_pos_coords = gr.Textbox(label="Pos Coords JSON") sam_neg_coords = gr.Textbox(label="Neg Coords JSON") sam_bboxes = gr.Textbox(label="BBoxes JSON") with gr.Row(): sam_output_img = gr.Image(label="Segmented Image", format="png") sam_output_bbox = gr.Textbox(label="Mask BBox JSON") sam_btn = gr.Button("Segment") sam_btn.click( fn=segment, inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes], outputs=[sam_output_img, sam_output_bbox], api_name="segment", concurrency_limit=5 ) if __name__ == "__main__": # Set start method to spawn for CUDA compatibility multiprocessing.set_start_method('spawn', force=True) # Initialize SAM Pool # Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B) NUM_SAM_WORKERS = 5 print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...") sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam) # Launch Gradio app.queue(max_size=40).launch(max_threads=5)