1- import asyncio
21import json
32import logging
4- import re
53import warnings
6- from typing import Any , List , Union , AsyncGenerator , Dict
4+ from typing import Any , List , Union , Dict
75
86from langchain_community .utilities import SQLDatabase
97from langchain_core .language_models import BaseLLM
2119
2220warnings .filterwarnings ("ignore" )
2321
22+ base_message_count_limit = 5
23+
2424
2525class LLMService :
2626 ds : CoreDatasource
@@ -55,34 +55,51 @@ def __init__(self, chat_question: ChatQuestion, history_records: List[ChatRecord
5555 last_chart_messages [- 1 ].full_chart_message
5656
5757 last_sql_messages : List [dict [str , Any ]] = json .loads (last_sql_message_str )
58+
59+ # todo maybe can configure
60+ count_limit = 0 - base_message_count_limit
61+
62+ self .sql_message = []
5863 if last_sql_messages is None or len (last_sql_messages ) == 0 :
5964 # add sys prompt
6065 self .sql_message .append (SystemMessage (content = self .chat_question .sql_sys_question ()))
6166 else :
67+ # limit count
6268 for last_sql_message in last_sql_messages :
69+ if last_sql_message ['type' ] == 'system' :
70+ _msg = SystemMessage (content = last_sql_message ['content' ])
71+ self .sql_message .append (_msg )
72+ break
73+ for last_sql_message in last_sql_messages [count_limit :]:
6374 _msg : BaseMessage
6475 if last_sql_message ['type' ] == 'human' :
6576 _msg = HumanMessage (content = last_sql_message ['content' ])
77+ self .sql_message .append (_msg )
6678 elif last_sql_message ['type' ] == 'ai' :
6779 _msg = AIMessage (content = last_sql_message ['content' ])
68- else :
69- _msg = SystemMessage (content = last_sql_message ['content' ])
70- self .sql_message .append (_msg )
80+ self .sql_message .append (_msg )
7181
7282 last_chart_messages : List [dict [str , Any ]] = json .loads (last_chart_message_str )
83+
84+ self .chart_message = []
7385 if last_chart_messages is None or len (last_chart_messages ) == 0 :
7486 # add sys prompt
7587 self .chart_message .append (SystemMessage (content = self .chat_question .chart_sys_question ()))
7688 else :
89+ # limit count
90+ for last_chart_message in last_chart_messages :
91+ if last_chart_message ['type' ] == 'system' :
92+ _msg = SystemMessage (content = last_chart_message ['content' ])
93+ self .chart_message .append (_msg )
94+ break
7795 for last_chart_message in last_chart_messages :
7896 _msg : BaseMessage
7997 if last_chart_message ['type' ] == 'human' :
8098 _msg = HumanMessage (content = last_chart_message ['content' ])
99+ self .chart_message .append (_msg )
81100 elif last_chart_message ['type' ] == 'ai' :
82101 _msg = AIMessage (content = last_chart_message ['content' ])
83- else :
84- _msg = SystemMessage (content = last_chart_message ['content' ])
85- self .chart_message .append (_msg )
102+ self .chart_message .append (_msg )
86103
87104 def init_record (self , session : SessionDep , current_user : CurrentUser ) -> ChatRecord :
88105 self .record = save_question (session = session , current_user = current_user , question = self .chat_question )
@@ -98,6 +115,7 @@ def generate_sql(self, session: SessionDep):
98115 full_sql_text = ''
99116 res = self .llm .stream (self .sql_message )
100117 for chunk in res :
118+ print (chunk )
101119 if isinstance (chunk , dict ):
102120 full_sql_text += chunk ['content' ]
103121 yield chunk ['content' ]
@@ -196,9 +214,10 @@ def save_error(self, session: SessionDep, message: str):
196214 return save_error_message (session = session , record_id = self .record .id , message = message )
197215
198216 def save_sql_data (self , session : SessionDep , data_obj : Dict [str , Any ]):
199- return save_sql_exec_data (session = session , record_id = self .record .id , data = json .dumps (data_obj , ensure_ascii = False ))
217+ return save_sql_exec_data (session = session , record_id = self .record .id ,
218+ data = json .dumps (data_obj , ensure_ascii = False ))
200219
201- def finish (self ,session : SessionDep ):
220+ def finish (self , session : SessionDep ):
202221 return finish_record (session = session , record_id = self .record .id )
203222
204223 def execute_sql (self , sql : str ):
@@ -237,7 +256,7 @@ def extract_nested_json(text):
237256 pass
238257 else :
239258 stack = [] # 括号不匹配则重置
240- if results [0 ]:
259+ if len ( results ) > 0 and results [0 ]:
241260 return results [0 ]
242261 return None
243262
0 commit comments