Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisclark committed Aug 26, 2024
1 parent f78a8ed commit 6586b63
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 85 deletions.
34 changes: 11 additions & 23 deletions explorer/assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def extract_response(r):
def table_schema(db_connection, table_name):
schema = schema_info(db_connection)
s = [table for table in schema if table[0] == table_name]
if len(s): return s[0]
if len(s):
return s[0]


def sample_rows_from_table(connection, table_name):
Expand Down Expand Up @@ -82,28 +83,15 @@ def sample_rows_from_table(connection, table_name):


def format_rows_from_table(rows):
column_headers = list(rows[0])
ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n"
for row in rows[1:]:
row_str = " | ".join(str(item) for item in row)
ret += row_str + "\n"
return ret
""" Given an array of rows (a list of lists), returns e.g.
AlbumId | Title | ArtistId
1 | For Those About To Rock We Salute You | 1
2 | Let It Rip | 2
3 | Restless and Wild | 2
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
import tiktoken
try:
encoding = tiktoken.encoding_for_model(OPENAI_MODEL)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(string))
return num_tokens


def fits_in_window(string: str) -> bool:
# Ratchet down by 5% to account for other boilerplate and system prompt
return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95)
"""
return "\n".join([" | ".join([str(item) for item in row]) for row in rows])


def build_system_prompt(flavor):
Expand Down Expand Up @@ -145,7 +133,7 @@ class TablePromptData:
annotation: TableDescription

def render(self):
ret = f"""## Table {self.name} ##
ret = f"""## Information for Table '{self.name}' ##
Schema:\n{self.schema}
Sample rows:\n{format_rows_from_table(self.sample)}"""
if self.annotation:
Expand All @@ -158,7 +146,7 @@ def build_prompt(db_connection, assistant_request, included_tables, query_error=

error_chunk = f"## Query Error ##\n{query_error}" if query_error else ""
sql_chunk = f"## Existing User-Written SQL ##\n{sql}" if sql else ""
request_chunk = f"## User's Request to Assistant ##\n\n{assistant_request}"
request_chunk = f"## User's Request to Assistant ##\n{assistant_request}"
table_chunks = [
TablePromptData(
name=t,
Expand Down
93 changes: 32 additions & 61 deletions explorer/tests/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,16 @@ def test_assistant_help(self, mocked_num_tokens, mocked_openai_client):
@unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled")
class TestBuildPrompt(TestCase):

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_vendor_only(self, mock_get_item, mock_fits_in_window, mock_sample_rows):
def test_build_prompt_with_vendor_only(self, mock_get_item):
mock_get_item.return_value.value = "system prompt"
result = build_prompt(default_db_connection(),
"Help me with SQL", [], sql="SELECT * FROM table;")
self.assertIn("sqlite", result["system"])

included_tables = []

result = build_prompt(default_db_connection(), "Help me with SQL", included_tables)
self.assertIn("## Database Type is sqlite", result["user"])
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
@patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data")
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_fits_in_window, mock_sample_rows):
def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_sample_rows):
mock_get_item.return_value.value = "system prompt"

included_tables = ["foo"]
Expand All @@ -89,15 +83,11 @@ def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_fits_in_

result = build_prompt(default_db_connection(),
"Help me with SQL", included_tables, sql="SELECT * FROM table;")
self.assertIn("## Database Type is sqlite", result["user"])
self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"])
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"])
self.assertIn("## Usage Notes about Table foo ##\nannotated\n\n", result["user"])
self.assertIn("Usage Notes:\nannotated", result["user"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
@patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data")
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_few_shot(self, mock_get_item, mock_fits_in_window, mock_sample_rows):
def test_build_prompt_with_few_shot(self, mock_get_item, mock_sample_rows):
mock_get_item.return_value.value = "system prompt"

included_tables = ["magic"]
Expand All @@ -106,56 +96,35 @@ def test_build_prompt_with_few_shot(self, mock_get_item, mock_fits_in_window, mo

result = build_prompt(default_db_connection(),
"Help me with SQL", included_tables, sql="SELECT * FROM table;")
self.assertIn("Example queries using these tables", result["user"])
self.assertIn("Relevant example queries", result["user"])
self.assertIn("magic value", result["user"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
@patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data")
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_window, mock_sample_rows):
def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_sample_rows):
mock_get_item.return_value.value = "system prompt"

included_tables = []

result = build_prompt(default_db_connection(),
"Help me with SQL", included_tables,
"Syntax error", "SELECT * FROM table;")
self.assertIn("## Database Type is sqlite", result["user"])
self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"])
self.assertIn("## Existing User-Written SQL ##\nSELECT * FROM table;", result["user"])
self.assertIn("## Query Error ##\nSyntax error\n", result["user"])
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"])
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL", result["user"])
self.assertIn("system prompt", result["system"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item, mock_fits_in_window, mock_sample_rows):
def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item):
mock_get_item.return_value.value = "system prompt"

included_tables = ["table1", "table2"]
included_tables = ["explorer_query"]
SimpleQueryFactory()

result = build_prompt(default_db_connection(), "Help me with SQL",
included_tables, sql="SELECT * FROM table;")
self.assertIn("## Database Type is sqlite", result["user"])
self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"])
self.assertIn("## Table Structure with Sampled Data ##\nsample data\n\n", result["user"])
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=False)
@patch("explorer.assistant.utils.tables_from_schema_info", return_value="table structure")
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_extra_tables_not_fitting_window(self, mock_get_item, mock_tables_from_schema_info,
mock_fits_in_window, mock_sample_rows):
mock_get_item.return_value.value = "system prompt"

included_tables = ["table1", "table2"]

result = build_prompt(default_db_connection(), "Help me with SQL",
included_tables, sql="SELECT * FROM table;")
self.assertIn("## Existing SQL ##\nSELECT * FROM table;\n\n", result["user"])
self.assertIn("## Table Structure ##\ntable structure\n\n", result["user"])
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"])
self.assertIn("## Information for Table 'explorer_query' ##", result["user"])
self.assertIn("Sample rows:\nid | title", result["user"])


@unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled")
Expand Down Expand Up @@ -238,12 +207,12 @@ def test_format_rows_from_table(self):
["val1", "val2"],
]
ret = format_rows_from_table(d)
self.assertEqual(ret, "col1 | col2\n" + "-" * 50 + "\nval1 | val2\n")
self.assertEqual(ret, "col1 | col2\nval1 | val2")

def test_schema_info_from_table_names(self):
from explorer.assistant.utils import table_schema
ret = table_schema(default_db_connection(), "explorer_query")
expected = [("explorer_query", [
expected = ("explorer_query", [
("id", "AutoField"),
("title", "CharField"),
("sql", "TextField"),
Expand All @@ -253,29 +222,31 @@ def test_schema_info_from_table_names(self):
("created_by_user_id", "IntegerField"),
("snapshot", "BooleanField"),
("connection", "CharField"),
("database_connection_id", "IntegerField")])]
("database_connection_id", "IntegerField"),
("few_shot", "BooleanField")])
self.assertEqual(ret, expected)


@unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled")
class TestAssistantUtils(TestCase):

def test_sample_rows_from_tables(self):
from explorer.assistant.utils import sample_rows_from_tables
def test_sample_rows_from_table(self):
from explorer.assistant.utils import sample_rows_from_table, format_rows_from_table
SimpleQueryFactory(title="First Query")
SimpleQueryFactory(title="Second Query")
QueryLogFactory()
ret = sample_rows_from_tables(conn(), ["explorer_query", "explorer_querylog"])
ret = sample_rows_from_table(conn(), "explorer_query")
self.assertEqual(len(ret), ROW_SAMPLE_SIZE)
ret = format_rows_from_table(ret)
self.assertTrue("First Query" in ret)
self.assertTrue("Second Query" in ret)
self.assertTrue("explorer_querylog" in ret)

def test_sample_rows_from_tables_no_tables(self):
from explorer.assistant.utils import sample_rows_from_tables
def test_sample_rows_from_tables_no_table_match(self):
from explorer.assistant.utils import sample_rows_from_table
SimpleQueryFactory(title="First Query")
SimpleQueryFactory(title="Second Query")
ret = sample_rows_from_tables(conn(), [])
self.assertEqual(ret, "")
ret = sample_rows_from_table(conn(), "banana")
self.assertEqual(ret, [["no such table: banana"]])

def test_relevant_few_shots(self):
relevant_q1 = SimpleQueryFactory(sql="select * from relevant_table", few_shot=True)
Expand Down
1 change: 0 additions & 1 deletion requirements/extra/assistant.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
openai>=1.6.1
tiktoken>=0.7

0 comments on commit 6586b63

Please sign in to comment.