|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# pylint: disable=wrong-import-position |
| 3 | +"""Test configuration Settings class.""" |
| 4 | + |
| 5 | +# python stuff |
| 6 | +import json |
| 7 | +import os |
| 8 | +import sys |
| 9 | +import unittest |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | + |
| 13 | +HERE = os.path.abspath(os.path.dirname(__file__)) |
| 14 | +PROJECT_ROOT = str(Path(HERE).parent.parent) |
| 15 | +PYTHON_ROOT = str(Path(PROJECT_ROOT).parent) |
| 16 | +if PYTHON_ROOT not in sys.path: |
| 17 | + sys.path.append(PYTHON_ROOT) # noqa: E402 |
| 18 | + |
| 19 | +from openai_api.common.const import OpenAIMessageKeys # noqa: E402 |
| 20 | + |
| 21 | +# our stuff |
| 22 | +from openai_api.common.tests.test_setup import get_test_file # noqa: E402 |
| 23 | +from openai_api.common.utils import ( # noqa: E402 |
| 24 | + exception_response_factory, |
| 25 | + get_content_for_role, |
| 26 | + get_message_history, |
| 27 | + get_messages_for_role, |
| 28 | + get_request_body, |
| 29 | + http_response_factory, |
| 30 | + parse_request, |
| 31 | +) |
| 32 | + |
| 33 | + |
| 34 | +class TestUtils(unittest.TestCase): |
| 35 | + """Test utils.""" |
| 36 | + |
| 37 | + # Get the directory of the current script |
| 38 | + here = HERE |
| 39 | + request = get_test_file("json/passthrough_openai_v2_request.json") |
| 40 | + response = get_test_file("json/passthrough_openai_v2_response.json") |
| 41 | + |
| 42 | + def setUp(self): |
| 43 | + """Set up test fixtures.""" |
| 44 | + |
| 45 | + def test_http_response_factory(self): |
| 46 | + """Test test_http_response_factory.""" |
| 47 | + retval = http_response_factory(200, self.response) |
| 48 | + self.assertEqual(retval["statusCode"], 200) |
| 49 | + self.assertEqual(retval["body"], json.dumps(self.response)) |
| 50 | + self.assertEqual(retval["isBase64Encoded"], False) |
| 51 | + self.assertEqual(retval["headers"]["Content-Type"], "application/json") |
| 52 | + |
| 53 | + def test_exception_response_factory(self): |
| 54 | + """Test exception_response_factory.""" |
| 55 | + try: |
| 56 | + raise AssertionError("test") |
| 57 | + except AssertionError as exception: |
| 58 | + retval = exception_response_factory(exception) |
| 59 | + self.assertIn("error", retval) |
| 60 | + self.assertIn("description", retval) |
| 61 | + |
| 62 | + def test_get_request_body(self): |
| 63 | + """Test get_request_body""" |
| 64 | + request_body = get_request_body(self.request) |
| 65 | + self.assertEqual(request_body, self.request) |
| 66 | + self.assertEqual(request_body["model"], "gpt-3.5-turbo") |
| 67 | + self.assertEqual(request_body["object"], "chat.completion") |
| 68 | + self.assertIn("temperature", request_body) |
| 69 | + self.assertIn("max_tokens", request_body) |
| 70 | + self.assertIn("messages", request_body) |
| 71 | + |
| 72 | + def test_parse_request(self): |
| 73 | + """Test parse_request""" |
| 74 | + request_body = get_request_body(self.request) |
| 75 | + object_type, model, messages, input_text, temperature, max_tokens = parse_request(request_body) |
| 76 | + self.assertEqual(object_type, "chat.completion") |
| 77 | + self.assertEqual(model, "gpt-3.5-turbo") |
| 78 | + self.assertEqual(input_text, None) |
| 79 | + self.assertEqual(temperature, 0) |
| 80 | + self.assertEqual(max_tokens, 256) |
| 81 | + self.assertEqual(len(messages), 2) |
| 82 | + |
| 83 | + def test_get_content_for_role(self): |
| 84 | + """Test get_content_for_role""" |
| 85 | + request_body = get_request_body(self.request) |
| 86 | + _, _, messages, _, _, _ = parse_request(request_body) |
| 87 | + system_message = get_content_for_role(messages, OpenAIMessageKeys.OPENAI_SYSTEM_MESSAGE_KEY) |
| 88 | + user_message = get_content_for_role(messages, OpenAIMessageKeys.OPENAI_USER_MESSAGE_KEY) |
| 89 | + self.assertEqual(system_message, "you always return the integer value 42.") |
| 90 | + self.assertEqual(user_message, "return the integer value 42.") |
| 91 | + |
| 92 | + def test_get_message_history(self): |
| 93 | + """test get_message_history""" |
| 94 | + request_body = get_request_body(self.request) |
| 95 | + _, _, messages, _, _, _ = parse_request(request_body) |
| 96 | + message_history = get_message_history(messages) |
| 97 | + self.assertIsInstance(message_history, list) |
| 98 | + self.assertEqual(len(message_history), 1) |
| 99 | + self.assertEqual(message_history[0]["role"], "user") |
| 100 | + self.assertEqual(message_history[0]["content"], "return the integer value 42.") |
| 101 | + |
| 102 | + def test_get_messages_for_role(self): |
| 103 | + """test get_messages_for_role""" |
| 104 | + request_body = get_request_body(self.request) |
| 105 | + _, _, messages, _, _, _ = parse_request(request_body) |
| 106 | + message_history = get_message_history(messages) |
| 107 | + self.assertIsInstance(message_history, list) |
| 108 | + user_messages = get_messages_for_role(message_history, OpenAIMessageKeys.OPENAI_USER_MESSAGE_KEY) |
| 109 | + self.assertEqual(len(user_messages), 1) |
| 110 | + self.assertEqual(user_messages[0], "return the integer value 42.") |
0 commit comments