forked from sinaptik-ai/pandas-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgoogle_vertexai.py
174 lines (135 loc) · 5.3 KB
/
google_vertexai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import Optional
from pandasai.core.code_execution.environment import import_dependency
from pandasai.exceptions import UnsupportedModelError
from pandasai.helpers.memory import Memory
from .base import BaseGoogle
class GoogleVertexAI(BaseGoogle):
"""Google Palm Vertexai LLM
BaseGoogle class is extended for Google Palm model using Vertexai.
The default model support at the moment is text-bison-001.
However, user can choose to use code-bison-001 too.
"""
_supported_code_models = [
"code-bison",
"code-bison-32k",
"code-bison-32k@002",
"code-bison@001",
"code-bison@002",
]
_supported_text_models = [
"text-bison",
"text-bison-32k",
"text-bison-32k@002",
"text-bison@001",
"text-bison@002",
"text-unicorn@001",
]
_supported_generative_models = [
"gemini-pro",
]
_supported_code_chat_models = ["codechat-bison@001", "codechat-bison@002"]
def __init__(
self, project_id: str, location: str, model: Optional[str] = None, **kwargs
):
"""
A init class to implement the Google Vertexai Models
Args:
project_id (str): GCP project
location (str): GCP project Location
model Optional (str): Model to use Default to text-bison@001
**kwargs: Arguments to control the Model Parameters
"""
self.model = model or "text-bison@001"
self._configure(project_id, location)
self.project_id = project_id
self.location = location
self._set_params(**kwargs)
def _configure(self, project_id: str, location: str):
"""
Configure Google VertexAi. Set value `self.vertexai` attribute.
Args:
project_id (str): GCP Project.
location (str): Location of Project.
Returns:
None.
"""
err_msg = "Install google-cloud-aiplatform for Google Vertexai"
vertexai = import_dependency("vertexai", extra=err_msg)
vertexai.init(project=project_id, location=location)
self.vertexai = vertexai
def _valid_params(self):
"""Returns if the Parameters are valid or Not"""
return super()._valid_params() + ["model"]
def _validate(self):
"""
A method to Validate the Model
"""
super()._validate()
if not self.model:
raise ValueError("model is required.")
def _generate_text(self, prompt: str, memory: Optional[Memory] = None) -> str:
"""
Generates text for prompt.
Args:
prompt (str): A string representation of the prompt.
Returns:
str: LLM response.
"""
self._validate()
updated_prompt = self.prepend_system_prompt(prompt, memory)
self.last_prompt = updated_prompt
if self.model in self._supported_code_models:
from vertexai.preview.language_models import CodeGenerationModel
code_generation = CodeGenerationModel.from_pretrained(self.model)
completion = code_generation.predict(
prefix=prompt,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_text_models:
from vertexai.preview.language_models import TextGenerationModel
text_generation = TextGenerationModel.from_pretrained(self.model)
completion = text_generation.predict(
prompt=updated_prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_generative_models:
from vertexai.preview.generative_models import GenerativeModel
model = GenerativeModel(self.model)
responses = model.generate_content(
[updated_prompt],
generation_config={
"max_output_tokens": self.max_output_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
},
)
completion = responses.candidates[0].content.parts[0]
elif self.model in self._supported_code_chat_models:
from vertexai.language_models import ChatMessage, CodeChatModel
code_chat_model = CodeChatModel.from_pretrained(self.model)
messages = []
for message in memory.all():
if message["is_user"]:
messages.append(
ChatMessage(author="user", content=message["message"])
)
else:
messages.append(
ChatMessage(author="model", content=message["message"])
)
chat = code_chat_model.start_chat(
context=memory.agent_description, message_history=messages
)
response = chat.send_message(prompt)
return response.text
else:
raise UnsupportedModelError(self.model)
return completion.text
@property
def type(self) -> str:
return "google-vertexai"