Skip to content

Commit 18392e3

Browse files
committed
Add AiShellConfigModel
1 parent 752661a commit 18392e3

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

aishell/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .aishell_config_model import AiShellConfigModel as AiShellConfigModel
12
from .language_model import LanguageModel as LanguageModel
23
from .open_ai_response_model import OpenAIResponseModel as OpenAIResponseModel
34
from .revchatgpt_chatbot_config_model import RevChatGPTChatbotConfigModel as RevChatGPTChatbotConfigModel
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, root_validator
4+
5+
from .language_model import LanguageModel
6+
from .revchatgpt_chatbot_config_model import RevChatGPTChatbotConfigModel
7+
8+
9+
class AiShellConfigModel(BaseModel):
10+
language_model: LanguageModel = LanguageModel.REVERSE_ENGINEERED_CHATGPT
11+
chatgpt_config: Optional[RevChatGPTChatbotConfigModel] = None
12+
openai_api_key: Optional[str] = None
13+
14+
@root_validator
15+
def check_required_info_provided(cls, values: dict[str, Optional[str]]):
16+
OPENAI_API_KEY_REQUIRED_MODELS = (LanguageModel.GPT3, LanguageModel.OFFICIAL_CHATGPT)
17+
18+
language_model = values.get('language_model')
19+
if language_model in OPENAI_API_KEY_REQUIRED_MODELS:
20+
if not values.get('openai_api_key'):
21+
raise ValueError('openai_api_key should not be none')
22+
elif language_model == LanguageModel.REVERSE_ENGINEERED_CHATGPT:
23+
if not values.get('chatgpt_config'):
24+
raise ValueError('chatgpt_config should not be none')
25+
26+
return values

0 commit comments

Comments
 (0)