Skip to content

Commit a4e456c

Browse files
committed
feat: chat view
1 parent 20ec595 commit a4e456c

File tree

9 files changed

+349
-243
lines changed

9 files changed

+349
-243
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,25 @@ class AiModelQuestion(BaseModel):
7070
engine: str = ""
7171
db_schema: str = ""
7272
sql: str = ""
73+
rule: str = """
74+
请逐步推理后给出答案:
75+
推理过程中不需要输出JSON,仅在最终结果内输出符合要求的JSON
76+
步骤1: [思考内容]
77+
步骤2: [思考内容]
78+
最终答案: [结果]
79+
"""
7380

7481
def sql_sys_question(self):
7582
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question)
7683

7784
def sql_user_question(self):
78-
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question)
85+
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, rule=self.rule)
7986

8087
def chart_sys_question(self):
8188
return get_chart_template()['system'].format(sql=self.sql, question=self.question)
8289

8390
def chart_user_question(self):
84-
return get_chart_template()['user'].format(sql=self.sql, question=self.question)
91+
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule)
8592

8693

8794
class ChatQuestion(AiModelQuestion):

backend/apps/chat/task/llm.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import asyncio
21
import json
32
import logging
4-
import re
53
import warnings
6-
from typing import Any, List, Union, AsyncGenerator, Dict
4+
from typing import Any, List, Union, Dict
75

86
from langchain_community.utilities import SQLDatabase
97
from langchain_core.language_models import BaseLLM
@@ -21,6 +19,8 @@
2119

2220
warnings.filterwarnings("ignore")
2321

22+
base_message_count_limit = 5
23+
2424

2525
class LLMService:
2626
ds: CoreDatasource
@@ -55,34 +55,51 @@ def __init__(self, chat_question: ChatQuestion, history_records: List[ChatRecord
5555
last_chart_messages[-1].full_chart_message
5656

5757
last_sql_messages: List[dict[str, Any]] = json.loads(last_sql_message_str)
58+
59+
# todo maybe can configure
60+
count_limit = 0 - base_message_count_limit
61+
62+
self.sql_message = []
5863
if last_sql_messages is None or len(last_sql_messages) == 0:
5964
# add sys prompt
6065
self.sql_message.append(SystemMessage(content=self.chat_question.sql_sys_question()))
6166
else:
67+
# limit count
6268
for last_sql_message in last_sql_messages:
69+
if last_sql_message['type'] == 'system':
70+
_msg = SystemMessage(content=last_sql_message['content'])
71+
self.sql_message.append(_msg)
72+
break
73+
for last_sql_message in last_sql_messages[count_limit:]:
6374
_msg: BaseMessage
6475
if last_sql_message['type'] == 'human':
6576
_msg = HumanMessage(content=last_sql_message['content'])
77+
self.sql_message.append(_msg)
6678
elif last_sql_message['type'] == 'ai':
6779
_msg = AIMessage(content=last_sql_message['content'])
68-
else:
69-
_msg = SystemMessage(content=last_sql_message['content'])
70-
self.sql_message.append(_msg)
80+
self.sql_message.append(_msg)
7181

7282
last_chart_messages: List[dict[str, Any]] = json.loads(last_chart_message_str)
83+
84+
self.chart_message = []
7385
if last_chart_messages is None or len(last_chart_messages) == 0:
7486
# add sys prompt
7587
self.chart_message.append(SystemMessage(content=self.chat_question.chart_sys_question()))
7688
else:
89+
# limit count
90+
for last_chart_message in last_chart_messages:
91+
if last_chart_message['type'] == 'system':
92+
_msg = SystemMessage(content=last_chart_message['content'])
93+
self.chart_message.append(_msg)
94+
break
7795
for last_chart_message in last_chart_messages:
7896
_msg: BaseMessage
7997
if last_chart_message['type'] == 'human':
8098
_msg = HumanMessage(content=last_chart_message['content'])
99+
self.chart_message.append(_msg)
81100
elif last_chart_message['type'] == 'ai':
82101
_msg = AIMessage(content=last_chart_message['content'])
83-
else:
84-
_msg = SystemMessage(content=last_chart_message['content'])
85-
self.chart_message.append(_msg)
102+
self.chart_message.append(_msg)
86103

87104
def init_record(self, session: SessionDep, current_user: CurrentUser) -> ChatRecord:
88105
self.record = save_question(session=session, current_user=current_user, question=self.chat_question)
@@ -98,6 +115,7 @@ def generate_sql(self, session: SessionDep):
98115
full_sql_text = ''
99116
res = self.llm.stream(self.sql_message)
100117
for chunk in res:
118+
print(chunk)
101119
if isinstance(chunk, dict):
102120
full_sql_text += chunk['content']
103121
yield chunk['content']
@@ -196,9 +214,10 @@ def save_error(self, session: SessionDep, message: str):
196214
return save_error_message(session=session, record_id=self.record.id, message=message)
197215

198216
def save_sql_data(self, session: SessionDep, data_obj: Dict[str, Any]):
199-
return save_sql_exec_data(session=session, record_id=self.record.id, data=json.dumps(data_obj, ensure_ascii=False))
217+
return save_sql_exec_data(session=session, record_id=self.record.id,
218+
data=json.dumps(data_obj, ensure_ascii=False))
200219

201-
def finish(self,session: SessionDep):
220+
def finish(self, session: SessionDep):
202221
return finish_record(session=session, record_id=self.record.id)
203222

204223
def execute_sql(self, sql: str):
@@ -237,7 +256,7 @@ def extract_nested_json(text):
237256
pass
238257
else:
239258
stack = [] # 括号不匹配则重置
240-
if results[0]:
259+
if len(results) > 0 and results[0]:
241260
return results[0]
242261
return None
243262

backend/template.yaml

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,23 @@ template:
1515
{{"success":false,"message":"抱歉,我无法回答您的问题。"}}
1616
- 如果根据提供的表结构不能生成符合问题与条件的SQL,回答:
1717
{{"success":false,"message":"无法生成SQL的原因"}}
18-
user: |
19-
表结构:
18+
- 如果问题是图表展示相关且与生成SQL查询无关时,请参考上一次回答的SQL来生成SQL
19+
- 如果问题是图表展示相关,可参考的图表类型为表格(table)、条形图(bar)、折线图(line)或饼图(pie),返回的JSON:
20+
{{"success":true,"sql":"生成的SQL语句","chart-type":"选择的图表类型(table/bar/line/pie)"}}
21+
user: |
22+
### 表结构:
2023
{schema}
2124
22-
问题:
25+
### 问题:
2326
{question}
27+
28+
### 其他规则:
29+
{rule}
2430
chart:
2531
system: |
2632
### 说明:
2733
您的任务是通过给定的问题和SQL生成 JSON 以进行数据可视化。
28-
请遵守以下规则
34+
请遵守以下规则:
2935
- 如果需要表格,则生成的 JSON 格式应为:
3036
{{"type":"table", "title": "标题", "columns": [{{"name":"中文字段名1", "value": "SQL 查询列 1(有别名用别名)"}}, {{"name": "中文字段名 2", "value": "SQL 查询列 2(有别名用别名)"}}]}}
3137
必须从 SQL 查询列中提取“columns”。
@@ -36,22 +42,25 @@ template:
3642
{{"type":"line", "title": "标题", "axis": {{"x": {{"name":"x轴的中文名称","value": "x轴的SQL查询列(有别名用别名)"}}, "y": {{"name":"y轴的中文名称","value": "y轴的SQL查询列(有别名用别名)"}}}}
3743
其中“x”和“y”必须从SQL查询列中提取。
3844
- 如果需要饼图,则生成的 JSON 格式应为:
39-
{{"type":"pie", "title": "标题", "column": {{"name":"中文字段名","value":"SQL查询列1(有别名用别名)"}}}}
45+
{{"type":"pie", "title": "标题", "column": [{{"name":"中文字段名1","value":"SQL查询列1(有别名用别名)"}},{{"name":"中文字段名2","value":"SQL查询列2(有别名用别名)"}}]}}
4046
其中“column”必须从SQL查询列中提取。
4147
- 如果答案未知或者与生成JSON无关,则生成的 JSON 格式应为:
4248
{{"type":"error", "reason": "抱歉,我无法回答您的问题。"}}
4349
44-
### 示例
50+
### 示例:
4551
如果 SQL 为: SELECT products_sales_data.category, AVG(products_sales_data.price) AS average_price FROM products_sales_data GROUP BY products_sales_data.category;
4652
问题是:每个商品分类的平均价格
4753
则生成的 JSON 可以是: {{"type":"table", "title": "每个商品分类的平均价格", "columns": [{{"name":"商品分类","value":"category"}}, {{"name":"平均价格","value":"average_price"}}]}}
4854
49-
### 响应
55+
### 响应:
5056
根据您的指示,这是我生成的与 问题 和 sql 匹配的 JSON:
5157
```json
5258
user: |
5359
### SQL:
5460
{sql}
5561
56-
### 问题:
57-
{question}
62+
### 问题:
63+
{question}
64+
65+
### 其他规则:
66+
{rule}

frontend/components.d.ts

Lines changed: 0 additions & 62 deletions
This file was deleted.

frontend/package.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
"dependencies": {
1212
"@npkg/tinymce-plugins": "^0.0.7",
1313
"@tinymce/tinymce-vue": "^5.1.0",
14-
"tinymce": "^5.8.2",
14+
"dayjs": "^1.11.13",
1515
"lodash": "^4.17.21",
1616
"snowflake-id": "^1.1.0",
17+
"tinymce": "^5.8.2",
1718
"vue": "^3.5.13",
1819
"vue-router": "^4.5.0",
19-
"web-storage-cache": "^1.1.1",
20-
"dayjs": "^1.11.13"
20+
"web-storage-cache": "^1.1.1"
2121
},
2222
"devDependencies": {
2323
"@element-plus/icons-vue": "^2.3.1",

0 commit comments

Comments
 (0)