-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_agent.py
More file actions
217 lines (182 loc) · 7.57 KB
/
Copy pathbase_agent.py
File metadata and controls
217 lines (182 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# base_agent.py - Agent 抽象基类
"""
所有 Agent 的基类,提供:
1. LLM 客户端初始化(多模型支持:智谱/Groq/Gemini/自定义)
2. 对话历史压缩
3. 抽象接口定义
"""
import openai
import json
import re
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from typing import Dict, List, Optional, Any, Tuple, Generator
from character_card import CharacterCard
from app_config import load_settings
# --- LLM 默认配置 ---
SETTINGS = load_settings()
SELECTED_LLM = SETTINGS.selected_llm()
USE_ZHIPU = SETTINGS.llm_provider.strip().lower() in {"gemma", "zhipu", "zhipuai"}
USE_GROQ = SETTINGS.llm_provider.strip().lower() == "groq"
ZHIPU_API_KEY = SETTINGS.zhipu.api_key
ZHIPU_BASE_URL = SETTINGS.zhipu.base_url
ZHIPU_MODEL = SETTINGS.zhipu.model
ZHIPU_PROVIDER = SETTINGS.zhipu.provider
GROQ_API_KEY = SETTINGS.groq.api_key
GROQ_BASE_URL = SETTINGS.groq.base_url
GROQ_MODEL = SETTINGS.groq.model
GEMINI_API_KEY = SETTINGS.gemini.api_key
GEMINI_BASE_URL = SETTINGS.gemini.base_url
GEMINI_MODEL = SETTINGS.gemini.model
# --- 基础状态模型 ---
class BaseChatState(BaseModel):
"""所有模式共享的最小会话状态"""
session_id: str
character_card: Optional[CharacterCard] = None
messages: List[Dict[str, Any]] = Field(default_factory=list)
current_expression: str = "default"
mood: str = "normal"
class BaseAgent(ABC):
"""Agent 抽象基类"""
def __init__(self, custom_config: Dict = None, skip_warmup: bool = False):
self._init_llm_client(custom_config, skip_warmup)
# 对话压缩配置
self.MAX_MESSAGES = 50
self.KEEP_RECENT = 30
def _init_llm_client(self, custom_config: Dict = None, skip_warmup: bool = False):
"""
初始化 LLM 客户端 — 支持自定义配置和多供应商
Args:
custom_config: 自定义 API 配置
{
'api_key': str,
'base_url': str,
'model': str,
'provider': str
}
skip_warmup: 跳过预热请求
"""
if custom_config:
self.client = openai.OpenAI(
api_key=custom_config.get('api_key'),
base_url=custom_config.get('base_url')
)
self.model = custom_config.get('model')
self.api_name = custom_config.get('provider', 'Custom API')
print(f"[AI] 使用自定义API配置: {self.api_name} ({self.model})")
elif USE_ZHIPU:
self.client = openai.OpenAI(api_key=SETTINGS.zhipu.api_key, base_url=SETTINGS.zhipu.base_url)
self.model = SETTINGS.zhipu.model
self.api_name = SETTINGS.zhipu.provider
elif USE_GROQ:
if not SETTINGS.groq.api_key:
raise ValueError("Missing GROQ_API_KEY environment variable")
self.client = openai.OpenAI(api_key=SETTINGS.groq.api_key, base_url=SETTINGS.groq.base_url)
self.model = SETTINGS.groq.model
self.api_name = SETTINGS.groq.provider
else:
if not SETTINGS.gemini.api_key:
raise ValueError("Missing GEMINI_API_KEY environment variable")
self.client = openai.OpenAI(api_key=SETTINGS.gemini.api_key, base_url=SETTINGS.gemini.base_url)
self.model = SETTINGS.gemini.model
self.api_name = SETTINGS.gemini.provider
# 预热
if not skip_warmup:
print(f"[AI] 正在预热 {self.api_name} API连接...")
try:
self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": "测试连接"}],
max_tokens=10
)
print(f"[AI] [SUCCESS] {self.api_name} API连接预热成功")
except Exception as e:
print(f"[AI] [WARNING] {self.api_name} API连接预热失败: {e}")
print(f"[AI] 提示: 请检查API Key是否正确配置")
@abstractmethod
def build_system_prompt(self) -> str:
"""从角色卡 + 模式特定规则生成 system prompt"""
pass
@abstractmethod
def create_initial_state(self, session_id: str) -> BaseChatState:
"""创建初始会话状态"""
pass
@abstractmethod
def process_message(self, state: BaseChatState, user_input: str) -> Tuple[BaseChatState, Dict]:
"""处理用户消息,返回 (更新后的状态, 响应字典)"""
pass
@abstractmethod
def process_message_stream(self, state: BaseChatState, user_input: str) -> Generator[Dict[str, Any], None, None]:
"""流式处理用户消息,yield SSE 事件"""
pass
def _compress_messages(self, state: BaseChatState) -> BaseChatState:
"""
压缩对话历史,避免超出上下文窗口
策略:
1. 保留 system 消息
2. 保留最近 KEEP_RECENT 条消息
3. 将中间消息总结为摘要
"""
messages = state.messages
if len(messages) <= self.MAX_MESSAGES:
return state
print(f"[COMPRESS] 对话历史过长({len(messages)}条),开始压缩...")
system_msg = None
other_messages = []
for msg in messages:
if msg.get('role') == 'system':
system_msg = msg
else:
other_messages.append(msg)
if len(other_messages) <= self.KEEP_RECENT:
return state
recent_messages = other_messages[-self.KEEP_RECENT:]
messages_to_summarize = other_messages[:-self.KEEP_RECENT]
summary = self._summarize_messages(messages_to_summarize)
compressed = []
if system_msg:
compressed.append(system_msg)
compressed.append({
"role": "assistant",
"content": f"[对话摘要]\n{summary}\n\n---\n[以上是之前的对话摘要,现在继续...]"
})
compressed.extend(recent_messages)
new_state = state.model_copy(deep=True)
new_state.messages = compressed
print(f"[COMPRESS] 压缩完成: {len(messages)} -> {len(compressed)} 条消息")
return new_state
def _summarize_messages(self, messages: List[Dict]) -> str:
"""
通用的对话历史摘要。子类可以 override 提供更精细的摘要。
"""
key_points = []
for msg in messages:
content = msg.get('content', '')
if not content or not isinstance(content, str):
continue
role = msg.get('role', 'unknown')
# 取每条消息的前100字作为摘要线索
snippet = content[:100].replace('\n', ' ')
if role == 'user':
key_points.append(f"- 用户: {snippet}")
elif role == 'assistant':
key_points.append(f"- 角色: {snippet}")
if not key_points:
return "之前进行了一些对话。"
# 限制摘要条数
if len(key_points) > 20:
key_points = key_points[:5] + ["- ...(中间省略)..."] + key_points[-10:]
return "\n".join(key_points)
@staticmethod
def parse_expression(text: str) -> Tuple[str, str]:
"""
从文本中提取 [expression:xxx] 标签
Returns:
(cleaned_text, expression_name)
"""
match = re.search(r'\[expression:(\w+)\]', text)
if match:
expression = match.group(1)
cleaned = re.sub(r'\s*\[expression:\w+\]\s*', '', text).strip()
return cleaned, expression
return text, "default"