petergilani commited on
Commit
8937abf
·
verified ·
1 Parent(s): f68880a

Upload qwen3coder_tool_parser_vllm.py

Browse files
Files changed (1) hide show
  1. qwen3coder_tool_parser_vllm.py +690 -0
qwen3coder_tool_parser_vllm.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import ast
4
+ import json
5
+ import uuid
6
+ from collections.abc import Sequence
7
+ from typing import Any, List, Optional, Union
8
+
9
+ import regex as re
10
+
11
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
+ ChatCompletionToolsParam,
13
+ DeltaFunctionCall, DeltaMessage,
14
+ DeltaToolCall,
15
+ ExtractedToolCallInformation,
16
+ FunctionCall, ToolCall)
17
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
+ ToolParser, ToolParserManager)
19
+ from vllm.logger import init_logger
20
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ @ToolParserManager.register_module("qwen3_coder")
26
+ class Qwen3CoderToolParser(ToolParser):
27
+
28
+ def __init__(self, tokenizer: AnyTokenizer):
29
+ super().__init__(tokenizer)
30
+
31
+ self.current_tool_name_sent: bool = False
32
+ self.prev_tool_call_arr: list[dict] = []
33
+ self.current_tool_id: int = -1
34
+ self.streamed_args_for_tool: list[str] = []
35
+
36
+ # Sentinel tokens for streaming mode
37
+ self.tool_call_start_token: str = "<tool_call>"
38
+ self.tool_call_end_token: str = "</tool_call>"
39
+ self.tool_call_prefix: str = "<function="
40
+ self.function_end_token: str = "</function>"
41
+ self.parameter_prefix: str = "<parameter="
42
+ self.parameter_end_token: str = "</parameter>"
43
+ self.is_tool_call_started: bool = False
44
+ self.failed_count: int = 0
45
+
46
+ # Enhanced streaming state - reset for each new message
47
+ self._reset_streaming_state()
48
+
49
+ # Regex patterns
50
+ self.tool_call_complete_regex = re.compile(
51
+ r"<tool_call>(.*?)</tool_call>", re.DOTALL)
52
+ self.tool_call_regex = re.compile(
53
+ r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
54
+ self.tool_call_function_regex = re.compile(
55
+ r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
56
+ self.tool_call_parameter_regex = re.compile(
57
+ r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
58
+ re.DOTALL)
59
+
60
+ if not self.model_tokenizer:
61
+ raise ValueError(
62
+ "The model tokenizer must be passed to the ToolParser "
63
+ "constructor during construction.")
64
+
65
+ self.tool_call_start_token_id = self.vocab.get(
66
+ self.tool_call_start_token)
67
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
68
+
69
+ if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
70
+ raise RuntimeError(
71
+ "Qwen3 XML Tool parser could not locate tool call start/end "
72
+ "tokens in the tokenizer!")
73
+
74
+ logger.info(
75
+ f"vLLM Successfully import tool parser {self.__class__.__name__} !"
76
+ )
77
+
78
+ def _generate_tool_call_id(self) -> str:
79
+ """Generate a unique tool call ID."""
80
+ return f"call_{uuid.uuid4().hex[:24]}"
81
+
82
+ def _reset_streaming_state(self):
83
+ """Reset all streaming state."""
84
+ self.current_tool_index = 0
85
+ self.is_tool_call_started = False
86
+ self.header_sent = False
87
+ self.current_tool_id = None
88
+ self.current_function_name = None
89
+ self.current_param_name = None
90
+ self.current_param_value = ""
91
+ self.param_count = 0
92
+ self.in_param = False
93
+ self.in_function = False
94
+ self.accumulated_text = ""
95
+ self.json_started = False
96
+ self.json_closed = False
97
+ # Store accumulated parameters for type conversion
98
+ self.accumulated_params = {}
99
+ self.streaming_request = None
100
+
101
+ def _get_arguments_config(
102
+ self, func_name: str,
103
+ tools: Optional[list[ChatCompletionToolsParam]]) -> dict:
104
+ """Extract argument configuration for a function."""
105
+ if tools is None:
106
+ return {}
107
+ for config in tools:
108
+ if not hasattr(config, "type") or not (hasattr(
109
+ config, "function") and hasattr(config.function, "name")):
110
+ continue
111
+ if config.type == "function" and config.function.name == func_name:
112
+ if not hasattr(config.function, "parameters"):
113
+ return {}
114
+ params = config.function.parameters
115
+ if isinstance(params, dict) and "properties" in params:
116
+ return params["properties"]
117
+ elif isinstance(params, dict):
118
+ return params
119
+ else:
120
+ return {}
121
+ logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
122
+ return {}
123
+
124
+ def _convert_param_value(self, param_value: str, param_name: str,
125
+ param_config: dict, func_name: str) -> Any:
126
+ """Convert parameter value based on its type in the schema."""
127
+ # Handle null value for any type
128
+ if param_value.lower() == "null":
129
+ return None
130
+
131
+ if param_name not in param_config:
132
+ if param_config != {}:
133
+ logger.warning(
134
+ f"Parsed parameter '{param_name}' is not defined in the tool "
135
+ f"parameters for tool '{func_name}', directly returning the string value."
136
+ )
137
+ return param_value
138
+
139
+ if isinstance(param_config[param_name],
140
+ dict) and "type" in param_config[param_name]:
141
+ param_type = str(param_config[param_name]["type"]).strip().lower()
142
+ else:
143
+ param_type = "string"
144
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
145
+ return param_value
146
+ elif param_type.startswith("int") or param_type.startswith(
147
+ "uint") or param_type.startswith(
148
+ "long") or param_type.startswith(
149
+ "short") or param_type.startswith("unsigned"):
150
+ try:
151
+ param_value = int(param_value)
152
+ except:
153
+ logger.warning(
154
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
155
+ f"'{func_name}', degenerating to string.")
156
+ return param_value
157
+ elif param_type.startswith("num") or param_type.startswith("float"):
158
+ try:
159
+ maybe_convert = False if "." in param_value or "e" in param_value.lower() else True
160
+ param_value: float = float(param_value)
161
+ if maybe_convert and param_value.is_integer():
162
+ param_value = int(param_value)
163
+ except:
164
+ logger.warning(
165
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
166
+ f"'{func_name}', degenerating to string.")
167
+ return param_value
168
+ elif param_type in ["boolean", "bool", "binary"]:
169
+ param_value = param_value.lower()
170
+ if param_value not in ["true", "false"]:
171
+ logger.warning(
172
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
173
+ )
174
+ return param_value == "true"
175
+ else:
176
+ if param_type in ["object", "array", "arr"
177
+ ] or param_type.startswith(
178
+ "dict") or param_type.startswith("list"):
179
+ try:
180
+ param_value = json.loads(param_value)
181
+ return param_value
182
+ except:
183
+ logger.warning(
184
+ f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool "
185
+ f"'{func_name}', will try other methods to parse it.")
186
+ try:
187
+ param_value = ast.literal_eval(param_value) # safer
188
+ except:
189
+ logger.warning(
190
+ f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string."
191
+ )
192
+ return param_value
193
+
194
+ def _parse_xml_function_call(
195
+ self, function_call_str: str,
196
+ tools: Optional[list[ChatCompletionToolsParam]]
197
+ ) -> Optional[ToolCall]:
198
+
199
+ # Extract function name
200
+ end_index = function_call_str.index(">")
201
+ function_name = function_call_str[:end_index]
202
+ param_config = self._get_arguments_config(function_name, tools)
203
+ parameters = function_call_str[end_index + 1:]
204
+ param_dict = {}
205
+ for match_text in self.tool_call_parameter_regex.findall(parameters):
206
+ idx = match_text.index(">")
207
+ param_name = match_text[:idx]
208
+ param_value = str(match_text[idx + 1:])
209
+ # Remove prefix and trailing \n
210
+ if param_value.startswith("\n"):
211
+ param_value = param_value[1:]
212
+ if param_value.endswith("\n"):
213
+ param_value = param_value[:-1]
214
+
215
+ param_dict[param_name] = self._convert_param_value(
216
+ param_value, param_name, param_config, function_name)
217
+ return ToolCall(
218
+ type="function",
219
+ function=FunctionCall(name=function_name,
220
+ arguments=json.dumps(param_dict,
221
+ ensure_ascii=False)),
222
+ )
223
+
224
+ def _get_function_calls(self, model_output: str) -> List[str]:
225
+ # Find all tool calls
226
+ matched_ranges = self.tool_call_regex.findall(model_output)
227
+ raw_tool_calls = [
228
+ match[0] if match[0] else match[1] for match in matched_ranges
229
+ ]
230
+
231
+ # Back-off strategy if no tool_call tags found
232
+ if len(raw_tool_calls) == 0:
233
+ raw_tool_calls = [model_output]
234
+
235
+ raw_function_calls = []
236
+ for tool_call in raw_tool_calls:
237
+ raw_function_calls.extend(
238
+ self.tool_call_function_regex.findall(tool_call))
239
+
240
+ function_calls = [
241
+ match[0] if match[0] else match[1] for match in raw_function_calls
242
+ ]
243
+ return function_calls
244
+
245
+ def extract_tool_calls(
246
+ self,
247
+ model_output: str,
248
+ request: ChatCompletionRequest,
249
+ ) -> ExtractedToolCallInformation:
250
+ # Quick check to avoid unnecessary processing
251
+ if self.tool_call_prefix not in model_output:
252
+ return ExtractedToolCallInformation(tools_called=False,
253
+ tool_calls=[],
254
+ content=model_output)
255
+
256
+ try:
257
+ function_calls = self._get_function_calls(model_output)
258
+ if len(function_calls) == 0:
259
+ return ExtractedToolCallInformation(tools_called=False,
260
+ tool_calls=[],
261
+ content=model_output)
262
+
263
+ tool_calls = [
264
+ self._parse_xml_function_call(function_call_str, request.tools)
265
+ for function_call_str in function_calls
266
+ ]
267
+
268
+ # Populate prev_tool_call_arr for serving layer to set finish_reason
269
+ self.prev_tool_call_arr.clear() # Clear previous calls
270
+ for tool_call in tool_calls:
271
+ if tool_call:
272
+ self.prev_tool_call_arr.append({
273
+ "name":
274
+ tool_call.function.name,
275
+ "arguments":
276
+ tool_call.function.arguments,
277
+ })
278
+
279
+ # Extract content before tool calls
280
+ content_index = model_output.find(self.tool_call_start_token)
281
+ content_index = content_index if content_index >= 0 else model_output.find(
282
+ self.tool_call_prefix)
283
+ content = model_output[:content_index] # .rstrip()
284
+
285
+ return ExtractedToolCallInformation(
286
+ tools_called=(len(tool_calls) > 0),
287
+ tool_calls=tool_calls,
288
+ content=content if content else None,
289
+ )
290
+
291
+ except Exception:
292
+ logger.exception("Error in extracting tool call from response.")
293
+ return ExtractedToolCallInformation(tools_called=False,
294
+ tool_calls=[],
295
+ content=model_output)
296
+
297
+ def extract_tool_calls_streaming(
298
+ self,
299
+ previous_text: str,
300
+ current_text: str,
301
+ delta_text: str,
302
+ previous_token_ids: Sequence[int],
303
+ current_token_ids: Sequence[int],
304
+ delta_token_ids: Sequence[int],
305
+ request: ChatCompletionRequest,
306
+ ) -> Union[DeltaMessage, None]:
307
+ # Store request for type conversion
308
+ if not previous_text:
309
+ self._reset_streaming_state()
310
+ self.streaming_request = request
311
+
312
+ # If no delta text, return None unless it's an EOS token after tool calls
313
+ if not delta_text:
314
+ # Check if this is an EOS token after all tool calls are complete
315
+ # We check for tool calls in the text even if is_tool_call_started is False
316
+ # because it might have been reset after processing all tools
317
+ if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
318
+ # Count complete tool calls
319
+ complete_calls = len(
320
+ self.tool_call_complete_regex.findall(current_text))
321
+
322
+ # If we have completed tool calls and populated prev_tool_call_arr
323
+ if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
324
+ # Check if all tool calls are closed
325
+ open_calls = current_text.count(
326
+ self.tool_call_start_token) - current_text.count(
327
+ self.tool_call_end_token)
328
+ if open_calls == 0:
329
+ # Return empty delta message to allow finish_reason processing
330
+ return DeltaMessage(content="")
331
+ elif not self.is_tool_call_started and current_text:
332
+ # This is a regular content response that's now complete
333
+ return DeltaMessage(content="")
334
+ return None
335
+
336
+ # Update accumulated text
337
+ self.accumulated_text = current_text
338
+
339
+ # Check if we need to advance to next tool
340
+ if self.json_closed and not self.in_function:
341
+ # Check if this tool call has ended
342
+ tool_ends = current_text.count(self.tool_call_end_token)
343
+ if tool_ends > self.current_tool_index:
344
+ # This tool has ended, advance to next
345
+ self.current_tool_index += 1
346
+ self.header_sent = False
347
+ self.param_count = 0
348
+ self.json_started = False
349
+ self.json_closed = False
350
+ self.accumulated_params = {}
351
+
352
+ # Check if there are more tool calls
353
+ tool_starts = current_text.count(self.tool_call_start_token)
354
+ if self.current_tool_index >= tool_starts:
355
+ # No more tool calls
356
+ self.is_tool_call_started = False
357
+ # Continue processing next tool
358
+ return None
359
+
360
+ # Handle normal content before tool calls
361
+ if not self.is_tool_call_started:
362
+ # Check if tool call is starting
363
+ if self.tool_call_start_token_id in delta_token_ids or self.tool_call_start_token in delta_text:
364
+ self.is_tool_call_started = True
365
+ # Return any content before the tool call
366
+ if self.tool_call_start_token in delta_text:
367
+ content_before = delta_text[:delta_text.index(
368
+ self.tool_call_start_token)]
369
+ if content_before:
370
+ return DeltaMessage(content=content_before)
371
+ return None
372
+ else:
373
+ # Check if we're between tool calls - skip whitespace
374
+ if current_text.rstrip().endswith(self.tool_call_end_token):
375
+ # We just ended a tool call, skip whitespace
376
+ if delta_text.strip() == "":
377
+ return None
378
+ # Normal content, no tool call
379
+ return DeltaMessage(content=delta_text)
380
+
381
+ # Check if we're between tool calls (waiting for next one)
382
+ # Count tool calls we've seen vs processed
383
+ tool_starts_count = current_text.count(self.tool_call_start_token)
384
+ if self.current_tool_index >= tool_starts_count:
385
+ # We're past all tool calls, shouldn't be here
386
+ return None
387
+
388
+ # We're in a tool call, find the current tool call portion
389
+ # Need to find the correct tool call based on current_tool_index
390
+ tool_starts = []
391
+ idx = 0
392
+ while True:
393
+ idx = current_text.find(self.tool_call_start_token, idx)
394
+ if idx == -1:
395
+ break
396
+ tool_starts.append(idx)
397
+ idx += len(self.tool_call_start_token)
398
+
399
+ if self.current_tool_index >= len(tool_starts):
400
+ # No more tool calls to process yet
401
+ return None
402
+
403
+ tool_start_idx = tool_starts[self.current_tool_index]
404
+ # Find where this tool call ends (or current position if not ended yet)
405
+ tool_end_idx = current_text.find(self.tool_call_end_token,
406
+ tool_start_idx)
407
+ if tool_end_idx == -1:
408
+ tool_text = current_text[tool_start_idx:]
409
+ else:
410
+ tool_text = current_text[tool_start_idx:tool_end_idx +
411
+ len(self.tool_call_end_token)]
412
+
413
+ # Looking for function header
414
+ if not self.header_sent:
415
+ if self.tool_call_prefix in tool_text:
416
+ func_start = tool_text.find(self.tool_call_prefix) + len(
417
+ self.tool_call_prefix)
418
+ func_end = tool_text.find(">", func_start)
419
+
420
+ if func_end != -1:
421
+ # Found complete function name
422
+ self.current_function_name = tool_text[func_start:func_end]
423
+ self.current_tool_id = self._generate_tool_call_id()
424
+ self.header_sent = True
425
+ self.in_function = True
426
+
427
+ # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
428
+ # This ensures finish_reason="tool_calls" even if parsing isn't complete
429
+ already_added = any(
430
+ tool.get("name") == self.current_function_name
431
+ for tool in self.prev_tool_call_arr)
432
+ if not already_added:
433
+ self.prev_tool_call_arr.append({
434
+ "name": self.current_function_name,
435
+ "arguments":
436
+ "{}", # Placeholder, will be updated later
437
+ })
438
+
439
+ # Send header with function info
440
+ return DeltaMessage(tool_calls=[
441
+ DeltaToolCall(
442
+ index=self.current_tool_index,
443
+ id=self.current_tool_id,
444
+ function=DeltaFunctionCall(
445
+ name=self.current_function_name, arguments=""),
446
+ type="function",
447
+ )
448
+ ])
449
+ return None
450
+
451
+ # We've sent header, now handle function body
452
+ if self.in_function:
453
+ # Send opening brace if not sent yet
454
+ if not self.json_started and self.parameter_prefix not in delta_text:
455
+ self.json_started = True
456
+ return DeltaMessage(tool_calls=[
457
+ DeltaToolCall(
458
+ index=self.current_tool_index,
459
+ function=DeltaFunctionCall(arguments="{"),
460
+ )
461
+ ])
462
+
463
+ # Make sure json_started is set if we're processing parameters
464
+ if not self.json_started:
465
+ self.json_started = True
466
+
467
+ # Check for function end in accumulated text
468
+ if not self.json_closed and self.function_end_token in tool_text:
469
+ # Close JSON
470
+ self.json_closed = True
471
+
472
+ # Extract the complete tool call to update prev_tool_call_arr with final arguments
473
+ # Find the function content
474
+ func_start = tool_text.find(self.tool_call_prefix) + len(
475
+ self.tool_call_prefix)
476
+ func_content_end = tool_text.find(self.function_end_token,
477
+ func_start)
478
+ if func_content_end != -1:
479
+ func_content = tool_text[func_start:func_content_end]
480
+ # Parse to get the complete arguments
481
+ try:
482
+ parsed_tool = self._parse_xml_function_call(
483
+ func_content, self.streaming_request.tools
484
+ if self.streaming_request else None)
485
+ if parsed_tool:
486
+ # Update existing entry in prev_tool_call_arr with complete arguments
487
+ for i, tool in enumerate(self.prev_tool_call_arr):
488
+ if tool.get(
489
+ "name") == parsed_tool.function.name:
490
+ self.prev_tool_call_arr[i][
491
+ "arguments"] = parsed_tool.function.arguments
492
+ break
493
+ except Exception:
494
+ pass # Ignore parsing errors during streaming
495
+
496
+ result = DeltaMessage(tool_calls=[
497
+ DeltaToolCall(
498
+ index=self.current_tool_index,
499
+ function=DeltaFunctionCall(arguments="}"),
500
+ )
501
+ ])
502
+
503
+ # Reset state for next tool
504
+ self.in_function = False
505
+ self.json_closed = True
506
+ self.accumulated_params = {}
507
+
508
+ return result
509
+
510
+ # Look for parameters
511
+ # Find all parameter starts
512
+ param_starts = []
513
+ idx = 0
514
+ while True:
515
+ idx = tool_text.find(self.parameter_prefix, idx)
516
+ if idx == -1:
517
+ break
518
+ param_starts.append(idx)
519
+ idx += len(self.parameter_prefix)
520
+
521
+ # Check if we should start a new parameter
522
+ if not self.in_param and self.param_count < len(param_starts):
523
+
524
+ if len(param_starts) > self.param_count:
525
+ # Process the next parameter
526
+ param_idx = param_starts[self.param_count]
527
+ param_start = param_idx + len(self.parameter_prefix)
528
+ remaining = tool_text[param_start:]
529
+
530
+ if ">" in remaining:
531
+ # We have the complete parameter name
532
+ name_end = remaining.find(">")
533
+ self.current_param_name = remaining[:name_end]
534
+
535
+ # Find the parameter value
536
+ value_start = param_start + name_end + 1
537
+ value_text = tool_text[value_start:]
538
+ if value_text.startswith("\n"):
539
+ value_text = value_text[1:]
540
+
541
+ # Find where this parameter ends
542
+ param_end_idx = value_text.find(
543
+ self.parameter_end_token)
544
+ if param_end_idx == -1:
545
+ # No closing tag, look for next parameter or function end
546
+ next_param_idx = value_text.find(
547
+ self.parameter_prefix)
548
+ func_end_idx = value_text.find(
549
+ self.function_end_token)
550
+
551
+ if next_param_idx != -1 and (func_end_idx == -1
552
+ or next_param_idx
553
+ < func_end_idx):
554
+ param_end_idx = next_param_idx
555
+ elif func_end_idx != -1:
556
+ param_end_idx = func_end_idx
557
+ else:
558
+ # Neither found, check if tool call is complete
559
+ if self.tool_call_end_token in tool_text:
560
+ # Tool call is complete, so parameter must be complete too
561
+ # Use all remaining text before function end as value
562
+ param_end_idx = len(value_text)
563
+ else:
564
+ # Still streaming, wait for more content
565
+ return None
566
+
567
+ if param_end_idx != -1:
568
+ # Complete parameter found
569
+ param_value = value_text[:param_end_idx]
570
+ if param_value.endswith("\n"):
571
+ param_value = param_value[:-1]
572
+
573
+ # Store raw value for later processing
574
+ self.accumulated_params[
575
+ self.current_param_name] = param_value
576
+
577
+ # Get parameter configuration for type conversion
578
+ param_config = self._get_arguments_config(
579
+ self.current_function_name,
580
+ self.streaming_request.tools
581
+ if self.streaming_request else None)
582
+
583
+ # Convert the parameter value to the appropriate type
584
+ converted_value = self._convert_param_value(
585
+ param_value, self.current_param_name,
586
+ param_config, self.current_function_name)
587
+
588
+ # Build JSON fragment based on the converted type
589
+ # Use json.dumps to properly serialize the value
590
+ serialized_value = json.dumps(converted_value,
591
+ ensure_ascii=False)
592
+
593
+ if self.param_count == 0:
594
+ json_fragment = f'"{self.current_param_name}": {serialized_value}'
595
+ else:
596
+ json_fragment = f', "{self.current_param_name}": {serialized_value}'
597
+
598
+ self.param_count += 1
599
+
600
+ return DeltaMessage(tool_calls=[
601
+ DeltaToolCall(
602
+ index=self.current_tool_index,
603
+ function=DeltaFunctionCall(
604
+ arguments=json_fragment),
605
+ )
606
+ ])
607
+
608
+ # Continue parameter value - Not used in the current implementation
609
+ # since we process complete parameters above
610
+ if self.in_param:
611
+ if self.parameter_end_token in delta_text:
612
+ # End of parameter
613
+ end_idx = delta_text.find(self.parameter_end_token)
614
+ value_chunk = delta_text[:end_idx]
615
+
616
+ # Skip past > if at start
617
+ if not self.current_param_value and ">" in value_chunk:
618
+ gt_idx = value_chunk.find(">")
619
+ value_chunk = value_chunk[gt_idx + 1:]
620
+
621
+ if not self.current_param_value and value_chunk.startswith(
622
+ "\n"):
623
+ value_chunk = value_chunk[1:]
624
+
625
+ # Store complete value
626
+ full_value = self.current_param_value + value_chunk
627
+ self.accumulated_params[
628
+ self.current_param_name] = full_value
629
+
630
+ # Get parameter configuration for type conversion
631
+ param_config = self._get_arguments_config(
632
+ self.current_function_name,
633
+ self.streaming_request.tools
634
+ if self.streaming_request else None)
635
+
636
+ # Convert the parameter value to the appropriate type
637
+ converted_value = self._convert_param_value(
638
+ full_value, self.current_param_name, param_config,
639
+ self.current_function_name)
640
+
641
+ # Serialize the converted value
642
+ serialized_value = json.dumps(converted_value,
643
+ ensure_ascii=False)
644
+
645
+ # Since we've been streaming the quoted version, we need to close it properly
646
+ # This is complex - for now just complete the value
647
+ self.in_param = False
648
+ self.current_param_value = ""
649
+
650
+ # Just close the current parameter string
651
+ return DeltaMessage(tool_calls=[
652
+ DeltaToolCall(
653
+ index=self.current_tool_index,
654
+ function=DeltaFunctionCall(
655
+ arguments='"'), # Close the string quote
656
+ )
657
+ ])
658
+ else:
659
+ # Continue accumulating value
660
+ value_chunk = delta_text
661
+
662
+ # Handle first chunk after param name
663
+ if not self.current_param_value and ">" in value_chunk:
664
+ gt_idx = value_chunk.find(">")
665
+ value_chunk = value_chunk[gt_idx + 1:]
666
+
667
+ if not self.current_param_value and value_chunk.startswith(
668
+ "\n"):
669
+ value_chunk = value_chunk[1:]
670
+
671
+ if value_chunk:
672
+ # Stream the escaped delta
673
+ prev_escaped = json.dumps(
674
+ self.current_param_value, ensure_ascii=False
675
+ )[1:-1] if self.current_param_value else ""
676
+ self.current_param_value += value_chunk
677
+ full_escaped = json.dumps(self.current_param_value,
678
+ ensure_ascii=False)[1:-1]
679
+ delta_escaped = full_escaped[len(prev_escaped):]
680
+
681
+ if delta_escaped:
682
+ return DeltaMessage(tool_calls=[
683
+ DeltaToolCall(
684
+ index=self.current_tool_index,
685
+ function=DeltaFunctionCall(
686
+ arguments=delta_escaped),
687
+ )
688
+ ])
689
+
690
+ return None