From ad2c7036d4089427426957464c04d15c9f405e11 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 19 May 2025 14:39:13 -0400 Subject: [PATCH 1/2] structured query gen example --- .../structure_query_generation.py | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 examples/pydantic_ai_examples/structure_query_generation.py diff --git a/examples/pydantic_ai_examples/structure_query_generation.py b/examples/pydantic_ai_examples/structure_query_generation.py new file mode 100644 index 000000000..1b4eeda68 --- /dev/null +++ b/examples/pydantic_ai_examples/structure_query_generation.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import Any, Literal + +from devtools import debug +from pydantic import BaseModel, Field + +from pydantic_ai import Agent, ToolOutput + +type QueryColumns = Literal[ + 'id', 'first_name', 'last_name', 'dob', 'email', 'address', 'phone' +] +table_name = 'users' + + +class Query(BaseModel, use_attribute_docstrings=True): + """Structured definition of a database query.""" + + select: Literal['*'] | list[SelectColumn] | list[Aggregation] = '*' + """List of columns to select from the table.""" + filter: FilterClause | FilterAnd | FilterOr | None = None + """Filter condition for the query.""" + order_by: list[OrderBy] | None = None + """List of columns to order the results by.""" + limit: int | None = None + """Maximum number of rows to return.""" + offset: int | None = None + """Number of rows to skip before returning results.""" + + def as_sql(self) -> tuple[str, list[Any]]: + param_mgr = ParamManager() + parts: list[str] = ['SELECT'] + if isinstance(self.select, list): + parts.append(' ' + ', '.join(c.as_sql() for c in self.select)) + else: + parts.append(' ' + self.select) + + parts.append(f'FROM {table_name}') + + if self.filter: + parts.append('WHERE ' + self.filter.as_sql(param_mgr)) + if self.order_by: + parts.append( + 'ORDER BY ' + ', '.join(order_by.as_sql() for order_by in self.order_by) + ) + if self.limit: + parts.append(f'LIMIT {self.limit}') + if self.offset: + parts.append(f'OFFSET {self.offset}') + return '\n'.join(parts), param_mgr.params + + +class SelectColumn(BaseModel): + name: QueryColumns + alias: str | None = Field(pattern=r'^[a-zA-Z_][a-zA-Z0-9_ ]*$') + + def as_sql(self) -> str: + if self.alias: + return f'{self.name} AS {f'"{self.alias}"' if " " in self.alias else self.alias}' + else: + return self.name + + +class Aggregation(BaseModel): + column: QueryColumns | Literal['*'] + function: Literal['SUM', 'AVG', 'MAX', 'MIN', 'COUNT'] + alias: str | None = Field(pattern=r'^[a-zA-Z_][a-zA-Z0-9_ ]*$') + + def as_sql(self) -> str: + func_call = f'{self.function}({self.column})' + if self.alias: + return f'{func_call} AS {f'"{self.alias}"' if " " in self.alias else self.alias}' + else: + return func_call + + +class FilterAnd(BaseModel): + left: FilterClause | FilterAnd | FilterOr + right: FilterClause | FilterAnd | FilterOr + + def as_sql(self, param_mgr: ParamManager) -> str: + left_sql = self.left.as_sql(param_mgr) + right_sql = self.right.as_sql(param_mgr) + return f'{left_sql} AND {right_sql}' + + +class FilterOr(BaseModel): + left: FilterClause | FilterAnd | FilterOr + right: FilterClause | FilterAnd | FilterOr + + def as_sql(self, param_mgr: ParamManager) -> str: + left_sql = self.left.as_sql(param_mgr) + right_sql = self.right.as_sql(param_mgr) + return f'{left_sql} OR {right_sql}' + + +class FilterClause(BaseModel): + column: QueryColumns + operator: Literal['=', '!=', '<', '>', '<=', '>='] + value: str | int | float | bool + + def as_sql(self, param_mgr: ParamManager) -> str: + value_param = param_mgr.add_param(self.value) + return f'{self.column} {self.operator} {value_param}' + + +class OrderBy(BaseModel): + column: QueryColumns + direction: Literal['ASC', 'DESC'] + + def as_sql(self) -> str: + return f'{self.column} {self.direction}' + + +class ParamManager: + def __init__(self) -> None: + self.params: list[Any] = [] + + def add_param(self, param: Any) -> str: + self.params.append(param) + return f'${len(self.params)}' + + +query_agent = Agent( + 'openai:gpt-4o', + output_type=ToolOutput(type_=Query, name='query'), + instructions="Generate a query to match the users' preferences.", +) + + +r = query_agent.run_sync( + # 'Find user IDs for users who were born before 1990 ordered by how old they are.' + # 'How many users are there.' + "what is jane's date of birth." +) +debug(r.output) +query, params = r.output.as_sql() +print(f'SQL:\n-------\n{query}\n-------\n{params=}') From af504c58c41545fde6a9ebc169e6a453e83446d5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 19 May 2025 14:50:36 -0400 Subject: [PATCH 2/2] tweak example --- .../structure_query_generation.py | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/examples/pydantic_ai_examples/structure_query_generation.py b/examples/pydantic_ai_examples/structure_query_generation.py index 1b4eeda68..4b0b385ea 100644 --- a/examples/pydantic_ai_examples/structure_query_generation.py +++ b/examples/pydantic_ai_examples/structure_query_generation.py @@ -1,7 +1,10 @@ from __future__ import annotations +import asyncio +from dataclasses import dataclass, field from typing import Any, Literal +import logfire from devtools import debug from pydantic import BaseModel, Field @@ -11,6 +14,20 @@ 'id', 'first_name', 'last_name', 'dob', 'email', 'address', 'phone' ] table_name = 'users' +type Param = list[str | int | float | bool] + +# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_pydantic_ai() + + +@dataclass +class ParamsManager: + params: list[Any] = field(default_factory=list) + + def add_param(self, param: Any) -> str: + self.params.append(param) + return f'${len(self.params)}' class Query(BaseModel, use_attribute_docstrings=True): @@ -28,17 +45,17 @@ class Query(BaseModel, use_attribute_docstrings=True): """Number of rows to skip before returning results.""" def as_sql(self) -> tuple[str, list[Any]]: - param_mgr = ParamManager() + params_mgr = ParamsManager() parts: list[str] = ['SELECT'] if isinstance(self.select, list): parts.append(' ' + ', '.join(c.as_sql() for c in self.select)) else: - parts.append(' ' + self.select) + parts.append(f' {self.select}') parts.append(f'FROM {table_name}') if self.filter: - parts.append('WHERE ' + self.filter.as_sql(param_mgr)) + parts.append(f'WHERE {self.filter.as_sql(params_mgr)}') if self.order_by: parts.append( 'ORDER BY ' + ', '.join(order_by.as_sql() for order_by in self.order_by) @@ -47,7 +64,7 @@ def as_sql(self) -> tuple[str, list[Any]]: parts.append(f'LIMIT {self.limit}') if self.offset: parts.append(f'OFFSET {self.offset}') - return '\n'.join(parts), param_mgr.params + return '\n'.join(parts), params_mgr.params class SelectColumn(BaseModel): @@ -78,9 +95,9 @@ class FilterAnd(BaseModel): left: FilterClause | FilterAnd | FilterOr right: FilterClause | FilterAnd | FilterOr - def as_sql(self, param_mgr: ParamManager) -> str: - left_sql = self.left.as_sql(param_mgr) - right_sql = self.right.as_sql(param_mgr) + def as_sql(self, params_mgr: ParamsManager) -> str: + left_sql = self.left.as_sql(params_mgr) + right_sql = self.right.as_sql(params_mgr) return f'{left_sql} AND {right_sql}' @@ -88,9 +105,9 @@ class FilterOr(BaseModel): left: FilterClause | FilterAnd | FilterOr right: FilterClause | FilterAnd | FilterOr - def as_sql(self, param_mgr: ParamManager) -> str: - left_sql = self.left.as_sql(param_mgr) - right_sql = self.right.as_sql(param_mgr) + def as_sql(self, params_mgr: ParamsManager) -> str: + left_sql = self.left.as_sql(params_mgr) + right_sql = self.right.as_sql(params_mgr) return f'{left_sql} OR {right_sql}' @@ -99,8 +116,8 @@ class FilterClause(BaseModel): operator: Literal['=', '!=', '<', '>', '<=', '>='] value: str | int | float | bool - def as_sql(self, param_mgr: ParamManager) -> str: - value_param = param_mgr.add_param(self.value) + def as_sql(self, params_mgr: ParamsManager) -> str: + value_param = params_mgr.add_param(self.value) return f'{self.column} {self.operator} {value_param}' @@ -112,15 +129,6 @@ def as_sql(self) -> str: return f'{self.column} {self.direction}' -class ParamManager: - def __init__(self) -> None: - self.params: list[Any] = [] - - def add_param(self, param: Any) -> str: - self.params.append(param) - return f'${len(self.params)}' - - query_agent = Agent( 'openai:gpt-4o', output_type=ToolOutput(type_=Query, name='query'), @@ -128,11 +136,18 @@ def add_param(self, param: Any) -> str: ) -r = query_agent.run_sync( - # 'Find user IDs for users who were born before 1990 ordered by how old they are.' - # 'How many users are there.' - "what is jane's date of birth." -) -debug(r.output) -query, params = r.output.as_sql() -print(f'SQL:\n-------\n{query}\n-------\n{params=}') +async def main(user_query: str): + r = await query_agent.run() + debug(r.output) + query, params = r.output.as_sql() + print(f'SQL:\n-------\n{query}\n-------\n{params=}') + + +if __name__ == '__main__': + asyncio.run( + main( + # 'Find user IDs for users who were born before 1990 ordered by how old they are.' + # 'How many users are there.' + "what is jane's date of birth." + ) + )