@@ -29,27 +29,26 @@ class OpenAIModel(BackendModel):
2929 def __init__ (
3030 self ,
3131 model : str ,
32- parameters : dict [str , Any ] = dict () ,
32+ parameters : dict [str , Any ] | None = None ,
3333 history_messages_len : int = 0 ,
3434 tool_call_required : bool = False ,
35+ base_url : str | None = None ,
3536 ) -> None :
3637 if not openai_model_enable :
3738 raise ImportError ("Please install openai to use OpenAIModel" )
38- super ().__init__ (
39- model ,
40- parameters ,
41- history_messages_len ,
42- )
43- self .client = openai .OpenAI ()
44- self .tool_call_required = tool_call_required
45- self .system_message = "You are a helpful assistant."
46- self .openai_system_message = {
47- "role" : "system" ,
48- "content" : self .system_message ,
49- }
50- self .action_space = None
51- self .action_schema = None
52- self .token_usage = 0
39+
40+ self .model = model
41+ self .parameters = parameters if parameters is not None else {}
42+ self .history_messages_len = history_messages_len
43+
44+ assert self .history_messages_len >= 0
45+
46+ self .client = openai .OpenAI (base_url = base_url )
47+ self .tool_call_required : bool = tool_call_required
48+ self .system_message : str = "You are a helpful assistant."
49+ self .action_space : list [Action ] | None = None
50+ self .action_schema : list [dict ] | None = None
51+ self .token_usage : int = 0
5352 self .chat_history : list [list [ChatCompletionMessage | dict ]] = []
5453
5554 def reset (self , system_message : str , action_space : list [Action ] | None ) -> None :
@@ -59,9 +58,9 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
5958 "content" : system_message ,
6059 }
6160 self .action_space = action_space
62- self .action_schema = self . _convert_action_to_schema (self .action_space )
61+ self .action_schema = _convert_action_to_schema (self .action_space )
6362 self .token_usage = 0
64- self .chat_history : list [ list [ ChatCompletionMessage | dict ]] = []
63+ self .chat_history = []
6564
6665 def chat (self , message : list [Message ] | Message ) -> BackendOutput :
6766 if isinstance (message , tuple ):
@@ -93,10 +92,12 @@ def record_message(
9392 }
9493 ) # extend conversation with function response
9594
96- def call_api (self , request_messages : list ) -> ChatCompletionMessage :
95+ def call_api (
96+ self , request_messages : list [ChatCompletionMessage | dict ]
97+ ) -> ChatCompletionMessage :
9798 if self .action_schema is not None :
9899 response = self .client .chat .completions .create (
99- messages = request_messages ,
100+ messages = request_messages , # type: ignore
100101 model = self .model ,
101102 tools = self .action_schema ,
102103 tool_choice = "required" if self .tool_call_required else "auto" ,
@@ -115,8 +116,8 @@ def call_api(self, request_messages: list) -> ChatCompletionMessage:
115116 def fetch_from_memory (self ) -> list [ChatCompletionMessage | dict ]:
116117 request : list [ChatCompletionMessage | dict ] = [self .openai_system_message ]
117118 if self .history_messages_len > 0 :
118- fetch_hisotry_len = min (self .history_messages_len , len (self .chat_history ))
119- for history_message in self .chat_history [- fetch_hisotry_len :]:
119+ fetch_history_len = min (self .history_messages_len , len (self .chat_history ))
120+ for history_message in self .chat_history [- fetch_history_len :]:
120121 request = request + history_message
121122 return request
122123
@@ -161,12 +162,14 @@ def generate_backend_output(
161162 action_list = action_list ,
162163 )
163164
164- @staticmethod
165- def _convert_action_to_schema (action_space ):
166- if action_space is None :
167- return None
168- actions = []
169- for action in action_space :
170- new_action = action .to_openai_json_schema ()
171- actions .append ({"type" : "function" , "function" : new_action })
172- return actions
165+
166+ def _convert_action_to_schema (
167+ action_space : list [Action ] | None ,
168+ ) -> list [dict ] | None :
169+ if action_space is None :
170+ return None
171+ actions = []
172+ for action in action_space :
173+ new_action = action .to_openai_json_schema ()
174+ actions .append ({"type" : "function" , "function" : new_action })
175+ return actions
0 commit comments