20
20
"""
21
21
import json
22
22
23
- import openai
24
23
from langchain .chains import LLMChain
25
24
from langchain .chat_models import ChatOpenAI
26
25
from langchain .memory import ConversationBufferMemory
31
30
SystemMessagePromptTemplate ,
32
31
)
33
32
from openai_api .common .conf import settings
34
- from openai_api .common .const import (
33
+ from openai_api .common .const import ( # VALID_EMBEDDING_MODELS,
35
34
VALID_CHAT_COMPLETION_MODELS ,
36
- VALID_EMBEDDING_MODELS ,
37
- OpenAIEndPoint ,
38
35
OpenAIMessageKeys ,
36
+ OpenAIObjectTypes ,
39
37
OpenAIResponseCodes ,
40
38
)
41
39
from openai_api .common .exceptions import EXCEPTION_MAP
49
47
http_response_factory ,
50
48
parse_request ,
51
49
)
52
- from openai_api .common .validators import (
50
+ from openai_api .common .validators import ( # validate_embedding_request,
53
51
validate_completion_request ,
54
- validate_embedding_request ,
55
52
validate_item ,
56
53
validate_messages ,
57
54
validate_request_body ,
63
60
# from langchain.schema.messages import BaseMessage
64
61
65
62
66
- ###############################################################################
67
- # ENVIRONMENT CREDENTIALS
68
- ###############################################################################
69
- openai .organization = settings .openai_api_organization
70
- openai .api_key = settings .openai_api_key
71
-
72
-
73
63
# pylint: disable=too-many-locals
74
64
# pylint: disable=unused-argument
75
65
def handler (event , context ):
@@ -86,20 +76,22 @@ def handler(event, context):
86
76
# ----------------------------------------------------------------------
87
77
request_body = get_request_body (event = event )
88
78
validate_request_body (request_body = request_body )
89
- end_point , model , messages , input_text , temperature , max_tokens = parse_request (request_body )
90
- validate_messages (request_body = request_body )
79
+ object_type , model , messages , input_text , temperature , max_tokens = parse_request (request_body )
91
80
request_meta_data = {
92
81
"request_meta_data" : {
93
- "lambda" : "lambda_langchain " ,
82
+ "lambda" : "lambda_openai_v2 " ,
94
83
"model" : model ,
95
- "end_point " : end_point ,
84
+ "object_type " : object_type ,
96
85
"temperature" : temperature ,
97
86
"max_tokens" : max_tokens ,
87
+ "input_text" : input_text ,
98
88
}
99
89
}
100
90
101
- match end_point :
102
- case OpenAIEndPoint .ChatCompletion :
91
+ validate_messages (request_body = request_body )
92
+
93
+ match object_type :
94
+ case OpenAIObjectTypes .ChatCompletion :
103
95
# pylint: disable=pointless-string-statement
104
96
"""
105
97
Need to keep in mind that this is a stateless operation. We have to bring
@@ -120,7 +112,9 @@ def handler(event, context):
120
112
121
113
# 2. initialize the LangChain ChatOpenAI model
122
114
# -------------------------------------------------------------
123
- llm = ChatOpenAI (model = model , temperature = temperature , max_tokens = max_tokens )
115
+ llm = ChatOpenAI (
116
+ model = model , temperature = temperature , max_tokens = max_tokens , api_key = settings .openai_api_key
117
+ )
124
118
prompt = ChatPromptTemplate (
125
119
messages = [
126
120
SystemMessagePromptTemplate .from_template (system_message ),
@@ -159,30 +153,34 @@ def handler(event, context):
159
153
conversation_response = json .loads (conversation .memory .json ())
160
154
openai_results = conversation_response
161
155
162
- case OpenAIEndPoint .Embedding :
156
+ case OpenAIObjectTypes .Embedding :
163
157
# https://platform.openai.com/docs/guides/embeddings/embeddings
164
- validate_item (
165
- item = model ,
166
- valid_items = VALID_EMBEDDING_MODELS ,
167
- item_type = "Embedding models" ,
168
- )
169
- validate_embedding_request (request_body )
170
- openai_results = openai .Embedding .create (input = input_text , model = model )
171
-
172
- case OpenAIEndPoint .Image :
158
+ raise NotImplementedError ("Refactoring of Embedding API v1 is in progress." )
159
+ # validate_item(
160
+ # item=model,
161
+ # valid_items=VALID_EMBEDDING_MODELS,
162
+ # item_type="Embedding models",
163
+ # )
164
+ # validate_embedding_request(request_body)
165
+ # openai_results = openai.Embedding.create(input=input_text, model=model)
166
+
167
+ case OpenAIObjectTypes .Image :
173
168
# https://platform.openai.com/docs/guides/images
174
- n = request_body .get ("n" , settings .openai_endpoint_image_n ) # pylint: disable=invalid-name
175
- size = request_body .get ("size" , settings .openai_endpoint_image_size )
176
- return openai .Image .create (prompt = input_text , n = n , size = size )
169
+ raise NotImplementedError ("Refactoring of Image API v1 is in progress." )
170
+ # n = request_body.get("n", settings.openai_endpoint_image_n) # pylint: disable=invalid-name
171
+ # size = request_body.get("size", settings.openai_endpoint_image_size)
172
+ # return openai.Image.create(prompt=input_text, n=n, size=size)
177
173
178
- case OpenAIEndPoint .Moderation :
174
+ case OpenAIObjectTypes .Moderation :
179
175
# https://platform.openai.com/docs/guides/moderation
180
- openai_results = openai .Moderation .create (input = input_text )
176
+ raise NotImplementedError ("Refactoring of Moderation API v1 is in progress." )
177
+ # openai_results = openai.Moderation.create(input=input_text)
181
178
182
- case OpenAIEndPoint .Models :
183
- openai_results = openai .Model .retrieve (model ) if model else openai .Model .list ()
179
+ case OpenAIObjectTypes .Models :
180
+ raise NotImplementedError ("Refactoring of Models API v1 is in progress." )
181
+ # openai_results = openai.Model.retrieve(model) if model else openai.Model.list()
184
182
185
- case OpenAIEndPoint .Audio :
183
+ case OpenAIObjectTypes .Audio :
186
184
raise NotImplementedError ("Audio support is coming soon" )
187
185
188
186
# handle anything that went wrong
0 commit comments