diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py index 4785a917..100c0db1 100644 --- a/explorer/assistant/utils.py +++ b/explorer/assistant/utils.py @@ -41,7 +41,8 @@ def extract_response(r): def table_schema(db_connection, table_name): schema = schema_info(db_connection) s = [table for table in schema if table[0] == table_name] - if len(s): return s[0] + if len(s): + return s[0] def sample_rows_from_table(connection, table_name): @@ -82,28 +83,15 @@ def sample_rows_from_table(connection, table_name): def format_rows_from_table(rows): - column_headers = list(rows[0]) - ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n" - for row in rows[1:]: - row_str = " | ".join(str(item) for item in row) - ret += row_str + "\n" - return ret + """ Given an array of rows (a list of lists), returns e.g. +AlbumId | Title | ArtistId +1 | For Those About To Rock We Salute You | 1 +2 | Let It Rip | 2 +3 | Restless and Wild | 2 -def num_tokens_from_string(string: str) -> int: - """Returns the number of tokens in a text string.""" - import tiktoken - try: - encoding = tiktoken.encoding_for_model(OPENAI_MODEL) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = len(encoding.encode(string)) - return num_tokens - - -def fits_in_window(string: str) -> bool: - # Ratchet down by 5% to account for other boilerplate and system prompt - return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95) + """ + return "\n".join([" | ".join([str(item) for item in row]) for row in rows]) def build_system_prompt(flavor): @@ -145,7 +133,7 @@ class TablePromptData: annotation: TableDescription def render(self): - ret = f"""## Table {self.name} ## + ret = f"""## Information for Table '{self.name}' ## Schema:\n{self.schema} Sample rows:\n{format_rows_from_table(self.sample)}""" if self.annotation: @@ -158,7 +146,7 @@ def build_prompt(db_connection, assistant_request, included_tables, query_error= error_chunk = f"## Query Error ##\n{query_error}" if query_error else "" sql_chunk = f"## Existing User-Written SQL ##\n{sql}" if sql else "" - request_chunk = f"## User's Request to Assistant ##\n\n{assistant_request}" + request_chunk = f"## User's Request to Assistant ##\n{assistant_request}" table_chunks = [ TablePromptData( name=t, diff --git a/explorer/tests/test_assistant.py b/explorer/tests/test_assistant.py index f3705183..d3e3ca9c 100644 --- a/explorer/tests/test_assistant.py +++ b/explorer/tests/test_assistant.py @@ -65,22 +65,16 @@ def test_assistant_help(self, mocked_num_tokens, mocked_openai_client): @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestBuildPrompt(TestCase): - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_vendor_only(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_vendor_only(self, mock_get_item): mock_get_item.return_value.value = "system prompt" + result = build_prompt(default_db_connection(), + "Help me with SQL", [], sql="SELECT * FROM table;") + self.assertIn("sqlite", result["system"]) - included_tables = [] - - result = build_prompt(default_db_connection(), "Help me with SQL", included_tables) - self.assertIn("## Database Type is sqlite", result["user"]) - self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"]) - - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_sample_rows): mock_get_item.return_value.value = "system prompt" included_tables = ["foo"] @@ -89,15 +83,11 @@ def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_fits_in_ result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Type is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"]) - self.assertIn("## Usage Notes about Table foo ##\nannotated\n\n", result["user"]) + self.assertIn("Usage Notes:\nannotated", result["user"]) - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_few_shot(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_few_shot(self, mock_get_item, mock_sample_rows): mock_get_item.return_value.value = "system prompt" included_tables = ["magic"] @@ -106,13 +96,12 @@ def test_build_prompt_with_few_shot(self, mock_get_item, mock_fits_in_window, mo result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("Example queries using these tables", result["user"]) + self.assertIn("Relevant example queries", result["user"]) self.assertIn("magic value", result["user"]) - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_sample_rows): mock_get_item.return_value.value = "system prompt" included_tables = [] @@ -120,42 +109,22 @@ def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_windo result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, "Syntax error", "SELECT * FROM table;") - self.assertIn("## Database Type is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"]) + self.assertIn("## Existing User-Written SQL ##\nSELECT * FROM table;", result["user"]) self.assertIn("## Query Error ##\nSyntax error\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"]) + self.assertIn("## User's Request to Assistant ##\nHelp me with SQL", result["user"]) self.assertIn("system prompt", result["system"]) - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item): mock_get_item.return_value.value = "system prompt" - included_tables = ["table1", "table2"] + included_tables = ["explorer_query"] + SimpleQueryFactory() result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Type is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Table Structure with Sampled Data ##\nsample data\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"]) - - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=False) - @patch("explorer.assistant.utils.tables_from_schema_info", return_value="table structure") - @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_extra_tables_not_fitting_window(self, mock_get_item, mock_tables_from_schema_info, - mock_fits_in_window, mock_sample_rows): - mock_get_item.return_value.value = "system prompt" - - included_tables = ["table1", "table2"] - - result = build_prompt(default_db_connection(), "Help me with SQL", - included_tables, sql="SELECT * FROM table;") - self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Table Structure ##\ntable structure\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"]) + self.assertIn("## Information for Table 'explorer_query' ##", result["user"]) + self.assertIn("Sample rows:\nid | title", result["user"]) @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") @@ -238,12 +207,12 @@ def test_format_rows_from_table(self): ["val1", "val2"], ] ret = format_rows_from_table(d) - self.assertEqual(ret, "col1 | col2\n" + "-" * 50 + "\nval1 | val2\n") + self.assertEqual(ret, "col1 | col2\nval1 | val2") def test_schema_info_from_table_names(self): from explorer.assistant.utils import table_schema ret = table_schema(default_db_connection(), "explorer_query") - expected = [("explorer_query", [ + expected = ("explorer_query", [ ("id", "AutoField"), ("title", "CharField"), ("sql", "TextField"), @@ -253,29 +222,31 @@ def test_schema_info_from_table_names(self): ("created_by_user_id", "IntegerField"), ("snapshot", "BooleanField"), ("connection", "CharField"), - ("database_connection_id", "IntegerField")])] + ("database_connection_id", "IntegerField"), + ("few_shot", "BooleanField")]) self.assertEqual(ret, expected) @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestAssistantUtils(TestCase): - def test_sample_rows_from_tables(self): - from explorer.assistant.utils import sample_rows_from_tables + def test_sample_rows_from_table(self): + from explorer.assistant.utils import sample_rows_from_table, format_rows_from_table SimpleQueryFactory(title="First Query") SimpleQueryFactory(title="Second Query") QueryLogFactory() - ret = sample_rows_from_tables(conn(), ["explorer_query", "explorer_querylog"]) + ret = sample_rows_from_table(conn(), "explorer_query") + self.assertEqual(len(ret), ROW_SAMPLE_SIZE) + ret = format_rows_from_table(ret) self.assertTrue("First Query" in ret) self.assertTrue("Second Query" in ret) - self.assertTrue("explorer_querylog" in ret) - def test_sample_rows_from_tables_no_tables(self): - from explorer.assistant.utils import sample_rows_from_tables + def test_sample_rows_from_tables_no_table_match(self): + from explorer.assistant.utils import sample_rows_from_table SimpleQueryFactory(title="First Query") SimpleQueryFactory(title="Second Query") - ret = sample_rows_from_tables(conn(), []) - self.assertEqual(ret, "") + ret = sample_rows_from_table(conn(), "banana") + self.assertEqual(ret, [["no such table: banana"]]) def test_relevant_few_shots(self): relevant_q1 = SimpleQueryFactory(sql="select * from relevant_table", few_shot=True) diff --git a/requirements/extra/assistant.txt b/requirements/extra/assistant.txt index a04e4fa4..44fbead2 100644 --- a/requirements/extra/assistant.txt +++ b/requirements/extra/assistant.txt @@ -1,2 +1 @@ openai>=1.6.1 -tiktoken>=0.7