|
| 1 | +import os |
| 2 | +import time |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import rich |
| 6 | +import typer |
| 7 | +from rich.console import Console |
| 8 | + |
| 9 | +from aishell.models import AiShellConfigModel |
| 10 | +from aishell.models.language_model import LanguageModel |
| 11 | +from aishell.query_clients import GPT3Client, OfficialChatGPTClient, ReverseEngineeredChatGPTClient |
| 12 | +from aishell.utils import AiShellConfigManager |
| 13 | + |
| 14 | +from .cli_app import cli_app |
| 15 | +from .config_aishell import config_aishell |
| 16 | + |
| 17 | + |
| 18 | +@cli_app.command() |
| 19 | +def aishell_command(question: str, language_model: Optional[LanguageModel] = None): |
| 20 | + config_manager = _get_config_manager() |
| 21 | + config_manager.config_model.language_model = language_model or config_manager.config_model.language_model |
| 22 | + |
| 23 | + query_client = _get_query_client( |
| 24 | + language_model=config_manager.config_model.language_model, |
| 25 | + config_model=config_manager.config_model, |
| 26 | + ) |
| 27 | + |
| 28 | + console = Console() |
| 29 | + |
| 30 | + try: |
| 31 | + with console.status(f''' |
| 32 | +[green] AiShell is thinking of `{question}` ...[/green] |
| 33 | +
|
| 34 | +[dim]AiShell is not responsible for any damage caused by the command executed by the user.[/dim]'''[1:]): |
| 35 | + start_time = time.time() |
| 36 | + response = query_client.query(question) |
| 37 | + end_time = time.time() |
| 38 | + |
| 39 | + execution_time = end_time - start_time |
| 40 | + console.print(f'AiShell: {response}\n\n[dim]Took {execution_time:.2f} seconds to think the command.[/dim]') |
| 41 | + |
| 42 | + will_execute = typer.confirm('Execute this command?') |
| 43 | + if will_execute: |
| 44 | + os.system(response) |
| 45 | + except KeyError: |
| 46 | + rich.print('It looks like the `session_token` is expired. Please reconfigure AiShell.') |
| 47 | + typer.confirm('Reconfigure AiShell?', abort=True) |
| 48 | + config_aishell() |
| 49 | + aishell_command(question=question, language_model=language_model) |
| 50 | + typer.Exit() |
| 51 | + |
| 52 | + |
| 53 | +def _get_config_manager(): |
| 54 | + is_config_file_available = AiShellConfigManager.is_config_file_available(AiShellConfigManager.DEFAULT_CONFIG_PATH) |
| 55 | + if is_config_file_available: |
| 56 | + return AiShellConfigManager(load_config=True) |
| 57 | + else: |
| 58 | + return config_aishell() |
| 59 | + |
| 60 | + |
| 61 | +def _get_query_client(language_model: LanguageModel, config_model: AiShellConfigModel): |
| 62 | + if language_model == LanguageModel.REVERSE_ENGINEERED_CHATGPT: |
| 63 | + return ReverseEngineeredChatGPTClient(config=config_model.chatgpt_config) |
| 64 | + |
| 65 | + if not config_model.openai_api_key: |
| 66 | + raise RuntimeError('OpenAI API key is not provided. Please provide it in the config file.') |
| 67 | + |
| 68 | + if language_model == LanguageModel.GPT3: |
| 69 | + return GPT3Client(openai_api_key=config_model.openai_api_key) |
| 70 | + if language_model == LanguageModel.OFFICIAL_CHATGPT: |
| 71 | + return OfficialChatGPTClient(openai_api_key=config_model.openai_api_key) |
| 72 | + raise NotImplementedError(f'Language model {language_model} is not implemented yet.') |
0 commit comments