Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Form | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| import torch | |
| import re | |
| app = FastAPI(title="GUI-Actor API", version="1.0.0") | |
| # Initialize global variables | |
| model = None | |
| processor = None | |
| tokenizer = None | |
| model_name = "microsoft/GUI-Actor-2B-Qwen2-VL" | |
| def load_model(): | |
| """Load model with proper error handling""" | |
| global model, processor, tokenizer | |
| try: | |
| print("Loading processor...") | |
| # Try different approaches to load the processor | |
| try: | |
| from transformers import Qwen2VLProcessor | |
| processor = Qwen2VLProcessor.from_pretrained(model_name) | |
| print("Successfully loaded Qwen2VLProcessor") | |
| except Exception as e: | |
| print(f"Failed to load Qwen2VLProcessor: {e}") | |
| from transformers import AutoProcessor | |
| processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
| print("Successfully loaded AutoProcessor") | |
| tokenizer = processor.tokenizer | |
| print("Loading model...") | |
| # Use the correct model class for Qwen2VL | |
| from transformers import Qwen2VLForConditionalGeneration | |
| model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, # float32 untuk CPU | |
| device_map=None, # CPU only | |
| trust_remote_code=True, # untuk custom model | |
| attn_implementation=None # skip flash attention | |
| ).eval() | |
| print("Model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| # Load model at startup | |
| model_loaded = load_model() | |
| class Base64Request(BaseModel): | |
| image_base64: str | |
| instruction: str | |
| def extract_coordinates(text): | |
| """ | |
| Extract coordinates from model output text | |
| """ | |
| # Pattern untuk mencari koordinat dalam berbagai format | |
| patterns = [ | |
| r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y) | |
| r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', # [x, y] | |
| r'(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)', # x, y | |
| r'point:\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # point: (x, y) | |
| ] | |
| for pattern in patterns: | |
| matches = re.findall(pattern, text.lower()) | |
| if matches: | |
| try: | |
| x, y = float(matches[0][0]), float(matches[0][1]) | |
| # Normalize jika koordinat > 1 (asumsi pixel coordinates) | |
| if x > 1 or y > 1: | |
| # Asumsi resolusi 1920x1080 untuk normalisasi | |
| x = x / 1920 if x > 1 else x | |
| y = y / 1080 if y > 1 else y | |
| return [(x, y)] | |
| except (ValueError, IndexError): | |
| continue | |
| # Default ke center jika tidak ditemukan | |
| return [(0.5, 0.5)] | |
| def cpu_inference(conversation, model, tokenizer, processor): | |
| """ | |
| Inference function untuk CPU | |
| """ | |
| try: | |
| # Apply chat template | |
| text = processor.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Get image from conversation | |
| image = conversation[1]["content"][0]["image"] | |
| # Process inputs | |
| inputs = processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt" | |
| ) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.8, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| generated_ids = outputs[0][inputs["input_ids"].shape[1]:] | |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Extract coordinates | |
| coordinates = extract_coordinates(response) | |
| return { | |
| "topk_points": coordinates, | |
| "response": response, | |
| "success": True | |
| } | |
| except Exception as e: | |
| return { | |
| "topk_points": [(0.5, 0.5)], | |
| "response": f"Error during inference: {str(e)}", | |
| "success": False | |
| } | |
| async def root(): | |
| return { | |
| "message": "GUI-Actor API is running", | |
| "status": "healthy", | |
| "model_loaded": model_loaded | |
| } | |
| async def predict_click_base64(data: Base64Request): | |
| if not model_loaded: | |
| return JSONResponse( | |
| content={ | |
| "error": "Model not loaded properly", | |
| "success": False, | |
| "x": 0.5, | |
| "y": 0.5 | |
| }, | |
| status_code=503 | |
| ) | |
| try: | |
| # Decode base64 to image | |
| image_data = base64.b64decode(data.image_base64.split(",")[-1]) | |
| pil_image = Image.open(BytesIO(image_data)).convert("RGB") | |
| conversation = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.", | |
| } | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": pil_image, | |
| }, | |
| { | |
| "type": "text", | |
| "text": data.instruction, | |
| }, | |
| ], | |
| }, | |
| ] | |
| # Run inference | |
| pred = cpu_inference(conversation, model, tokenizer, processor) | |
| px, py = pred["topk_points"][0] | |
| return JSONResponse(content={ | |
| "x": round(px, 4), | |
| "y": round(py, 4), | |
| "response": pred["response"], | |
| "success": pred["success"] | |
| }) | |
| except Exception as e: | |
| return JSONResponse( | |
| content={ | |
| "error": str(e), | |
| "success": False, | |
| "x": 0.5, | |
| "y": 0.5 | |
| }, | |
| status_code=500 | |
| ) | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model": model_name, | |
| "device": "cpu", | |
| "torch_dtype": "float32", | |
| "model_loaded": model_loaded | |
| } | |
| async def predict_click_form( | |
| image_base64: str = Form(...), | |
| instruction: str = Form(...) | |
| ): | |
| data = Base64Request(image_base64=image_base64, instruction=instruction) | |
| return await predict_click_base64(data) |