Skip to content

Commit 0542967

Browse files
committed
feat: generate sql with history
1 parent 5de9b7d commit 0542967

File tree

25 files changed

+529
-369
lines changed

25 files changed

+529
-369
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""009_modify_chat
2+
3+
Revision ID: 1f077c30e476
4+
Revises: 35d925df4568
5+
Create Date: 2025-05-30 16:11:08.020715
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '1f077c30e476'
15+
down_revision = '35d925df4568'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.drop_table('chat_record')
23+
op.create_table('chat_record',
24+
sa.Column('id', sa.Integer(), sa.Identity(always=True), nullable=False),
25+
sa.Column('chat_id', sa.Integer(), nullable=True),
26+
sa.Column('create_time', sa.DateTime(timezone=True), nullable=True),
27+
sa.Column('finish_time', sa.DateTime(timezone=True), nullable=True),
28+
sa.Column('create_by', sa.BigInteger(), nullable=True),
29+
sa.Column('datasource', sa.Integer(), nullable=False),
30+
sa.Column('engine_type', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
31+
sa.Column('question', sa.Text(), nullable=True),
32+
sa.Column('sql_answer', sa.Text(), nullable=True),
33+
sa.Column('sql', sa.Text(), nullable=True),
34+
sa.Column('sql_exec_result', sa.Text(), nullable=True),
35+
sa.Column('data', sa.Text(), nullable=True),
36+
sa.Column('chart_answer', sa.Text(), nullable=True),
37+
sa.Column('chart', sa.Text(), nullable=True),
38+
sa.Column('full_sql_message', sa.Text(), nullable=True),
39+
sa.Column('full_chart_message', sa.Text(), nullable=True),
40+
sa.Column('finish', sa.Boolean(), nullable=True),
41+
sa.Column('error', sa.Text(), nullable=True),
42+
sa.Column('run_time', sa.Float(), nullable=False),
43+
sa.PrimaryKeyConstraint('id')
44+
)
45+
# ### end Alembic commands ###
46+
47+
48+
def downgrade():
49+
# ### commands auto generated by Alembic - please adjust! ###
50+
op.drop_table('chat_record')
51+
op.create_table('chat_record',
52+
sa.Column('id', sa.Integer(), sa.Identity(always=True), nullable=False),
53+
sa.Column('chat_id', sa.Integer(), nullable=True),
54+
sa.Column('create_time', sa.DateTime(timezone=True), nullable=True),
55+
sa.Column('create_by', sa.BigInteger(), nullable=True),
56+
sa.Column('datasource', sa.Integer(), nullable=False),
57+
sa.Column('engine_type', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
58+
sa.Column('question', sa.Text(), nullable=True),
59+
sa.Column('full_question', sa.Text(), nullable=True),
60+
sa.Column('answer', sa.Text(), nullable=True),
61+
sa.Column('run_time', sa.Float(), nullable=False),
62+
sa.PrimaryKeyConstraint('id')
63+
)
64+
# ### end Alembic commands ###

backend/apps/ai_model/__init__.py

Whitespace-only changes.

backend/apps/ai_model/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# todo

backend/apps/chat/schemas/chat_base_schema.py renamed to backend/apps/ai_model/model_factory.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from abc import ABC, abstractmethod
44
from langchain_core.language_models import BaseLLM as LangchainBaseLLM
55
from langchain_openai import ChatOpenAI
6+
7+
from apps.system.models.system_model import AiModelDetail
8+
9+
610
# from langchain_community.llms import Tongyi, VLLM
711

812
class LLMConfig(BaseModel):
@@ -84,4 +88,16 @@ def create_llm(cls, config: LLMConfig) -> BaseLLM:
8488
@classmethod
8589
def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
8690
"""Register new model type"""
87-
cls._llm_types[model_type] = llm_class
91+
cls._llm_types[model_type] = llm_class
92+
93+
94+
# todo
95+
def get_llm_config(aimodel: AiModelDetail) -> LLMConfig:
96+
config = LLMConfig(
97+
model_type="openai",
98+
model_name=aimodel.name,
99+
api_key=aimodel.api_key,
100+
api_base_url=aimodel.endpoint,
101+
additional_params={"temperature": aimodel.temperature}
102+
)
103+
return config

backend/apps/ai_model/openai/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# todo

backend/apps/chat/api/chat.py

Lines changed: 54 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1+
import asyncio
2+
import json
3+
from typing import List
4+
15
from fastapi import APIRouter, HTTPException
26
from fastapi.responses import StreamingResponse
37
from sqlmodel import select
48

59
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, save_question, save_answer, rename_chat, \
6-
delete_chat
7-
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat
8-
from apps.chat.schemas.chat_base_schema import LLMConfig
9-
from apps.chat.schemas.chat_schema import ChatQuestion
10-
from apps.chat.schemas.llm import AgentService
11-
from apps.datasource.crud.datasource import get_table_obj_by_ds
10+
delete_chat, list_records
11+
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat, ChatQuestion
12+
from apps.chat.task.llm import LLMService
13+
from apps.datasource.crud.datasource import get_table_schema
1214
from apps.datasource.models.datasource import CoreDatasource
1315
from apps.system.models.system_model import AiModelDetail
1416
from common.core.deps import SessionDep, CurrentUser
15-
import json
16-
import asyncio
1717

1818
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
1919

@@ -96,14 +96,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
9696
detail="No available datasource configuration found"
9797
)
9898

99-
record: ChatRecord
100-
try:
101-
record = save_question(session=session, current_user=current_user, question=request_question)
102-
except Exception as e1:
103-
raise HTTPException(
104-
status_code=400,
105-
detail=str(e1)
106-
)
99+
request_question.engine = ds.type_name
107100

108101
# Get available AI model
109102
aimodel = session.exec(select(AiModelDetail).where(
@@ -112,93 +105,54 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
112105
)).first()
113106
if not aimodel:
114107
raise HTTPException(
115-
status_code=400,
108+
status_code=500,
116109
detail="No available AI model configuration found"
117110
)
118111

119-
# Use Tongyi Qianwen
120-
tongyi_config = LLMConfig(
121-
model_type="openai",
122-
model_name=aimodel.name,
123-
api_key=aimodel.api_key,
124-
api_base_url=aimodel.endpoint,
125-
additional_params={"temperature": aimodel.temperature}
126-
)
127-
# llm_service = LLMService(tongyi_config)
128-
llm_service = AgentService(tongyi_config, ds)
129-
130-
# Use Custom VLLM model
131-
""" vllm_config = LLMConfig(
132-
model_type="vllm",
133-
model_name="your-model-path",
134-
additional_params={
135-
"max_new_tokens": 200,
136-
"temperature": 0.3
137-
}
138-
)
139-
vllm_service = LLMService(vllm_config) """
140-
""" result = llm_service.generate_sql(question)
141-
return result """
142-
112+
history_records: List[ChatRecord] = list_records(session=session, current_user=current_user,
113+
chart_id=request_question.chat_id)
143114
# get schema
144-
schema_str = ""
145-
table_objs = get_table_obj_by_ds(session=session, ds=ds)
146-
db_name = table_objs[0].schema
147-
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
148-
for obj in table_objs:
149-
schema_str += f"# Table: {db_name}.{obj.table.table_name}"
150-
table_comment = ''
151-
if obj.table.custom_comment:
152-
table_comment = obj.table.custom_comment.strip()
153-
if table_comment == '':
154-
schema_str += '\n[\n'
155-
else:
156-
schema_str += f", {table_comment}\n[\n"
157-
158-
field_list = []
159-
for field in obj.fields:
160-
field_comment = ''
161-
if field.custom_comment:
162-
field_comment = field.custom_comment.strip()
163-
if field_comment == '':
164-
field_list.append(f"({field.field_name}:{field.field_type})")
165-
else:
166-
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
167-
schema_str += ",\n".join(field_list)
168-
schema_str += '\n]\n'
169-
170-
print(schema_str)
171-
172-
async def event_generator():
173-
all_text = ''
174-
try:
175-
async for chunk in llm_service.async_generate(question, schema_str):
176-
data = json.loads(chunk.replace('data: ', ''))
177-
178-
if data['type'] in ['final', 'tool_result']:
179-
content = data['content']
180-
print('-- ' + content)
181-
all_text += content
182-
for char in content:
183-
yield f"data: {json.dumps({'type': 'char', 'content': char})}\n\n"
184-
await asyncio.sleep(0.05)
185-
186-
if 'html' in data:
187-
yield f"data: {json.dumps({'type': 'html', 'content': data['html']})}\n\n"
188-
else:
189-
yield chunk
190-
191-
except Exception as e:
192-
all_text = 'Exception:' + str(e)
193-
yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
194-
195-
try:
196-
save_answer(session=session, id=record.id, answer=all_text)
197-
except Exception as e:
198-
raise HTTPException(
199-
status_code=500,
200-
detail=str(e)
201-
)
115+
request_question.db_schema = get_table_schema(session=session, ds=ds)
116+
llm_service = LLMService(request_question, history_records, ds, aimodel)
117+
118+
llm_service.init_record(session=session, current_user=current_user)
119+
120+
def run_task():
121+
sql_res = llm_service.generate_sql(session=session)
122+
for chunk in sql_res:
123+
yield json.dumps({'content': chunk, 'type': 'sql'}) + '\n\n'
124+
yield json.dumps({'type': 'info', 'msg': 'sql generated'}) + '\n\n'
125+
126+
# async def event_generator():
127+
# all_text = ''
128+
# try:
129+
# async for chunk in llm_service.async_generate(question, request_question.db_schema):
130+
# data = json.loads(chunk.replace('data: ', ''))
131+
#
132+
# if data['type'] in ['final', 'tool_result']:
133+
# content = data['content']
134+
# print('-- ' + content)
135+
# all_text += content
136+
# for char in content:
137+
# yield f"data: {json.dumps({'type': 'char', 'content': char})}\n\n"
138+
# await asyncio.sleep(0.05)
139+
#
140+
# if 'html' in data:
141+
# yield f"data: {json.dumps({'type': 'html', 'content': data['html']})}\n\n"
142+
# else:
143+
# yield chunk
144+
#
145+
# except Exception as e:
146+
# all_text = 'Exception:' + str(e)
147+
# yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
148+
#
149+
# try:
150+
# save_answer(session=session, id=record.id, answer=all_text)
151+
# except Exception as e:
152+
# raise HTTPException(
153+
# status_code=500,
154+
# detail=str(e)
155+
# )
202156

203157
# return EventSourceResponse(event_generator(), headers={"Content-Type": "text/event-stream"})
204-
return StreamingResponse(event_generator(), media_type="text/event-stream")
158+
return StreamingResponse(run_task(), media_type="text/event-stream")

backend/apps/chat/curd/chat.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import datetime
2-
import json
32
from typing import List
43

5-
from sqlalchemy import text, and_
6-
from sqlmodel import select
4+
from sqlalchemy import and_
75

8-
from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat
9-
from apps.chat.schemas.chat_schema import ChatQuestion
6+
from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion
107
from apps.datasource.models.datasource import CoreDatasource
118
from common.core.deps import SessionDep, CurrentUser
129

@@ -66,6 +63,12 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
6663
return chat_info
6764

6865

66+
def list_records(session: SessionDep, chart_id: int, current_user: CurrentUser) -> List[ChatRecord]:
67+
record_list = session.query(ChatRecord).filter(
68+
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time).all()
69+
return record_list
70+
71+
6972
def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat) -> ChatInfo:
7073
if not create_chat_obj.datasource:
7174
raise Exception("Datasource cannot be None")
@@ -125,11 +128,33 @@ def save_question(session: SessionDep, current_user: CurrentUser, question: Chat
125128
return result
126129

127130

128-
def save_full_question(session: SessionDep, id: int, full_question: str) -> ChatRecord:
131+
def save_full_sql_message(session: SessionDep, record_id: int, full_message: str) -> ChatRecord:
132+
return save_full_sql_message_and_answer(session=session, record_id=record_id, full_message=full_message, answer='')
133+
134+
135+
def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer: str, full_message: str) -> ChatRecord:
136+
if not record_id:
137+
raise Exception("Record id cannot be None")
138+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
139+
record.full_sql_message = full_message
140+
record.sql_answer = answer
141+
142+
result = ChatRecord(**record.model_dump())
143+
144+
session.add(record)
145+
session.flush()
146+
session.refresh(record)
147+
148+
session.commit()
149+
150+
return result
151+
152+
153+
def save_full_chart_message(session: SessionDep, id: int, full_message: str) -> ChatRecord:
129154
if not id:
130155
raise Exception("Record id cannot be None")
131156
record = session.query(ChatRecord).filter(ChatRecord.id == id).first()
132-
record.full_question = full_question
157+
record.full_chart_message = full_message
133158

134159
result = ChatRecord(**record.model_dump())
135160

0 commit comments

Comments
 (0)