Skip to content

Commit

Permalink
refactor: avoid to use extra space when finding model by name (#13043)
Browse files Browse the repository at this point in the history
  • Loading branch information
acelyc111 authored Jan 30, 2025
1 parent b4b09dd commit b09c39c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
11 changes: 5 additions & 6 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,12 @@ def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Op
:param credentials: model credentials
:return: model schema
"""
# get predefined models (predefined_models)
models = self.predefined_models()

model_map = {model.model: model for model in models}
if model in model_map:
return model_map[model]
# Try to get model schema from predefined models
for predefined_model in self.predefined_models():
if model == predefined_model.model:
return predefined_model

# Try to get model schema from credentials
if credentials:
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
if model_schema:
Expand Down
17 changes: 9 additions & 8 deletions api/core/model_runtime/model_providers/cohere/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,16 +677,17 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
:return: model schema
"""
# get model schema
models = self.predefined_models()
model_map = {model.model: model for model in models}

mode = credentials.get("mode")
base_model_schema = None
for predefined_model in self.predefined_models():
if (
mode == "chat" and predefined_model.model == "command-light-chat"
) or predefined_model.model == "command-light":
base_model_schema = predefined_model
break

if mode == "chat":
base_model_schema = model_map["command-light-chat"]
else:
base_model_schema = model_map["command-light"]
if not base_model_schema:
raise ValueError("Model not found")

base_model_schema = cast(AIModelEntity, base_model_schema)

Expand Down
20 changes: 10 additions & 10 deletions api/core/model_runtime/model_providers/openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,6 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]:
:param credentials: provider credentials
:return:
"""
# get predefined models
predefined_models = self.predefined_models()
predefined_models_map = {model.model: model for model in predefined_models}

# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
Expand All @@ -359,9 +356,10 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]:
base_model = model.id.split(":")[1]

base_model_schema = None
for predefined_model_name, predefined_model in predefined_models_map.items():
if predefined_model_name in base_model:
for predefined_model in self.predefined_models():
if predefined_model.model in base_model:
base_model_schema = predefined_model
break

if not base_model_schema:
continue
Expand Down Expand Up @@ -1186,12 +1184,14 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
base_model = model.split(":")[1]

# get model schema
models = self.predefined_models()
model_map = {model.model: model for model in models}
if base_model not in model_map:
raise ValueError(f"Base model {base_model} not found")
base_model_schema = None
for predefined_model in self.predefined_models():
if base_model == predefined_model.model:
base_model_schema = predefined_model
break

base_model_schema = model_map[base_model]
if not base_model_schema:
raise ValueError(f"Base model {base_model} not found")

base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties
Expand Down

0 comments on commit b09c39c

Please sign in to comment.