Skip to content

Commit 9a2b00d

Browse files
committed
feat: add unit test for langchain
1 parent c8ceff3 commit 9a2b00d

20 files changed

+185
-1586
lines changed

api/terraform/python/openai_api/lambda_langchain/lambda_handler.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
"""
2121
import json
2222

23-
import openai
2423
from langchain.chains import LLMChain
2524
from langchain.chat_models import ChatOpenAI
2625
from langchain.memory import ConversationBufferMemory
@@ -31,11 +30,10 @@
3130
SystemMessagePromptTemplate,
3231
)
3332
from openai_api.common.conf import settings
34-
from openai_api.common.const import (
33+
from openai_api.common.const import ( # VALID_EMBEDDING_MODELS,
3534
VALID_CHAT_COMPLETION_MODELS,
36-
VALID_EMBEDDING_MODELS,
37-
OpenAIEndPoint,
3835
OpenAIMessageKeys,
36+
OpenAIObjectTypes,
3937
OpenAIResponseCodes,
4038
)
4139
from openai_api.common.exceptions import EXCEPTION_MAP
@@ -49,9 +47,8 @@
4947
http_response_factory,
5048
parse_request,
5149
)
52-
from openai_api.common.validators import (
50+
from openai_api.common.validators import ( # validate_embedding_request,
5351
validate_completion_request,
54-
validate_embedding_request,
5552
validate_item,
5653
validate_messages,
5754
validate_request_body,
@@ -63,13 +60,6 @@
6360
# from langchain.schema.messages import BaseMessage
6461

6562

66-
###############################################################################
67-
# ENVIRONMENT CREDENTIALS
68-
###############################################################################
69-
openai.organization = settings.openai_api_organization
70-
openai.api_key = settings.openai_api_key
71-
72-
7363
# pylint: disable=too-many-locals
7464
# pylint: disable=unused-argument
7565
def handler(event, context):
@@ -86,20 +76,22 @@ def handler(event, context):
8676
# ----------------------------------------------------------------------
8777
request_body = get_request_body(event=event)
8878
validate_request_body(request_body=request_body)
89-
end_point, model, messages, input_text, temperature, max_tokens = parse_request(request_body)
90-
validate_messages(request_body=request_body)
79+
object_type, model, messages, input_text, temperature, max_tokens = parse_request(request_body)
9180
request_meta_data = {
9281
"request_meta_data": {
93-
"lambda": "lambda_langchain",
82+
"lambda": "lambda_openai_v2",
9483
"model": model,
95-
"end_point": end_point,
84+
"object_type": object_type,
9685
"temperature": temperature,
9786
"max_tokens": max_tokens,
87+
"input_text": input_text,
9888
}
9989
}
10090

101-
match end_point:
102-
case OpenAIEndPoint.ChatCompletion:
91+
validate_messages(request_body=request_body)
92+
93+
match object_type:
94+
case OpenAIObjectTypes.ChatCompletion:
10395
# pylint: disable=pointless-string-statement
10496
"""
10597
Need to keep in mind that this is a stateless operation. We have to bring
@@ -120,7 +112,9 @@ def handler(event, context):
120112

121113
# 2. initialize the LangChain ChatOpenAI model
122114
# -------------------------------------------------------------
123-
llm = ChatOpenAI(model=model, temperature=temperature, max_tokens=max_tokens)
115+
llm = ChatOpenAI(
116+
model=model, temperature=temperature, max_tokens=max_tokens, api_key=settings.openai_api_key
117+
)
124118
prompt = ChatPromptTemplate(
125119
messages=[
126120
SystemMessagePromptTemplate.from_template(system_message),
@@ -159,30 +153,34 @@ def handler(event, context):
159153
conversation_response = json.loads(conversation.memory.json())
160154
openai_results = conversation_response
161155

162-
case OpenAIEndPoint.Embedding:
156+
case OpenAIObjectTypes.Embedding:
163157
# https://platform.openai.com/docs/guides/embeddings/embeddings
164-
validate_item(
165-
item=model,
166-
valid_items=VALID_EMBEDDING_MODELS,
167-
item_type="Embedding models",
168-
)
169-
validate_embedding_request(request_body)
170-
openai_results = openai.Embedding.create(input=input_text, model=model)
171-
172-
case OpenAIEndPoint.Image:
158+
raise NotImplementedError("Refactoring of Embedding API v1 is in progress.")
159+
# validate_item(
160+
# item=model,
161+
# valid_items=VALID_EMBEDDING_MODELS,
162+
# item_type="Embedding models",
163+
# )
164+
# validate_embedding_request(request_body)
165+
# openai_results = openai.Embedding.create(input=input_text, model=model)
166+
167+
case OpenAIObjectTypes.Image:
173168
# https://platform.openai.com/docs/guides/images
174-
n = request_body.get("n", settings.openai_endpoint_image_n) # pylint: disable=invalid-name
175-
size = request_body.get("size", settings.openai_endpoint_image_size)
176-
return openai.Image.create(prompt=input_text, n=n, size=size)
169+
raise NotImplementedError("Refactoring of Image API v1 is in progress.")
170+
# n = request_body.get("n", settings.openai_endpoint_image_n) # pylint: disable=invalid-name
171+
# size = request_body.get("size", settings.openai_endpoint_image_size)
172+
# return openai.Image.create(prompt=input_text, n=n, size=size)
177173

178-
case OpenAIEndPoint.Moderation:
174+
case OpenAIObjectTypes.Moderation:
179175
# https://platform.openai.com/docs/guides/moderation
180-
openai_results = openai.Moderation.create(input=input_text)
176+
raise NotImplementedError("Refactoring of Moderation API v1 is in progress.")
177+
# openai_results = openai.Moderation.create(input=input_text)
181178

182-
case OpenAIEndPoint.Models:
183-
openai_results = openai.Model.retrieve(model) if model else openai.Model.list()
179+
case OpenAIObjectTypes.Models:
180+
raise NotImplementedError("Refactoring of Models API v1 is in progress.")
181+
# openai_results = openai.Model.retrieve(model) if model else openai.Model.list()
184182

185-
case OpenAIEndPoint.Audio:
183+
case OpenAIObjectTypes.Audio:
186184
raise NotImplementedError("Audio support is coming soon")
187185

188186
# handle anything that went wrong

api/terraform/python/openai_api/lambda_langchain/tests/.env.test_01

Lines changed: 0 additions & 8 deletions
This file was deleted.

api/terraform/python/openai_api/lambda_langchain/tests/.env.test_nulls

Lines changed: 0 additions & 8 deletions
This file was deleted.

api/terraform/python/openai_api/lambda_langchain/tests/mock_data/json/apigateway_index_lambda_event.json

Lines changed: 0 additions & 40 deletions
This file was deleted.

api/terraform/python/openai_api/lambda_langchain/tests/mock_data/json/apigateway_index_lambda_event_bad_event.json

Lines changed: 0 additions & 40 deletions
This file was deleted.

api/terraform/python/openai_api/lambda_langchain/tests/mock_data/json/apigateway_index_lambda_event_bad_source.json

Lines changed: 0 additions & 40 deletions
This file was deleted.

api/terraform/python/openai_api/lambda_langchain/tests/mock_data/json/apigateway_index_lambda_event_no_records.json

Lines changed: 0 additions & 3 deletions
This file was deleted.

api/terraform/python/openai_api/lambda_langchain/tests/mock_data/json/apigateway_index_lambda_response.json

Lines changed: 0 additions & 88 deletions
This file was deleted.

0 commit comments

Comments
 (0)