2
2
import time
3
3
from typing import Optional
4
4
5
- import rich
6
5
import typer
7
6
from rich .console import Console
7
+ from ygka .models import LanguageModel
8
+ from ygka .query_clients import QueryClientFactory
9
+ from ygka .utils import YGKAConfigManager
8
10
9
- from aishell .models import AiShellConfigModel
10
- from aishell .models .language_model import LanguageModel
11
- from aishell .query_clients import GPT3Client , OfficialChatGPTClient , ReverseEngineeredChatGPTClient
12
- from aishell .utils import AiShellConfigManager
11
+ from aishell .utils import construct_prompt
13
12
14
13
from .cli_app import cli_app
15
14
from .config_aishell import config_aishell
@@ -20,53 +19,29 @@ def aishell_command(question: str, language_model: Optional[LanguageModel] = Non
20
19
config_manager = _get_config_manager ()
21
20
config_manager .config_model .language_model = language_model or config_manager .config_model .language_model
22
21
23
- query_client = _get_query_client (
24
- language_model = config_manager .config_model .language_model ,
25
- config_model = config_manager .config_model ,
26
- )
22
+ query_client = QueryClientFactory (config_model = config_manager .config_model ).create ()
27
23
28
24
console = Console ()
29
25
30
- try :
31
- with console .status (f'''
26
+ with console .status (f'''
32
27
[green] AiShell is thinking of `{ question } ` ...[/green]
33
28
34
29
[dim]AiShell is not responsible for any damage caused by the command executed by the user.[/dim]''' [1 :]):
35
- start_time = time .time ()
36
- response = query_client .query (question )
37
- end_time = time .time ()
30
+ start_time = time .time ()
31
+ response = query_client .query (construct_prompt ( question ) )
32
+ end_time = time .time ()
38
33
39
- execution_time = end_time - start_time
40
- console .print (f'AiShell: { response } \n \n [dim]Took { execution_time :.2f} seconds to think the command.[/dim]' )
34
+ execution_time = end_time - start_time
35
+ console .print (f'AiShell: { response } \n \n [dim]Took { execution_time :.2f} seconds to think the command.[/dim]' )
41
36
42
- will_execute = typer .confirm ('Execute this command?' )
43
- if will_execute :
44
- os .system (response )
45
- except KeyError :
46
- rich .print ('It looks like the `session_token` is expired. Please reconfigure AiShell.' )
47
- typer .confirm ('Reconfigure AiShell?' , abort = True )
48
- config_aishell ()
49
- aishell_command (question = question , language_model = language_model )
50
- typer .Exit ()
37
+ will_execute = typer .confirm ('Execute this command?' )
38
+ if will_execute :
39
+ os .system (response )
51
40
52
41
53
42
def _get_config_manager ():
54
- is_config_file_available = AiShellConfigManager .is_config_file_available (AiShellConfigManager .DEFAULT_CONFIG_PATH )
43
+ is_config_file_available = YGKAConfigManager .is_config_file_available (YGKAConfigManager .DEFAULT_CONFIG_PATH )
55
44
if is_config_file_available :
56
- return AiShellConfigManager (load_config = True )
45
+ return YGKAConfigManager (load_config = True )
57
46
else :
58
47
return config_aishell ()
59
-
60
-
61
- def _get_query_client (language_model : LanguageModel , config_model : AiShellConfigModel ):
62
- if language_model == LanguageModel .REVERSE_ENGINEERED_CHATGPT :
63
- return ReverseEngineeredChatGPTClient (config = config_model .chatgpt_config )
64
-
65
- if not config_model .openai_api_key :
66
- raise RuntimeError ('OpenAI API key is not provided. Please provide it in the config file.' )
67
-
68
- if language_model == LanguageModel .GPT3 :
69
- return GPT3Client (openai_api_key = config_model .openai_api_key )
70
- if language_model == LanguageModel .OFFICIAL_CHATGPT :
71
- return OfficialChatGPTClient (openai_api_key = config_model .openai_api_key )
72
- raise NotImplementedError (f'Language model { language_model } is not implemented yet.' )
0 commit comments