1111# See the License for the specific language governing permissions and
1212# limitations under the License.
1313# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
14- import base64
15- import io
1614import json
17- from typing import Any , Dict , List , Optional , Tuple , Union
15+ from typing import Any
1816
17+ from openai .types .chat import ChatCompletionMessageToolCall
1918from PIL import Image
2019
2120from crab import Action , ActionOutput , BackendModel , BackendOutput , MessageType
3433 CAMEL_ENABLED = False
3534
3635
37- def find_model_platform_type (model_platform_name : str ) -> ModelPlatformType :
36+ def _find_model_platform_type (model_platform_name : str ) -> " ModelPlatformType" :
3837 for platform in ModelPlatformType :
3938 if platform .value .lower () == model_platform_name .lower ():
4039 return platform
@@ -44,33 +43,51 @@ def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
4443 )
4544
4645
47- def find_model_type (model_name : str ) -> Union [ ModelType , str ] :
46+ def _find_model_type (model_name : str ) -> " str | ModelType" :
4847 for model in ModelType :
4948 if model .value .lower () == model_name .lower ():
5049 return model
5150 return model_name
5251
5352
54- def decode_image (encoded_image : str ) -> Image :
55- data = base64 .b64decode (encoded_image )
56- return Image .open (io .BytesIO (data ))
53+ def _convert_action_to_schema (
54+ action_space : list [Action ] | None ,
55+ ) -> "list[OpenAIFunction] | None" :
56+ if action_space is None :
57+ return None
58+ return [OpenAIFunction (action .entry ) for action in action_space ]
59+
60+
61+ def _convert_tool_calls_to_action_list (
62+ tool_calls : list [ChatCompletionMessageToolCall ] | None ,
63+ ) -> list [ActionOutput ] | None :
64+ if tool_calls is None :
65+ return None
66+
67+ return [
68+ ActionOutput (
69+ name = call .function .name ,
70+ arguments = json .loads (call .function .arguments ),
71+ )
72+ for call in tool_calls
73+ ]
5774
5875
5976class CamelModel (BackendModel ):
6077 def __init__ (
6178 self ,
6279 model : str ,
6380 model_platform : str ,
64- parameters : Optional [ Dict [ str , Any ]] = None ,
81+ parameters : dict [ str , Any ] | None = None ,
6582 history_messages_len : int = 0 ,
6683 ) -> None :
6784 if not CAMEL_ENABLED :
6885 raise ImportError ("Please install camel-ai to use CamelModel" )
6986 self .parameters = parameters or {}
7087 # TODO: a better way?
71- self .model_type = find_model_type (model )
72- self .model_platform_type = find_model_platform_type (model_platform )
73- self .client : Optional [ ChatAgent ] = None
88+ self .model_type = _find_model_type (model )
89+ self .model_platform_type = _find_model_platform_type (model_platform )
90+ self .client : ChatAgent | None = None
7491 self .token_usage = 0
7592
7693 super ().__init__ (
@@ -79,11 +96,11 @@ def __init__(
7996 history_messages_len ,
8097 )
8198
82- def get_token_usage (self ):
99+ def get_token_usage (self ) -> int :
83100 return self .token_usage
84101
85- def reset (self , system_message : str , action_space : Optional [ List [ Action ]] ) -> None :
86- action_schema = self . _convert_action_to_schema (action_space )
102+ def reset (self , system_message : str , action_space : list [ Action ] | None ) -> None :
103+ action_schema = _convert_action_to_schema (action_space )
87104 config = self .parameters .copy ()
88105 if action_schema is not None :
89106 config ["tool_choice" ] = "required"
@@ -109,30 +126,9 @@ def reset(self, system_message: str, action_space: Optional[List[Action]]) -> No
109126 )
110127 self .token_usage = 0
111128
112- @staticmethod
113- def _convert_action_to_schema (
114- action_space : Optional [List [Action ]],
115- ) -> Optional [List [OpenAIFunction ]]:
116- if action_space is None :
117- return None
118- return [OpenAIFunction (action .entry ) for action in action_space ]
119-
120- @staticmethod
121- def _convert_tool_calls_to_action_list (tool_calls ) -> List [ActionOutput ]:
122- if tool_calls is None :
123- return tool_calls
124-
125- return [
126- ActionOutput (
127- name = call .function .name ,
128- arguments = json .loads (call .function .arguments ),
129- )
130- for call in tool_calls
131- ]
132-
133- def chat (self , messages : List [Tuple [str , MessageType ]]):
129+ def chat (self , messages : list [tuple [str , MessageType ]]) -> BackendOutput :
134130 # TODO: handle multiple text messages after message refactoring
135- image_list : List [Image .Image ] = []
131+ image_list : list [Image .Image ] = []
136132 content = ""
137133 for message in messages :
138134 if message [1 ] == MessageType .IMAGE_JPG_BASE64 :
@@ -147,12 +143,9 @@ def chat(self, messages: List[Tuple[str, MessageType]]):
147143 )
148144 response = self .client .step (usermsg )
149145 self .token_usage += response .info ["usage" ]["total_tokens" ]
150- tool_call_request = response .info .get ("tool_call_request" )
151-
152- # TODO: delete this after record_message is refactored
153- self .client .record_message (response .msg )
146+ tool_call_request = response .info .get ("external_tool_request" )
154147
155148 return BackendOutput (
156149 message = response .msg .content ,
157- action_list = self . _convert_tool_calls_to_action_list ([tool_call_request ]),
150+ action_list = _convert_tool_calls_to_action_list ([tool_call_request ]),
158151 )
0 commit comments