-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfc_inferencer.py
More file actions
681 lines (584 loc) · 23.9 KB
/
fc_inferencer.py
File metadata and controls
681 lines (584 loc) · 23.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
"""
Async Function Call Inferencer with direct tool calling.
"""
import asyncio
import ast
import json
import logging
import os
import random
import re
import time
from typing import List, Literal, Optional, TypedDict, Union
import httpx
from openai import (
APIConnectionError,
APITimeoutError,
AsyncOpenAI,
AuthenticationError,
BadRequestError,
InternalServerError,
RateLimitError,
)
from pydantic import BaseModel, Field
from llm_streaming import collect_openai_chat_stream, is_streaming_unsupported_error
from tools.registry import ToolRegistry, build_default_registry
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AsyncFCInferencer")
def get_middle_mixed(text: str, max_num: int = 4000) -> str:
"""
Truncate mixed Chinese-English text, keeping head and tail.
Args:
text: Original text
max_num: Maximum number of units to keep
Returns:
Truncated text with head and tail preserved
"""
if not text or max_num <= 0:
return ""
pattern = re.compile(r"[a-zA-Z0-9_'-]+|[^\s]")
matches = list(pattern.finditer(text))
total_units = len(matches)
if total_units <= max_num:
return text
head_count = max_num // 2
tail_count = max_num - head_count
parts = []
if head_count > 0:
head_span_end = matches[head_count - 1].end()
parts.append(text[:head_span_end])
parts.append("...(truncated)...")
if tail_count > 0:
tail_idx = total_units - tail_count
tail_span_start = matches[tail_idx].start()
parts.append(text[tail_span_start:])
return "".join(parts)
class FunctionCall(BaseModel):
"""Function call model for tool calls."""
name: Optional[str] = None
arguments: str = ""
class ToolCall(BaseModel):
"""Tool call model."""
id: str
type: Literal['function'] = 'function'
function: FunctionCall
class ChatMessage(BaseModel):
"""Chat message model compatible with OpenAI format."""
role: str
content: Optional[str] = None
reasoning_content: Optional[str] = Field(default=None)
tool_calls: Optional[List[ToolCall]] = Field(default=None)
tool_call_id: Optional[str] = None
name: Optional[str] = None
class ModelConfig(TypedDict):
"""Model configuration."""
model: str
base_url: Union[str, List[str]]
api_key: Optional[str]
class SampleParameters(TypedDict, total=False):
"""Sampling parameters for LLM inference."""
temperature: float
top_p: float
top_k: int
class AsyncFCInferencer:
"""
Async Function Call Inferencer with direct tool support.
This inferencer supports:
- Multiple LLM backends with load balancing
- Direct tool calling via ToolRegistry (no MCP protocol)
- Automatic retry mechanism
- Tool response truncation
"""
def __init__(
self,
model: ModelConfig,
model_infer_params: Optional[dict] = None,
registry: Optional[ToolRegistry] = None,
max_iterations: Optional[int] = None,
timeout: Optional[int] = None,
max_retry: Optional[int] = None,
sleep_interval: Optional[int] = None,
max_tool_response_length: Optional[int] = None,
max_tool_calls_per_turn: Optional[int] = None,
task_id: Optional[str] = None,
):
base_urls = model['base_url'] if isinstance(model['base_url'], list) else [model['base_url']]
# Create independent HTTP client for this instance
max_connections = int(os.getenv("MAX_CONNECTIONS", "100"))
max_keepalive = int(os.getenv("MAX_KEEPALIVE_CONNECTIONS", "20"))
keepalive_expiry = float(os.getenv("KEEPALIVE_EXPIRY", "10.0"))
http_timeout = float(os.getenv("HTTP_TIMEOUT", os.getenv("TIMEOUT", "60.0")))
request_timeout = (
float(timeout)
if timeout is not None
else float(os.getenv("REQUEST_TIMEOUT", "2000.0"))
)
self.timeout = request_timeout
self._started_at = time.monotonic()
self._deadline_at = (
self._started_at + request_timeout
if request_timeout and request_timeout > 0
else None
)
io_timeout = min(http_timeout, request_timeout) if request_timeout > 0 else http_timeout
self.http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=max_connections,
max_keepalive_connections=max_keepalive,
keepalive_expiry=keepalive_expiry
),
timeout=httpx.Timeout(
connect=io_timeout,
read=request_timeout,
write=io_timeout,
pool=io_timeout
)
)
self.clients = [
AsyncOpenAI(
api_key=model.get("api_key") or "dummy",
base_url=url,
http_client=self.http_client,
max_retries=0 # Disable SDK auto-retry, use application-level retry only
)
for url in base_urls
]
self.model_name = model["model"]
self.model_infer_params = model_infer_params or {}
self.max_iterations = max_iterations or int(os.getenv("MAX_ITERATIONS", "50"))
self.max_retry = max_retry or int(os.getenv("MAX_RETRY", "25"))
self.sleep_interval = sleep_interval or int(os.getenv("RETRY_INTERVAL", "5"))
self.max_tool_response_length = max_tool_response_length or int(os.getenv("MAX_TOOL_RESPONSE_LENGTH", "8192"))
self.max_tool_calls_per_turn = max_tool_calls_per_turn or int(os.getenv("MAX_TOOL_CALLS_PER_TURN", "5"))
self.enable_llm_streaming = os.getenv("ENABLE_LLM_STREAMING", "1").strip().lower() not in {
"0",
"false",
"no",
}
self.task_id = str(task_id or "unknown")
self._streaming_mode_logged = False
self._streaming_fallback_logged = False
self.registry = registry or build_default_registry()
self.last_error: Optional[str] = None
self.last_status: Optional[str] = None
def _set_terminal_state(self, error: Optional[str], *, status: str) -> None:
"""Record the current terminal state for the request."""
if error not in (None, ""):
self.last_error = error
self.last_status = status
def _clear_failure_state(self) -> None:
"""Reset terminal state after a successful step."""
self.last_error = None
self.last_status = None
@staticmethod
def _extract_status_code(exc: Exception) -> Optional[int]:
"""Best-effort extract an HTTP status code from provider/transport exceptions."""
for attr in ("status_code", "status"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
if response is not None:
for attr in ("status_code", "status"):
value = getattr(response, attr, None)
if isinstance(value, int):
return value
return None
@classmethod
def _is_terminal_llm_exception(cls, exc: Exception) -> bool:
"""Return True for prompt/auth/configuration errors that should stop retrying."""
if isinstance(exc, (AuthenticationError, BadRequestError)):
return True
status_code = cls._extract_status_code(exc)
if isinstance(status_code, int):
return status_code in {400, 401, 403, 404, 422}
error_text = f"{type(exc).__name__}: {exc}".lower()
return (
"authentication" in error_text
or "invalid task gateway token" in error_text
or "invalid token" in error_text
)
def _remaining_budget(self) -> Optional[float]:
"""Return the remaining request budget in seconds, if bounded."""
if self._deadline_at is None:
return None
return self._deadline_at - time.monotonic()
def _mark_deadline_exceeded(self, context: str) -> None:
"""Record deadline exhaustion with contextual detail."""
detail = "Request deadline exceeded"
if context:
detail = f"{detail} while {context}"
self._set_terminal_state(detail, status="error")
def _ensure_time_remaining(self, context: str, *, min_remaining: float = 1.0) -> bool:
"""Return False and record a stop reason when no useful time budget remains."""
remaining = self._remaining_budget()
if remaining is not None and remaining <= min_remaining:
self._mark_deadline_exceeded(context)
return False
return True
def _effective_call_timeout(self, context: str) -> Optional[float]:
"""Clamp each model call to the remaining end-to-end request budget."""
remaining = self._remaining_budget()
if remaining is None:
return self.timeout
effective_timeout = min(float(self.timeout), remaining)
if effective_timeout <= 1.0:
self._mark_deadline_exceeded(context)
return None
return effective_timeout
async def _sleep_before_retry(self, context: str) -> bool:
"""Sleep before retrying, but never beyond the remaining request budget."""
remaining = self._remaining_budget()
if remaining is None:
await asyncio.sleep(self.sleep_interval)
return True
sleep_for = min(float(self.sleep_interval), max(0.0, remaining - 1.0))
if sleep_for <= 0:
self._mark_deadline_exceeded(context)
return False
await asyncio.sleep(sleep_for)
return True
async def _handle_llm_retry_error(self, exc: Exception, attempt: int) -> bool:
"""Handle one LLM request error. Returns True when the caller should retry."""
logger.error(
"LLM request error (attempt %d/%d): %s: %s",
attempt + 1,
self.max_retry,
type(exc).__name__,
exc,
)
if attempt == self.max_retry - 1:
self._set_terminal_state(
(
f"LLM request failed after {self.max_retry} attempts: "
f"{type(exc).__name__}: {exc}"
),
status="error",
)
if not self._ensure_time_remaining(
f"handling exhausted retries for LLM attempt {attempt + 1}",
min_remaining=0.0,
):
return False
return False
if not await self._sleep_before_retry(
f"waiting to retry LLM attempt {attempt + 2}/{self.max_retry}"
):
return False
return True
async def infer(self, messages: List[ChatMessage]) -> List[dict]:
"""Run inference with tool calling loop."""
self._clear_failure_state()
current_messages = [m.model_dump(exclude_none=True) for m in messages]
tools_schema = self.registry.schemas
for iteration in range(1, self.max_iterations + 1):
if not self._ensure_time_remaining(f"starting iteration {iteration}"):
break
logger.info(f"Iteration {iteration}/{self.max_iterations}")
response = await self._call_llm(current_messages, tools_schema)
if response is None:
if self.last_error is None:
self._set_terminal_state(
f"LLM request failed after {self.max_retry} attempts",
status="error",
)
break
choice = response.choices[0]
message_data = choice.message
assistant_msg = message_data.model_dump(exclude_none=True)
if "content" not in assistant_msg:
assistant_msg["content"] = ""
current_messages.append(assistant_msg)
if not message_data.tool_calls:
self._clear_failure_state()
break
if len(message_data.tool_calls) > self.max_tool_calls_per_turn:
tool_call_count = len(message_data.tool_calls)
error = (
f"model returned {tool_call_count} tool calls in one turn, "
f"exceeding limit {self.max_tool_calls_per_turn}"
)
logger.warning(f"Too many tool calls: {tool_call_count}")
self._set_terminal_state(error, status="completed")
break
logger.info(f"Tools called: {[tc.function.name for tc in message_data.tool_calls]}")
tool_results = await self._execute_tool_calls(message_data.tool_calls)
if tool_results is None:
if self.last_error is None:
self._set_terminal_state("Tool execution failed", status="error")
break
current_messages.extend(tool_results)
else:
self._set_terminal_state(
f"Reached max iterations ({self.max_iterations}) without a final answer",
status="completed",
)
return current_messages
async def _call_llm(self, messages: List[dict], tools_schema: list):
"""Call LLM with retry logic."""
for attempt in range(self.max_retry):
call_timeout = self._effective_call_timeout(
f"starting LLM attempt {attempt + 1}/{self.max_retry}"
)
if call_timeout is None:
return None
try:
client = random.choice(self.clients)
response = await self._request_completion(
client,
messages=messages,
tools_schema=tools_schema,
stream=self.enable_llm_streaming,
timeout=call_timeout,
)
self._clear_failure_state()
return response
except (
APITimeoutError,
TimeoutError,
APIConnectionError,
RateLimitError,
InternalServerError,
httpx.TimeoutException,
httpx.NetworkError,
httpx.RemoteProtocolError,
) as e:
if not await self._handle_llm_retry_error(e, attempt):
return None
except AuthenticationError as e:
logger.error("LLM terminal authentication error: %s", e)
self._set_terminal_state(
f"LLM authentication failed: {type(e).__name__}: {e}",
status="error",
)
return None
except BadRequestError as e:
logger.error("LLM terminal bad request: %s", e)
self._set_terminal_state(
f"LLM bad request: {type(e).__name__}: {e}",
status="error",
)
return None
except Exception as e:
status_code = self._extract_status_code(e)
if self._is_terminal_llm_exception(e):
logger.error("LLM terminal error: %s: %s", type(e).__name__, e)
self._set_terminal_state(
f"LLM request failed: {type(e).__name__}: {e}",
status="error",
)
return None
if status_code == 429 or (isinstance(status_code, int) and status_code >= 500):
if not await self._handle_llm_retry_error(e, attempt):
return None
continue
logger.error("LLM unknown error will retry: %s: %s", type(e).__name__, e)
if not await self._handle_llm_retry_error(e, attempt):
return None
return None
def _build_call_params(
self,
messages: List[dict],
tools_schema: list,
*,
stream: bool,
timeout: float,
) -> dict:
"""Build chat-completions parameters for either stream or non-stream mode."""
infer_params = dict(self.model_infer_params or {})
infer_params.pop("stream", None)
stream_options = infer_params.pop("stream_options", None)
call_params = {
"model": self.model_name,
"messages": messages,
"tools": tools_schema if tools_schema else None,
"stream": stream,
"timeout": timeout,
**infer_params,
}
if stream:
call_params["stream_options"] = stream_options or {"include_usage": True}
return call_params
async def _request_completion(
self,
client: AsyncOpenAI,
*,
messages: List[dict],
tools_schema: list,
stream: bool,
timeout: float,
):
"""Request one chat completion, aggregating stream chunks when enabled."""
call_params = self._build_call_params(
messages,
tools_schema,
stream=stream,
timeout=timeout,
)
if not self._streaming_mode_logged:
if stream:
logger.info(
"Task %s LLM streaming enabled for main inferencer (model=%s)",
self.task_id,
self.model_name,
)
else:
logger.info(
"Task %s LLM streaming disabled for main inferencer (model=%s)",
self.task_id,
self.model_name,
)
self._streaming_mode_logged = True
if stream:
try:
stream_resp = await client.chat.completions.create(**call_params)
return await collect_openai_chat_stream(stream_resp, model_name=self.model_name)
except Exception as exc:
if is_streaming_unsupported_error(exc):
if not self._streaming_fallback_logged:
logger.warning(
"Task %s LLM fallback to non-stream for main inferencer (model=%s): upstream rejected streaming",
self.task_id,
self.model_name,
)
self._streaming_fallback_logged = True
return await self._request_completion(
client,
messages=messages,
tools_schema=tools_schema,
stream=False,
timeout=timeout,
)
raise
return await client.chat.completions.create(**call_params)
def _strip_code_fences(self, text: str) -> str:
"""Remove surrounding Markdown code fences from tool arguments."""
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if lines and lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines).strip()
def _parse_tool_arguments(self, args_str: str) -> dict:
"""Parse tool-call arguments with a strict-JSON first strategy.
Models occasionally emit Python-style dicts or fenced JSON. We accept a
small compatibility envelope here because the tool schema is already
known and execution remains local.
"""
candidate = self._strip_code_fences(str(args_str))
candidate = (
candidate
.replace("“", '"')
.replace("”", '"')
.replace("‘", "'")
.replace("’", "'")
)
for _ in range(3):
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
break
if isinstance(parsed, dict):
return parsed
if isinstance(parsed, str):
candidate = self._strip_code_fences(parsed)
continue
raise TypeError(
f"Tool arguments must decode to an object, got {type(parsed).__name__}"
)
try:
parsed = ast.literal_eval(candidate)
except Exception as exc:
raise ValueError(
f"Invalid tool arguments: {candidate[:200]!r}"
) from exc
if isinstance(parsed, dict):
return parsed
if isinstance(parsed, str):
return self._parse_tool_arguments(parsed)
raise TypeError(
f"Tool arguments must decode to an object, got {type(parsed).__name__}"
)
async def _execute_tool_calls(self, tool_calls) -> Optional[List[dict]]:
"""Execute tool calls and return results."""
results = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
args_str = tool_call.function.arguments
call_id = tool_call.id
if not self._ensure_time_remaining(f"executing tool '{tool_name}'"):
return None
result_content = await self._execute_single_tool(tool_name, args_str)
if result_content is None:
return None
results.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_content,
"name": tool_name
})
return results
async def _execute_single_tool(self, tool_name: str, args_str: str) -> Optional[str]:
"""Execute a single tool call via registry (direct function call)."""
if not self.registry.has_tool(tool_name):
logger.error(f"Tool not found: {tool_name}")
self._set_terminal_state(f"Tool not found: {tool_name}", status="completed")
return None
try:
args = self._parse_tool_arguments(args_str)
except Exception as e:
logger.error(
"Invalid tool arguments for %s: %s. Raw arguments preview: %r",
tool_name,
e,
str(args_str)[:200],
)
self._set_terminal_state(
f"Invalid tool arguments for '{tool_name}': "
f"{type(e).__name__}: {e}",
status="completed",
)
return None
for attempt in range(self.max_retry):
if not self._ensure_time_remaining(
f"starting tool '{tool_name}' attempt {attempt + 1}/{self.max_retry}"
):
return None
try:
logger.info(f"Executing {tool_name} with args: {str(args)[:200]}")
result_content = await self.registry.execute(tool_name, args)
if self.max_tool_response_length:
result_content = get_middle_mixed(
result_content, self.max_tool_response_length
)
self._clear_failure_state()
return result_content
except Exception as e:
logger.error(f"Tool execution error (attempt {attempt + 1}): {e}")
self._set_terminal_state(
f"Tool '{tool_name}' failed after {attempt + 1}/{self.max_retry} attempts: "
f"{type(e).__name__}: {e}",
status="error",
)
if attempt == self.max_retry - 1:
return None
if not await self._sleep_before_retry(
f"waiting to retry tool '{tool_name}' attempt {attempt + 2}/{self.max_retry}"
):
return None
return None
async def close(self):
"""Close HTTP client and release all connections."""
if hasattr(self, 'http_client'):
await self.http_client.aclose()
def extract_final_answer(self, messages: List[dict]) -> str:
"""Extract final answer from message history."""
if not messages:
return ""
for msg in reversed(messages):
if msg.get("role") == "assistant" and msg.get("content"):
return msg["content"]
return ""