Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,6 @@ _build/
# model parameter
*.pth

logs/
logs/

.DS_Store
48 changes: 25 additions & 23 deletions crab/agents/backend_models/camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

try:
from camel.agents import ChatAgent
from camel.configs import ChatGPTConfig
from camel.messages import BaseMessage
from camel.models import ModelFactory
from camel.toolkits import OpenAIFunction
Expand All @@ -33,29 +32,34 @@
CAMEL_ENABLED = False


def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
for platform in ModelPlatformType:
if platform.value.lower() == model_platform_name.lower():
return platform
all_models = [platform.value for platform in ModelPlatformType]
raise ValueError(
f"Model {model_platform_name} not found. Supported models are {all_models}"
)
def _get_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
try:
return ModelPlatformType(model_platform_name)
except ValueError:
all_models = [platform.value for platform in ModelPlatformType]
raise ValueError(
f"Model {model_platform_name} not found. Supported models are {all_models}"
)


def _find_model_type(model_name: str) -> "str | ModelType":
for model in ModelType:
if model.value.lower() == model_name.lower():
return model
return model_name
def _get_model_type(model_name: str) -> "str | ModelType":
try:
return ModelType(model_name)
except ValueError:
return model_name


def _convert_action_to_schema(
action_space: list[Action] | None,
) -> "list[OpenAIFunction] | None":
if action_space is None:
return None
return [OpenAIFunction(action.entry) for action in action_space]
schema_list = []
for action in action_space:
new_action = action.to_openai_json_schema()
schema = {"type": "function", "function": new_action}
schema_list.append(OpenAIFunction(action.entry, schema))
return schema_list


def _convert_tool_calls_to_action_list(
Expand Down Expand Up @@ -84,9 +88,8 @@ def __init__(
if not CAMEL_ENABLED:
raise ImportError("Please install camel-ai to use CamelModel")
self.parameters = parameters or {}
# TODO: a better way?
self.model_type = _find_model_type(model)
self.model_platform_type = _find_model_platform_type(model_platform)
self.model_type = _get_model_type(model)
self.model_platform_type = _get_model_platform_type(model_platform)
self.client: ChatAgent | None = None
self.token_usage = 0

Expand All @@ -104,15 +107,14 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
config = self.parameters.copy()
if action_schema is not None:
config["tool_choice"] = "required"
config["tools"] = action_schema
config["tools"] = [
schema.get_openai_tool_schema() for schema in action_schema
]

chatgpt_config = ChatGPTConfig(
**config,
)
backend_model = ModelFactory.create(
self.model_platform_type,
self.model_type,
model_config_dict=chatgpt_config.as_dict(),
model_config_dict=config,
)
sysmsg = BaseMessage.make_assistant_message(
role_name="Assistant",
Expand Down
Loading
Loading