Skip to content

Commit ce1ee1a

Browse files
committed
Add support for session_token
1 parent 64ea98d commit ce1ee1a

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

aishell/query_clients/reverse_engineered_chatgpt_client.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,28 @@
33

44
from revChatGPT.V1 import Chatbot
55

6+
from aishell.exceptions import UnauthorizedAccessError
67
from aishell.utils import make_executable_command
78

89
from .query_client import QueryClient
910

1011

1112
class ReverseEngineeredChatGPTClient(QueryClient):
12-
access_key: str
13+
config: dict[str, str] = {}
1314

1415
def __init__(
1516
self,
16-
chatgpt_access_key: Optional[str] = None,
17+
access_token: Optional[str] = None,
18+
session_token: Optional[str] = None,
1719
):
18-
super().__init__()
19-
CHATGPT_ACCESS_KEY = os.environ.get('CHATGPT_ACCESS_KEY')
20-
21-
if chatgpt_access_key is not None:
22-
self.access_key = chatgpt_access_key
23-
elif CHATGPT_ACCESS_KEY is not None:
24-
self.access_key = CHATGPT_ACCESS_KEY
20+
CHATGPT_ACCESS_TOKEN = os.environ.get('CHATGPT_ACCESS_TOKEN', access_token)
21+
CHATGPT_SESSION_TOKEN = os.environ.get('CHATGPT_SESSION_TOKEN', session_token)
22+
if CHATGPT_ACCESS_TOKEN is not None:
23+
self.config['access_token'] = CHATGPT_ACCESS_TOKEN
24+
elif CHATGPT_SESSION_TOKEN is not None:
25+
self.config['session_token'] = CHATGPT_SESSION_TOKEN
2526
else:
26-
raise Exception('access_key should not be none')
27+
raise UnauthorizedAccessError('No access token or session token provided.')
2728

2829
def _construct_prompt(self, text: str) -> str:
2930
return f'''You are now a translater from human language to {os.uname()[0]} shell command.
@@ -32,7 +33,7 @@ def _construct_prompt(self, text: str) -> str:
3233

3334
def query(self, prompt: str) -> str:
3435
prompt = self._construct_prompt(prompt)
35-
chatbot = Chatbot(config={'access_token': self.access_key})
36+
chatbot = Chatbot(config=self.config)
3637

3738
response_text = ''
3839
for data in chatbot.ask(prompt):

0 commit comments

Comments
 (0)