petergilani commited on
Commit
eb9f6cd
·
verified ·
1 Parent(s): 5c7307c

Upload qwen3_coder_detector_sgl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. qwen3_coder_detector_sgl.py +474 -0
qwen3_coder_detector_sgl.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import logging
4
+ import re
5
+ from typing import Any, List, Optional
6
+
7
+ from sglang.srt.entrypoints.openai.protocol import Tool
8
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
9
+ from sglang.srt.function_call.core_types import (
10
+ StreamingParseResult,
11
+ ToolCallItem,
12
+ _GetInfoFunc,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Qwen3CoderDetector(BaseFormatDetector):
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ # Sentinel tokens
23
+ self.tool_call_start_token: str = "<tool_call>"
24
+ self.tool_call_end_token: str = "</tool_call>"
25
+ self.tool_call_prefix: str = "<function="
26
+ self.function_end_token: str = "</function>"
27
+ self.parameter_prefix: str = "<parameter="
28
+ self.parameter_end_token: str = "</parameter>"
29
+
30
+ # Regex for non-streaming fallback
31
+ self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
32
+ self.tool_call_function_regex = re.compile(
33
+ r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
34
+ )
35
+ self.tool_call_parameter_regex = re.compile(
36
+ r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
37
+ re.DOTALL,
38
+ )
39
+
40
+ # Streaming State
41
+ # Base class already initializes _buffer, we just use it directly
42
+ # No need to check with hasattr - we control the lifecycle through inheritance
43
+
44
+ # Index pointing to the next character to be processed in buffer
45
+ self.parsed_pos: int = 0
46
+ # Parameter count inside the current tool being processed, used to determine whether to add comma
47
+ self.current_tool_param_count: int = 0
48
+ # Flag indicating whether current tool has already sent '{'
49
+ self.json_started: bool = False
50
+
51
+ # [FIX] New state flag: mark whether inside tool_call structure block
52
+ self.is_inside_tool_call: bool = False
53
+
54
+ # Initialize attributes that were missing in the original PR
55
+ self.current_func_name: Optional[str] = None
56
+
57
+ def has_tool_call(self, text: str) -> bool:
58
+ return self.tool_call_start_token in text
59
+
60
+ def _get_arguments_config(
61
+ self, func_name: str, tools: Optional[list[Tool]]
62
+ ) -> dict:
63
+ """Extract argument configuration for a function."""
64
+ if tools is None:
65
+ return {}
66
+ for config in tools:
67
+ try:
68
+ config_type = config.type
69
+ config_function = config.function
70
+ config_function_name = config_function.name
71
+ except AttributeError:
72
+ continue
73
+
74
+ if config_type == "function" and config_function_name == func_name:
75
+ try:
76
+ params = config_function.parameters
77
+ except AttributeError:
78
+ return {}
79
+
80
+ if isinstance(params, dict) and "properties" in params:
81
+ return params["properties"]
82
+ elif isinstance(params, dict):
83
+ return params
84
+ else:
85
+ return {}
86
+ logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
87
+ return {}
88
+
89
+ def _convert_param_value(
90
+ self, param_value: str, param_name: str, param_config: dict, func_name: str
91
+ ) -> Any:
92
+ """Convert parameter value based on its type in the schema."""
93
+ # Handle null value for any type
94
+ if param_value.lower() == "null":
95
+ return None
96
+
97
+ if param_name not in param_config:
98
+ if param_config != {}:
99
+ logger.warning(
100
+ f"Parsed parameter '{param_name}' is not defined in the tool "
101
+ f"parameters for tool '{func_name}', directly returning the string value."
102
+ )
103
+ return param_value
104
+
105
+ if (
106
+ isinstance(param_config[param_name], dict)
107
+ and "type" in param_config[param_name]
108
+ ):
109
+ param_type = str(param_config[param_name]["type"]).strip().lower()
110
+ else:
111
+ param_type = "string"
112
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
113
+ return param_value
114
+ elif (
115
+ param_type.startswith("int")
116
+ or param_type.startswith("uint")
117
+ or param_type.startswith("long")
118
+ or param_type.startswith("short")
119
+ or param_type.startswith("unsigned")
120
+ ):
121
+ try:
122
+ param_value = int(param_value)
123
+ except Exception:
124
+ logger.warning(
125
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
126
+ f"'{func_name}', degenerating to string."
127
+ )
128
+ return param_value
129
+ elif param_type.startswith("num") or param_type.startswith("float"):
130
+ try:
131
+ maybe_convert = (
132
+ False if "." in param_value or "e" in param_value.lower() else True
133
+ )
134
+ param_value: float = float(param_value)
135
+ if maybe_convert and param_value.is_integer():
136
+ param_value = int(param_value)
137
+ except Exception:
138
+ logger.warning(
139
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
140
+ f"'{func_name}', degenerating to string."
141
+ )
142
+ return param_value
143
+ elif param_type in ["boolean", "bool", "binary"]:
144
+ param_value = param_value.lower()
145
+ if param_value not in ["true", "false"]:
146
+ logger.warning(
147
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
148
+ )
149
+ return param_value == "true"
150
+ else:
151
+ if (
152
+ param_type in ["object", "array", "arr"]
153
+ or param_type.startswith("dict")
154
+ or param_type.startswith("list")
155
+ ):
156
+ try:
157
+ param_value = json.loads(param_value)
158
+ return param_value
159
+ except Exception:
160
+ logger.warning(
161
+ f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool "
162
+ f"'{func_name}', will try other methods to parse it."
163
+ )
164
+ try:
165
+ param_value = ast.literal_eval(param_value) # safer
166
+ except Exception:
167
+ logger.warning(
168
+ f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string."
169
+ )
170
+ return param_value
171
+
172
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
173
+ """One-shot parsing for non-streaming scenarios."""
174
+ if self.tool_call_start_token not in text:
175
+ return StreamingParseResult(normal_text=text)
176
+
177
+ calls = []
178
+ try:
179
+ # Simple cleanup of the text to find tool calls
180
+ # Note: This is a simplified regex approach consistent with vLLM
181
+ raw_tool_calls = self.tool_call_regex.findall(text)
182
+ if not raw_tool_calls:
183
+ # Fallback: maybe the whole text is inside the tag or tags are stripped
184
+ if self.tool_call_prefix in text:
185
+ raw_tool_calls = [text]
186
+
187
+ tool_idx = 0
188
+ for tool_content in raw_tool_calls:
189
+ # Find function calls
190
+ funcs = self.tool_call_function_regex.findall(tool_content)
191
+ for func_match in funcs:
192
+ func_body = func_match[0] or func_match[1]
193
+ if ">" not in func_body:
194
+ continue
195
+
196
+ name_end = func_body.index(">")
197
+ func_name = func_body[:name_end]
198
+ params_str = func_body[name_end + 1 :]
199
+
200
+ param_config = self._get_arguments_config(func_name, tools)
201
+ parsed_params = {}
202
+
203
+ for p_match in self.tool_call_parameter_regex.findall(params_str):
204
+ if ">" not in p_match:
205
+ continue
206
+ p_idx = p_match.index(">")
207
+ p_name = p_match[:p_idx]
208
+ p_val = p_match[p_idx + 1 :]
209
+ # Remove prefixing and trailing \n
210
+ if p_val.startswith("\n"):
211
+ p_val = p_val[1:]
212
+ if p_val.endswith("\n"):
213
+ p_val = p_val[:-1]
214
+
215
+ parsed_params[p_name] = self._convert_param_value(
216
+ p_val, p_name, param_config, func_name
217
+ )
218
+
219
+ calls.append(
220
+ ToolCallItem(
221
+ tool_index=tool_idx,
222
+ name=func_name,
223
+ parameters=json.dumps(parsed_params, ensure_ascii=False),
224
+ )
225
+ )
226
+ tool_idx += 1
227
+
228
+ # Determine normal text (text before the first tool call)
229
+ start_idx = text.find(self.tool_call_start_token)
230
+ if start_idx == -1:
231
+ start_idx = text.find(self.tool_call_prefix)
232
+ normal_text = text[:start_idx] if start_idx > 0 else ""
233
+
234
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
235
+
236
+ except Exception as e:
237
+ logger.error(f"Error in detect_and_parse: {e}")
238
+ return StreamingParseResult(normal_text=text)
239
+
240
+ def parse_streaming_increment(
241
+ self, new_text: str, tools: List[Tool]
242
+ ) -> StreamingParseResult:
243
+ """
244
+ Robust cursor-based streaming parser.
245
+ """
246
+ self._buffer += new_text
247
+
248
+ # Guard against empty buffer
249
+ if not self._buffer:
250
+ return StreamingParseResult()
251
+
252
+ calls = []
253
+ normal_text_chunks = []
254
+
255
+ while True:
256
+ # Working text slice
257
+ current_slice = self._buffer[self.parsed_pos :]
258
+
259
+ # Optimization: If almost empty, wait for more
260
+ if not current_slice:
261
+ break
262
+
263
+ # -------------------------------------------------------
264
+ # 1. Priority detection: check if it's the start of Tool Call
265
+ # -------------------------------------------------------
266
+ if current_slice.startswith(self.tool_call_start_token):
267
+ self.parsed_pos += len(self.tool_call_start_token)
268
+ self.is_inside_tool_call = True
269
+ continue
270
+
271
+ # -------------------------------------------------------
272
+ # 2. Function Name: <function=name>
273
+ # -------------------------------------------------------
274
+ if current_slice.startswith(self.tool_call_prefix):
275
+ end_angle = current_slice.find(">")
276
+ if end_angle != -1:
277
+ func_name = current_slice[len(self.tool_call_prefix) : end_angle]
278
+
279
+ self.current_tool_id += 1
280
+ self.current_tool_name_sent = True
281
+ self.current_tool_param_count = 0
282
+ self.json_started = False
283
+ self.current_func_name = func_name
284
+
285
+ calls.append(
286
+ ToolCallItem(
287
+ tool_index=self.current_tool_id,
288
+ name=func_name,
289
+ parameters="",
290
+ )
291
+ )
292
+
293
+ self.parsed_pos += end_angle + 1
294
+ continue
295
+ else:
296
+ # Incomplete tag
297
+ break
298
+
299
+ # -------------------------------------------------------
300
+ # 3. Parameter: <parameter=name>value...
301
+ # -------------------------------------------------------
302
+ if current_slice.startswith(self.parameter_prefix):
303
+ name_end = current_slice.find(">")
304
+ if name_end != -1:
305
+ value_start_idx = name_end + 1
306
+ rest_of_slice = current_slice[value_start_idx:]
307
+
308
+ # A parameter can end in multiple ways:
309
+ # 1. [Normal] Encounter </parameter>
310
+ # 2. [Abnormal] Encounter next <parameter=
311
+ # 3. [Abnormal] Encounter </function>
312
+ # So we need to find the smallest one as the parameter end position.
313
+ cand_end_param = rest_of_slice.find(self.parameter_end_token)
314
+ cand_next_param = rest_of_slice.find(self.parameter_prefix)
315
+ cand_end_func = rest_of_slice.find(self.function_end_token)
316
+
317
+ candidates = []
318
+ if cand_end_param != -1:
319
+ candidates.append(
320
+ (cand_end_param, len(self.parameter_end_token))
321
+ )
322
+ if cand_next_param != -1:
323
+ candidates.append((cand_next_param, 0))
324
+ if cand_end_func != -1:
325
+ candidates.append((cand_end_func, 0))
326
+
327
+ if candidates:
328
+ best_cand = min(candidates, key=lambda x: x[0])
329
+ end_pos = best_cand[0]
330
+ end_token_len = best_cand[1]
331
+
332
+ param_name = current_slice[
333
+ len(self.parameter_prefix) : name_end
334
+ ]
335
+ raw_value = rest_of_slice[:end_pos]
336
+
337
+ # Cleanup value
338
+ if raw_value.startswith("\n"):
339
+ raw_value = raw_value[1:]
340
+ if raw_value.endswith("\n"):
341
+ raw_value = raw_value[:-1]
342
+
343
+ # JSON Construction
344
+ if not self.json_started:
345
+ calls.append(
346
+ ToolCallItem(
347
+ tool_index=self.current_tool_id, parameters="{"
348
+ )
349
+ )
350
+ self.json_started = True
351
+
352
+ param_config = self._get_arguments_config(
353
+ self.current_func_name, tools
354
+ )
355
+ converted_val = self._convert_param_value(
356
+ raw_value, param_name, param_config, self.current_func_name
357
+ )
358
+
359
+ # Construct JSON fragment: "key": value
360
+ # Note: We must be careful with json.dumps to ensure valid JSON streaming
361
+ json_key_val = f"{json.dumps(param_name)}: {json.dumps(converted_val, ensure_ascii=False)}"
362
+
363
+ if self.current_tool_param_count > 0:
364
+ fragment = f", {json_key_val}"
365
+ else:
366
+ fragment = json_key_val
367
+
368
+ calls.append(
369
+ ToolCallItem(
370
+ tool_index=self.current_tool_id, parameters=fragment
371
+ )
372
+ )
373
+ self.current_tool_param_count += 1
374
+
375
+ # Advance cursor
376
+ total_len = (name_end + 1) + end_pos + end_token_len
377
+ self.parsed_pos += total_len
378
+ continue
379
+
380
+ # Incomplete parameter tag or value
381
+ break
382
+
383
+ # -------------------------------------------------------
384
+ # 4. Function End: </function>
385
+ # -------------------------------------------------------
386
+ if current_slice.startswith(self.function_end_token):
387
+ if not self.json_started:
388
+ calls.append(
389
+ ToolCallItem(tool_index=self.current_tool_id, parameters="{")
390
+ )
391
+ self.json_started = True
392
+
393
+ calls.append(
394
+ ToolCallItem(tool_index=self.current_tool_id, parameters="}")
395
+ )
396
+ self.parsed_pos += len(self.function_end_token)
397
+ self.current_func_name = None
398
+ continue
399
+
400
+ # -------------------------------------------------------
401
+ # 5. Tool Call End: </tool_call>
402
+ # -------------------------------------------------------
403
+ if current_slice.startswith(self.tool_call_end_token):
404
+ self.parsed_pos += len(self.tool_call_end_token)
405
+ self.is_inside_tool_call = False # [FIX] Exit tool call region
406
+ continue
407
+
408
+ # -------------------------------------------------------
409
+ # 6. Handling content / whitespace / normal text
410
+ # -------------------------------------------------------
411
+ # If current position is not the start of a tag (i.e., doesn't start with <), it might be plain text,
412
+ # or a newline between two tags.
413
+ # But we need to be careful not to output truncated tags like "<fun" as text.
414
+
415
+ next_open_angle = current_slice.find("<")
416
+
417
+ if next_open_angle == -1:
418
+ # This entire segment is plain text
419
+ if not self.is_inside_tool_call:
420
+ normal_text_chunks.append(current_slice)
421
+ # [FIX] If inside tool call, discard this text (usually \n), don't append
422
+ self.parsed_pos += len(current_slice)
423
+ continue
424
+
425
+ elif next_open_angle == 0:
426
+ # Looks like a Tag, but doesn't match any known Tag above
427
+
428
+ possible_tags = [
429
+ self.tool_call_start_token,
430
+ self.tool_call_end_token,
431
+ self.tool_call_prefix,
432
+ self.function_end_token,
433
+ self.parameter_prefix,
434
+ self.parameter_end_token,
435
+ ]
436
+
437
+ is_potential_tag = False
438
+ for tag in possible_tags:
439
+ if tag.startswith(current_slice):
440
+ is_potential_tag = True
441
+ break
442
+
443
+ if is_potential_tag:
444
+ break # Wait for more
445
+ else:
446
+ # Just a plain '<' symbol
447
+ if not self.is_inside_tool_call:
448
+ normal_text_chunks.append("<")
449
+ self.parsed_pos += 1
450
+ continue
451
+
452
+ else:
453
+ # '<' is in the middle
454
+ text_segment = current_slice[:next_open_angle]
455
+ if not self.is_inside_tool_call:
456
+ normal_text_chunks.append(text_segment)
457
+ # [FIX] If inside tool call, discard whitespace/text before Tag
458
+ self.parsed_pos += next_open_angle
459
+ continue
460
+
461
+ # Memory Cleanup: Slice the buffer
462
+ # Keep unparsed part, discard parsed part
463
+ if self.parsed_pos > 0:
464
+ self._buffer = self._buffer[self.parsed_pos :]
465
+ self.parsed_pos = 0
466
+
467
+ normal_text = "".join(normal_text_chunks) if normal_text_chunks else ""
468
+ return StreamingParseResult(calls=calls, normal_text=normal_text)
469
+
470
+ def supports_structural_tag(self) -> bool:
471
+ return False
472
+
473
+ def structure_info(self) -> _GetInfoFunc:
474
+ raise NotImplementedError