|
|
from typing import Any, Dict, List |
|
|
import base64, io, os |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
from transformers import AutoProcessor |
|
|
|
|
|
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer |
|
|
from gui_actor.inference import inference |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Accepts JSON like: |
|
|
{ |
|
|
"image": "data:image/png;base64,...." OR "image_b64": "<raw_base64>", |
|
|
"image_url": "https://...png", # optional (qwen-vl-utils supports it) |
|
|
"prompt": "Click the close button", |
|
|
"topk": 3, |
|
|
"return_pixels": true, |
|
|
"screen_w": 1920, |
|
|
"screen_h": 1080 |
|
|
} |
|
|
Returns: |
|
|
{ |
|
|
"points_norm": [[x,y], ...], # 0..1 normalized |
|
|
"points_px": [[x_px,y_px], ...] # if screen_w/h given |
|
|
} |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
self.tokenizer = self.processor.tokenizer |
|
|
|
|
|
|
|
|
self.model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained( |
|
|
path, torch_dtype=dtype, device_map="auto" |
|
|
).eval() |
|
|
|
|
|
def _load_pil(self, data: Dict[str, Any]) -> Image.Image: |
|
|
if "image" in data and isinstance(data["image"], str) and data["image"].startswith("data:image"): |
|
|
|
|
|
b64 = data["image"].split("base64,")[-1] |
|
|
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
|
|
if "image_b64" in data: |
|
|
return Image.open(io.BytesIO(base64.b64decode(data["image_b64"]))).convert("RGB") |
|
|
|
|
|
return None |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
if "inputs" in data: |
|
|
payload = data["inputs"] |
|
|
else: |
|
|
payload = data |
|
|
|
|
|
prompt = payload.get("prompt", "") |
|
|
topk = int(payload.get("topk", 3)) |
|
|
img = self._load_pil(payload) |
|
|
|
|
|
|
|
|
user_content = [] |
|
|
if img is not None: |
|
|
user_content.append({"type": "image", "image": img}) |
|
|
elif "image_url" in payload: |
|
|
user_content.append({"type": "image", "image_url": payload["image_url"]}) |
|
|
else: |
|
|
raise ValueError("No image provided. Supply 'image' (data URL), 'image_b64', or 'image_url'.") |
|
|
|
|
|
user_content.append({"type": "text", "text": prompt}) |
|
|
conversation = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": [{"type": "text", |
|
|
"text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, " |
|
|
"locate the target element and output a click position."}] |
|
|
}, |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
|
|
|
try: |
|
|
pred = inference( |
|
|
conversation, |
|
|
self.model, |
|
|
self.tokenizer, |
|
|
self.processor, |
|
|
use_placeholder=True, |
|
|
topk=topk, |
|
|
) |
|
|
|
|
|
points = pred.get("topk_points") or [] |
|
|
result = {"points_norm": points} |
|
|
|
|
|
|
|
|
if payload.get("return_pixels") and payload.get("screen_w") and payload.get("screen_h"): |
|
|
w = int(payload["screen_w"]) |
|
|
h = int(payload["screen_h"]) |
|
|
result["points_px"] = [[int(x*w), int(y*h)] for (x, y) in points] |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "points_norm": [], "points_px": []} |
|
|
|