Skip to content

Commit 27da37e

Browse files
SongChiYoungekzhu
andauthored
[Refactor] model family resolution to support non-prefixed names like Mistral (#6158)
This PR improves how model_family is resolved when selecting a transformer from the registry. Previously, model families were inferred using a simple prefix-based match like: ``` if model.startswith(family): ... ``` This works for cleanly prefixed models (e.g., `gpt-4o`, `claude-3`) but fails for models like `mistral-large-latest`, `codestral-latest`, etc., where prefix-based matching is ambiguous or misleading. To address this: • model_family can now be passed explicitly (e.g., via ModelInfo) • _find_model_family() is only used as a fallback when the value is "unknown" • Transformer lookup is now more robust and predictable • Example integration in to_oai_type() demonstrates this pattern using self._model_info["family"] This change is required for safe support of models like Mistral and other future models that do not follow standard naming conventions. Linked to discussion in [#6151](#6151) Related : #6011 --------- Co-authored-by: Eric Zhu <[email protected]>
1 parent 9143e58 commit 27da37e

File tree

4 files changed

+92
-13
lines changed

4 files changed

+92
-13
lines changed

python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
import logging
12
from typing import Dict
23

4+
from autogen_core import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
35
from autogen_core.models import ModelFamily, ModelInfo
46

7+
logger = logging.getLogger(EVENT_LOGGER_NAME)
8+
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
9+
510
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
611
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
712
_MODEL_POINTERS = {
13+
# OpenAI models
814
"o3-mini": "o3-mini-2025-01-31",
915
"o1": "o1-2024-12-17",
1016
"o1-preview": "o1-preview-2024-09-12",
@@ -18,6 +24,7 @@
1824
"gpt-4-32k": "gpt-4-32k-0613",
1925
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
2026
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
27+
# Anthropic models
2128
"claude-3-haiku": "claude-3-haiku-20240307",
2229
"claude-3-sonnet": "claude-3-sonnet-20240229",
2330
"claude-3-opus": "claude-3-opus-20240229",
@@ -291,8 +298,24 @@ def resolve_model(model: str) -> str:
291298

292299

293300
def get_info(model: str) -> ModelInfo:
301+
# If call it, that mean is that the config does not have cumstom model_info
294302
resolved_model = resolve_model(model)
295-
return _MODEL_INFO[resolved_model]
303+
model_info: ModelInfo = _MODEL_INFO.get(
304+
resolved_model,
305+
{
306+
"vision": False,
307+
"function_calling": False,
308+
"json_output": False,
309+
"family": "FAILED",
310+
"structured_output": False,
311+
},
312+
)
313+
if model_info.get("family") == "FAILED":
314+
raise ValueError("model_info is required when model name is not a valid OpenAI model")
315+
if model_info.get("family") == ModelFamily.UNKNOWN:
316+
trace_logger.warning(f"Model info not found for model: {model}")
317+
318+
return model_info
296319

297320

298321
def get_token_limit(model: str) -> int:

python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,12 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
162162

163163

164164
def to_oai_type(
165-
message: LLMMessage, prepend_name: bool = False, model_family: str = "gpt-4o"
165+
message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN
166166
) -> Sequence[ChatCompletionMessageParam]:
167167
context = {
168168
"prepend_name": prepend_name,
169169
}
170-
transformers = get_transformer("openai", model_family)
170+
transformers = get_transformer("openai", model, model_family)
171171

172172
def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]:
173173
raise ValueError(f"Unknown message type: {type(message)}")
@@ -280,6 +280,7 @@ def count_tokens_openai(
280280
*,
281281
add_name_prefixes: bool = False,
282282
tools: Sequence[Tool | ToolSchema] = [],
283+
model_family: str = ModelFamily.UNKNOWN,
283284
) -> int:
284285
try:
285286
encoding = tiktoken.encoding_for_model(model)
@@ -293,7 +294,7 @@ def count_tokens_openai(
293294
# Message tokens.
294295
for message in messages:
295296
num_tokens += tokens_per_message
296-
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model_family=model)
297+
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family)
297298
for oai_message_part in oai_message:
298299
for key, value in oai_message_part.items():
299300
if value is None:
@@ -556,7 +557,12 @@ def _process_create_args(
556557
messages = self._rstrip_last_assistant_message(messages)
557558

558559
oai_messages_nested = [
559-
to_oai_type(m, prepend_name=self._add_name_prefixes, model_family=create_args.get("model", "unknown"))
560+
to_oai_type(
561+
m,
562+
prepend_name=self._add_name_prefixes,
563+
model=create_args.get("model", "unknown"),
564+
model_family=self._model_info["family"],
565+
)
560566
for m in messages
561567
]
562568

@@ -1049,6 +1055,7 @@ def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool |
10491055
self._create_args["model"],
10501056
add_name_prefixes=self._add_name_prefixes,
10511057
tools=tools,
1058+
model_family=self._model_info["family"],
10521059
)
10531060

10541061
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:

python/packages/autogen-ext/src/autogen_ext/models/openai/_transformation/registry.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import defaultdict
22
from typing import Any, Callable, Dict, List
33

4-
from autogen_core.models import LLMMessage
4+
from autogen_core.models import LLMMessage, ModelFamily
55

66
from .types import (
77
TransformerFunc,
@@ -87,13 +87,14 @@ def _find_model_family(api: str, model: str) -> str:
8787
Finds the best matching model family for the given model.
8888
Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0").
8989
"""
90-
for family in MESSAGE_TRANSFORMERS[api].keys():
91-
if model.startswith(family):
92-
return family
93-
return "default"
90+
family = ModelFamily.UNKNOWN
91+
for _family in MESSAGE_TRANSFORMERS[api].keys():
92+
if model.startswith(_family):
93+
family = _family
94+
return family
9495

9596

96-
def get_transformer(api: str, model_family: str) -> TransformerMap:
97+
def get_transformer(api: str, model: str, model_family: str) -> TransformerMap:
9798
"""
9899
Returns the registered transformer map for the given model family.
99100
@@ -107,9 +108,11 @@ def get_transformer(api: str, model_family: str) -> TransformerMap:
107108
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
108109
"""
109110

110-
model = _find_model_family(api, model_family)
111+
if model_family == ModelFamily.UNKNOWN:
112+
# fallback to finding the best matching model family
113+
model_family = _find_model_family(api, model)
111114

112-
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model, {})
115+
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})
113116

114117
if not transformer:
115118
raise ValueError(f"No transformer found for model family '{model_family}'")

python/packages/autogen-ext/tests/models/test_openai_model_client.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
convert_tools,
3030
to_oai_type,
3131
)
32+
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
3233
from openai.resources.beta.chat.completions import ( # type: ignore
3334
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
3435
)
@@ -2367,6 +2368,51 @@ async def test_empty_assistant_content_string_with_some_model(
23672368
assert isinstance(result.content, str)
23682369

23692370

2371+
def test_openai_model_registry_find_well() -> None:
2372+
model = "gpt-4o"
2373+
client1 = OpenAIChatCompletionClient(model=model, api_key="test")
2374+
client2 = OpenAIChatCompletionClient(
2375+
model=model,
2376+
model_info={
2377+
"vision": False,
2378+
"function_calling": False,
2379+
"json_output": False,
2380+
"structured_output": False,
2381+
"family": ModelFamily.UNKNOWN,
2382+
},
2383+
api_key="test",
2384+
)
2385+
2386+
def get_regitered_transformer(client: OpenAIChatCompletionClient) -> TransformerMap:
2387+
model_name = client._create_args["model"] # pyright: ignore[reportPrivateUsage]
2388+
model_family = client.model_info["family"]
2389+
return get_transformer("openai", model_name, model_family)
2390+
2391+
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)
2392+
2393+
2394+
def test_openai_model_registry_find_wrong() -> None:
2395+
with pytest.raises(ValueError, match="No transformer found for model family"):
2396+
get_transformer("openai", "gpt-7", "foobar")
2397+
2398+
2399+
@pytest.mark.asyncio
2400+
@pytest.mark.parametrize(
2401+
"model",
2402+
[
2403+
"gpt-4o-mini",
2404+
],
2405+
)
2406+
async def test_openai_model_unknown_message_type(model: str, openai_client: OpenAIChatCompletionClient) -> None:
2407+
class WrongMessage:
2408+
content = "foo"
2409+
source = "bar"
2410+
2411+
messages: List[WrongMessage] = [WrongMessage()]
2412+
with pytest.raises(ValueError, match="Unknown message type"):
2413+
await openai_client.create(messages=messages) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2414+
2415+
23702416
@pytest.mark.asyncio
23712417
@pytest.mark.parametrize(
23722418
"model",

0 commit comments

Comments
 (0)