|
6 | 6 |
|
7 | 7 | import asyncio |
8 | 8 | import time |
9 | | -from collections.abc import AsyncGenerator, AsyncIterator |
| 9 | +from collections.abc import AsyncIterator |
10 | 10 | from datetime import UTC, datetime |
11 | 11 | from typing import Annotated, Any |
12 | 12 |
|
|
15 | 15 | from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam |
16 | 16 | from pydantic import TypeAdapter |
17 | 17 |
|
18 | | -from llama_stack.apis.common.content_types import ( |
19 | | - InterleavedContent, |
20 | | -) |
21 | 18 | from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError |
22 | 19 | from llama_stack.apis.inference import ( |
23 | | - ChatCompletionResponse, |
24 | | - ChatCompletionResponseEventType, |
25 | | - ChatCompletionResponseStreamChunk, |
26 | | - CompletionMessage, |
27 | | - CompletionResponse, |
28 | | - CompletionResponseStreamChunk, |
29 | 20 | Inference, |
30 | 21 | ListOpenAIChatCompletionResponse, |
31 | | - Message, |
32 | 22 | OpenAIAssistantMessageParam, |
33 | 23 | OpenAIChatCompletion, |
34 | 24 | OpenAIChatCompletionChunk, |
|
45 | 35 | OpenAIMessageParam, |
46 | 36 | Order, |
47 | 37 | RerankResponse, |
48 | | - StopReason, |
49 | | - ToolPromptFormat, |
50 | 38 | ) |
51 | 39 | from llama_stack.apis.inference.inference import ( |
52 | 40 | OpenAIChatCompletionContentPartImageParam, |
53 | 41 | OpenAIChatCompletionContentPartTextParam, |
54 | 42 | ) |
55 | | -from llama_stack.apis.models import Model, ModelType |
56 | | -from llama_stack.core.telemetry.telemetry import MetricEvent, MetricInResponse |
| 43 | +from llama_stack.apis.models import ModelType |
| 44 | +from llama_stack.core.telemetry.telemetry import MetricEvent |
57 | 45 | from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span |
58 | 46 | from llama_stack.log import get_logger |
59 | 47 | from llama_stack.models.llama.llama3.chat_format import ChatFormat |
@@ -153,35 +141,6 @@ def _construct_metrics( |
153 | 141 | ) |
154 | 142 | return metric_events |
155 | 143 |
|
156 | | - async def _compute_and_log_token_usage( |
157 | | - self, |
158 | | - prompt_tokens: int, |
159 | | - completion_tokens: int, |
160 | | - total_tokens: int, |
161 | | - model: Model, |
162 | | - ) -> list[MetricInResponse]: |
163 | | - metrics = self._construct_metrics( |
164 | | - prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id |
165 | | - ) |
166 | | - if self.telemetry_enabled: |
167 | | - for metric in metrics: |
168 | | - enqueue_event(metric) |
169 | | - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] |
170 | | - |
171 | | - async def _count_tokens( |
172 | | - self, |
173 | | - messages: list[Message] | InterleavedContent, |
174 | | - tool_prompt_format: ToolPromptFormat | None = None, |
175 | | - ) -> int | None: |
176 | | - if not hasattr(self, "formatter") or self.formatter is None: |
177 | | - return None |
178 | | - |
179 | | - if isinstance(messages, list): |
180 | | - encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) |
181 | | - else: |
182 | | - encoded = self.formatter.encode_content(messages) |
183 | | - return len(encoded.tokens) if encoded and encoded.tokens else 0 |
184 | | - |
185 | 144 | async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]: |
186 | 145 | model = await self.routing_table.get_object_by_identifier("model", model_id) |
187 | 146 | if model: |
@@ -375,121 +334,6 @@ async def health(self) -> dict[str, HealthResponse]: |
375 | 334 | ) |
376 | 335 | return health_statuses |
377 | 336 |
|
378 | | - async def stream_tokens_and_compute_metrics( |
379 | | - self, |
380 | | - response, |
381 | | - prompt_tokens, |
382 | | - fully_qualified_model_id: str, |
383 | | - provider_id: str, |
384 | | - tool_prompt_format: ToolPromptFormat | None = None, |
385 | | - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: |
386 | | - completion_text = "" |
387 | | - async for chunk in response: |
388 | | - complete = False |
389 | | - if hasattr(chunk, "event"): # only ChatCompletions have .event |
390 | | - if chunk.event.event_type == ChatCompletionResponseEventType.progress: |
391 | | - if chunk.event.delta.type == "text": |
392 | | - completion_text += chunk.event.delta.text |
393 | | - if chunk.event.event_type == ChatCompletionResponseEventType.complete: |
394 | | - complete = True |
395 | | - completion_tokens = await self._count_tokens( |
396 | | - [ |
397 | | - CompletionMessage( |
398 | | - content=completion_text, |
399 | | - stop_reason=StopReason.end_of_turn, |
400 | | - ) |
401 | | - ], |
402 | | - tool_prompt_format=tool_prompt_format, |
403 | | - ) |
404 | | - else: |
405 | | - if hasattr(chunk, "delta"): |
406 | | - completion_text += chunk.delta |
407 | | - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled: |
408 | | - complete = True |
409 | | - completion_tokens = await self._count_tokens(completion_text) |
410 | | - # if we are done receiving tokens |
411 | | - if complete: |
412 | | - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) |
413 | | - |
414 | | - # Create a separate span for streaming completion metrics |
415 | | - if self.telemetry_enabled: |
416 | | - # Log metrics in the new span context |
417 | | - completion_metrics = self._construct_metrics( |
418 | | - prompt_tokens=prompt_tokens, |
419 | | - completion_tokens=completion_tokens, |
420 | | - total_tokens=total_tokens, |
421 | | - fully_qualified_model_id=fully_qualified_model_id, |
422 | | - provider_id=provider_id, |
423 | | - ) |
424 | | - for metric in completion_metrics: |
425 | | - if metric.metric in [ |
426 | | - "completion_tokens", |
427 | | - "total_tokens", |
428 | | - ]: # Only log completion and total tokens |
429 | | - enqueue_event(metric) |
430 | | - |
431 | | - # Return metrics in response |
432 | | - async_metrics = [ |
433 | | - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics |
434 | | - ] |
435 | | - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics |
436 | | - else: |
437 | | - # Fallback if no telemetry |
438 | | - completion_metrics = self._construct_metrics( |
439 | | - prompt_tokens or 0, |
440 | | - completion_tokens or 0, |
441 | | - total_tokens, |
442 | | - fully_qualified_model_id=fully_qualified_model_id, |
443 | | - provider_id=provider_id, |
444 | | - ) |
445 | | - async_metrics = [ |
446 | | - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics |
447 | | - ] |
448 | | - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics |
449 | | - yield chunk |
450 | | - |
451 | | - async def count_tokens_and_compute_metrics( |
452 | | - self, |
453 | | - response: ChatCompletionResponse | CompletionResponse, |
454 | | - prompt_tokens, |
455 | | - fully_qualified_model_id: str, |
456 | | - provider_id: str, |
457 | | - tool_prompt_format: ToolPromptFormat | None = None, |
458 | | - ): |
459 | | - if isinstance(response, ChatCompletionResponse): |
460 | | - content = [response.completion_message] |
461 | | - else: |
462 | | - content = response.content |
463 | | - completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) |
464 | | - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) |
465 | | - |
466 | | - # Create a separate span for completion metrics |
467 | | - if self.telemetry_enabled: |
468 | | - # Log metrics in the new span context |
469 | | - completion_metrics = self._construct_metrics( |
470 | | - prompt_tokens=prompt_tokens, |
471 | | - completion_tokens=completion_tokens, |
472 | | - total_tokens=total_tokens, |
473 | | - fully_qualified_model_id=fully_qualified_model_id, |
474 | | - provider_id=provider_id, |
475 | | - ) |
476 | | - for metric in completion_metrics: |
477 | | - if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens |
478 | | - enqueue_event(metric) |
479 | | - |
480 | | - # Return metrics in response |
481 | | - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] |
482 | | - |
483 | | - # Fallback if no telemetry |
484 | | - metrics = self._construct_metrics( |
485 | | - prompt_tokens or 0, |
486 | | - completion_tokens or 0, |
487 | | - total_tokens, |
488 | | - fully_qualified_model_id=fully_qualified_model_id, |
489 | | - provider_id=provider_id, |
490 | | - ) |
491 | | - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] |
492 | | - |
493 | 337 | async def stream_tokens_and_compute_metrics_openai_chat( |
494 | 338 | self, |
495 | 339 | response: AsyncIterator[OpenAIChatCompletionChunk], |
|
0 commit comments