-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgm_api_server.py
More file actions
89 lines (69 loc) · 2.58 KB
/
Copy pathgm_api_server.py
File metadata and controls
89 lines (69 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import uvicorn
from fastapi import FastAPI
# 确保导入了 BaseModel, Field, List, Dict
from pydantic import BaseModel, Field
from typing import List, Dict, Any
# 导入我们上一步编写的检索器
from rag_retriever import RAGRetriever, KnowledgeBaseType
# --- 2. 初始化 ---
print("正在初始化 FastAPI 应用...")
app = FastAPI(
title="GM Agent RAG API",
description="为LLM Agent提供CoC规则库和模组库的检索服务",
version="1.0.0"
)
print("正在加载知识库索引 (FAISS)...")
try:
retriever = RAGRetriever()
print("[成功] 知识库加载成功!")
except Exception as e:
print(f"[错误] 知识库加载失败: {e}")
retriever = None
# --- 3. 定义 API 数据模型 ---
class SearchQuery(BaseModel):
query_text: str = Field(..., description="要检索的自然语言查询", example="How does Sanity work?")
kb_type: KnowledgeBaseType = Field(..., description="要查询的知识库: 'coc_rules' 或 'current_module'",
example="coc_rules")
top_k: int = Field(3, description="返回最相似的结果数量", example=3)
# --- 重点修复:定义一个精确的响应模型 ---
# 这会告诉 FastAPI 我们要返回 int 和 float
class SearchResult(BaseModel):
kb_type: str
index_id: int
text_chunk: str
distance: float
# --- 4. 定义 API 路由 ---
# --- 重点修复: 将 response_model 改为 List[SearchResult] ---
@app.post("/search", response_model=List[SearchResult])
def search_knowledge_base(query: SearchQuery):
"""
在指定的知识库 (规则库或模组库) 中执行语义检索。
"""
if retriever is None:
print("错误: RAG retriever 未初始化")
return []
print(f"收到 /search 请求: kb='{query.kb_type}', k={query.top_k}, query='{query.query_text}'")
# retriever.search 返回的是 List[Dict]
# FastAPI 会自动验证它是否符合 List[SearchResult] 模型
results = retriever.search(
query_text=query.query_text,
kb_type=query.kb_type,
top_k=query.top_k
)
return results
@app.get("/health")
def health_check():
"""
健康检查接口
"""
if retriever and retriever.indexes:
return {
"status": "ok",
"loaded_kb": list(retriever.indexes.keys())
}
else:
return {"status": "error", "message": "RAG retriever 未加载"}
# --- 5. 启动服务 ---
if __name__ == "__main__":
print("启动 GM Agent API 服务器,访问 http://127.0.0.1:8000/docs 查看 API 文档")
uvicorn.run(app, host="0.0.0.0", port=8000)