lukas-agentix commited on
Commit
faadde5
·
verified ·
1 Parent(s): 5d98f4d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +106 -0
handler.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import base64, io, os
3
+ from PIL import Image
4
+ import torch
5
+
6
+ from transformers import AutoProcessor
7
+ # Use the vendor'ed GUI-Actor sources (copied into the repo as /gui_actor)
8
+ from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
9
+ from gui_actor.inference import inference
10
+
11
+ class EndpointHandler:
12
+ """
13
+ Accepts JSON like:
14
+ {
15
+ "image": "data:image/png;base64,...." OR "image_b64": "<raw_base64>",
16
+ "image_url": "https://...png", # optional (qwen-vl-utils supports it)
17
+ "prompt": "Click the close button",
18
+ "topk": 3,
19
+ "return_pixels": true,
20
+ "screen_w": 1920,
21
+ "screen_h": 1080
22
+ }
23
+ Returns:
24
+ {
25
+ "points_norm": [[x,y], ...], # 0..1 normalized
26
+ "points_px": [[x_px,y_px], ...] # if screen_w/h given
27
+ }
28
+ """
29
+
30
+ def __init__(self, path: str = ""):
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ # bfloat16 on GPU is fine for Qwen2.5; fallback to float32 on CPU
33
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
34
+
35
+ self.processor = AutoProcessor.from_pretrained(path)
36
+ self.tokenizer = self.processor.tokenizer
37
+
38
+ # Avoid hard requiring flash-attn; it will use PyTorch SDPA if unavailable
39
+ self.model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
40
+ path, torch_dtype=dtype, device_map="auto"
41
+ ).eval()
42
+
43
+ def _load_pil(self, data: Dict[str, Any]) -> Image.Image:
44
+ if "image" in data and isinstance(data["image"], str) and data["image"].startswith("data:image"):
45
+ # "data:image/png;base64,......"
46
+ b64 = data["image"].split("base64,")[-1]
47
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
48
+ if "image_b64" in data:
49
+ return Image.open(io.BytesIO(base64.b64decode(data["image_b64"]))).convert("RGB")
50
+ # If image_url is provided, pass URL through; Qwen utils can handle URLs internally
51
+ return None
52
+
53
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
54
+ # Handle both direct input and HuggingFace's nested input format
55
+ if "inputs" in data:
56
+ payload = data["inputs"]
57
+ else:
58
+ payload = data
59
+
60
+ prompt = payload.get("prompt", "")
61
+ topk = int(payload.get("topk", 3))
62
+ img = self._load_pil(payload)
63
+
64
+ # Build conversation per model card
65
+ user_content = []
66
+ if img is not None:
67
+ user_content.append({"type": "image", "image": img})
68
+ elif "image_url" in payload:
69
+ user_content.append({"type": "image", "image_url": payload["image_url"]})
70
+ else:
71
+ raise ValueError("No image provided. Supply 'image' (data URL), 'image_b64', or 'image_url'.")
72
+
73
+ user_content.append({"type": "text", "text": prompt})
74
+ conversation = [
75
+ {
76
+ "role": "system",
77
+ "content": [{"type": "text",
78
+ "text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, "
79
+ "locate the target element and output a click position."}]
80
+ },
81
+ {"role": "user", "content": user_content},
82
+ ]
83
+
84
+ try:
85
+ pred = inference(
86
+ conversation,
87
+ self.model,
88
+ self.tokenizer,
89
+ self.processor,
90
+ use_placeholder=True,
91
+ topk=topk,
92
+ )
93
+
94
+ points = pred.get("topk_points") or []
95
+ result = {"points_norm": points}
96
+
97
+ # Optional: convert to pixels
98
+ if payload.get("return_pixels") and payload.get("screen_w") and payload.get("screen_h"):
99
+ w = int(payload["screen_w"])
100
+ h = int(payload["screen_h"])
101
+ result["points_px"] = [[int(x*w), int(y*h)] for (x, y) in points]
102
+
103
+ return result
104
+
105
+ except Exception as e:
106
+ return {"error": str(e), "points_norm": [], "points_px": []}