Skip to content

Commit 292517b

Browse files
committed
Refactor, please mypy
1 parent 33b5295 commit 292517b

File tree

19 files changed

+283
-93
lines changed

19 files changed

+283
-93
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
"htmlcov": true,
3737
".mypy_cache": true,
3838
".coverage": true
39-
}
39+
},
40+
"python.REPL.enableREPLSmartSend": false
4041
}

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17+
freezegun

src/backend/fastapi_app/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3838
if (
3939
os.getenv("OPENAI_CHAT_HOST") == "azure"
4040
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41-
or os.getenv("POSTGRES_HOST").endswith(".database.azure.com")
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
4242
):
4343
azure_credential = await get_azure_credential()
4444
engine = await create_postgres_engine_from_env(azure_credential)

src/backend/fastapi_app/api_models.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from enum import Enum
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field
6+
from pydantic_ai.messages import ModelRequest, ModelResponse
67

78

89
class AIChatRoles(str, Enum):
@@ -95,4 +96,33 @@ class ChatParams(ChatRequestOverrides):
9596
enable_text_search: bool
9697
enable_vector_search: bool
9798
original_user_query: str
98-
past_messages: list[ChatCompletionMessageParam]
99+
past_messages: list[Union[ModelRequest, ModelResponse]]
100+
101+
102+
class Filter(BaseModel):
103+
column: str
104+
comparison_operator: str
105+
value: Any
106+
107+
108+
class PriceFilter(Filter):
109+
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
110+
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
111+
value: float = Field(description="The price value to compare against (e.g., 30.00)")
112+
113+
114+
class BrandFilter(Filter):
115+
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
116+
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
117+
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
118+
119+
120+
class SearchResults(BaseModel):
121+
query: str
122+
"""The original search query"""
123+
124+
items: list[ItemPublic]
125+
"""List of items that match the search query and filters"""
126+
127+
filters: list[Filter]
128+
"""List of filters applied to the search results"""

src/backend/fastapi_app/openai_clients.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
async def create_openai_chat_client(
12-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
12+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
1313
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
@@ -29,7 +29,7 @@ async def create_openai_chat_client(
2929
azure_deployment=azure_deployment,
3030
api_key=api_key,
3131
)
32-
else:
32+
elif azure_credential:
3333
logger.info(
3434
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
3535
azure_endpoint,
@@ -44,6 +44,8 @@ async def create_openai_chat_client(
4444
azure_deployment=azure_deployment,
4545
azure_ad_token_provider=token_provider,
4646
)
47+
else:
48+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
4749
elif OPENAI_CHAT_HOST == "ollama":
4850
logger.info("Setting up OpenAI client for chat completions using Ollama")
4951
openai_chat_client = openai.AsyncOpenAI(
@@ -67,7 +69,7 @@ async def create_openai_chat_client(
6769

6870

6971
async def create_openai_embed_client(
70-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
72+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
7173
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
7274
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
7375
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
@@ -87,7 +89,7 @@ async def create_openai_embed_client(
8789
azure_deployment=azure_deployment,
8890
api_key=api_key,
8991
)
90-
else:
92+
elif azure_credential:
9193
logger.info(
9294
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
9395
azure_endpoint,
@@ -102,6 +104,8 @@ async def create_openai_embed_client(
102104
azure_deployment=azure_deployment,
103105
azure_ad_token_provider=token_provider,
104106
)
107+
else:
108+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
105109
elif OPENAI_EMBED_HOST == "ollama":
106110
logger.info("Setting up OpenAI client for embeddings using Ollama")
107111
openai_embed_client = openai.AsyncOpenAI(

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Float, Integer, column, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from fastapi_app.api_models import Filter
89
from fastapi_app.embeddings import compute_text_embedding
910
from fastapi_app.postgres_models import Item
1011

@@ -26,21 +27,24 @@ def __init__(
2627
self.embed_dimensions = embed_dimensions
2728
self.embedding_column = embedding_column
2829

29-
def build_filter_clause(self, filters) -> tuple[str, str]:
30+
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
3031
if filters is None:
3132
return "", ""
3233
filter_clauses = []
3334
for filter in filters:
34-
if isinstance(filter["value"], str):
35-
filter["value"] = f"'{filter['value']}'"
36-
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
35+
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
36+
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
3737
filter_clause = " AND ".join(filter_clauses)
3838
if len(filter_clause) > 0:
3939
return f"WHERE {filter_clause}", f"AND {filter_clause}"
4040
return "", ""
4141

4242
async def search(
43-
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
43+
self,
44+
query_text: Optional[str],
45+
query_vector: list[float],
46+
top: int = 5,
47+
filters: Optional[list[Filter]] = None,
4448
):
4549
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4650
table_name = Item.__tablename__
@@ -106,7 +110,7 @@ async def search_and_embed(
106110
top: int = 5,
107111
enable_vector_search: bool = False,
108112
enable_text_search: bool = False,
109-
filters: Optional[list[dict]] = None,
113+
filters: Optional[list[Filter]] = None,
110114
) -> list[Item]:
111115
"""
112116
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.

src/backend/fastapi_app/rag_advanced.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import AsyncGenerator
2-
from typing import Optional, TypedDict, Union
2+
from typing import Optional, Union
33

44
from openai import AsyncAzureOpenAI, AsyncOpenAI
55
from openai.types.chat import ChatCompletionMessageParam
@@ -11,51 +11,22 @@
1111

1212
from fastapi_app.api_models import (
1313
AIChatRoles,
14+
BrandFilter,
1415
ChatRequestOverrides,
16+
Filter,
1517
ItemPublic,
1618
Message,
19+
PriceFilter,
1720
RAGContext,
1821
RetrievalResponse,
1922
RetrievalResponseDelta,
23+
SearchResults,
2024
ThoughtStep,
2125
)
2226
from fastapi_app.postgres_searcher import PostgresSearcher
2327
from fastapi_app.rag_base import ChatParams, RAGChatBase
2428

2529

26-
class PriceFilter(TypedDict):
27-
column: str = "price"
28-
"""The column to filter on (always 'price' for this filter)"""
29-
30-
comparison_operator: str
31-
"""The operator for price comparison ('>', '<', '>=', '<=', '=')"""
32-
33-
value: float
34-
""" The price value to compare against (e.g., 30.00) """
35-
36-
37-
class BrandFilter(TypedDict):
38-
column: str = "brand"
39-
"""The column to filter on (always 'brand' for this filter)"""
40-
41-
comparison_operator: str
42-
"""The operator for brand comparison ('=' or '!=')"""
43-
44-
value: str
45-
"""The brand name to compare against (e.g., 'AirStrider')"""
46-
47-
48-
class SearchResults(TypedDict):
49-
query: str
50-
"""The original search query"""
51-
52-
items: list[ItemPublic]
53-
"""List of items that match the search query and filters"""
54-
55-
filters: list[Union[PriceFilter, BrandFilter]]
56-
"""List of filters applied to the search results"""
57-
58-
5930
class AdvancedRAGChat(RAGChatBase):
6031
query_prompt_template = open(RAGChatBase.prompts_dir / "query.txt").read()
6132
query_fewshots = open(RAGChatBase.prompts_dir / "query_fewshots.json").read()
@@ -79,9 +50,13 @@ def __init__(
7950
chat_model if chat_deployment is None else chat_deployment,
8051
provider=OpenAIProvider(openai_client=openai_chat_client),
8152
)
82-
self.search_agent = Agent(
53+
self.search_agent = Agent[ChatParams, SearchResults](
8354
pydantic_chat_model,
84-
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=self.chat_params.seed),
55+
model_settings=ModelSettings(
56+
temperature=0.0,
57+
max_tokens=500,
58+
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
59+
),
8560
system_prompt=self.query_prompt_template,
8661
tools=[self.search_database],
8762
output_type=SearchResults,
@@ -92,7 +67,7 @@ def __init__(
9267
model_settings=ModelSettings(
9368
temperature=self.chat_params.temperature,
9469
max_tokens=self.chat_params.response_token_limit,
95-
seed=self.chat_params.seed,
70+
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
9671
),
9772
)
9873

@@ -115,7 +90,7 @@ async def search_database(
11590
List of formatted items that match the search query and filters
11691
"""
11792
# Only send non-None filters
118-
filters = []
93+
filters: list[Filter] = []
11994
if price_filter:
12095
filters.append(price_filter)
12196
if brand_filter:
@@ -134,12 +109,12 @@ async def search_database(
134109
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
135110
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
136111
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
137-
results = await self.search_agent.run(
112+
results = await self.search_agent.run( # type: ignore[call-overload]
138113
user_query,
139114
message_history=few_shots + self.chat_params.past_messages,
140115
deps=self.chat_params,
141116
)
142-
items = results.output["items"]
117+
items = results.output.items
143118
thoughts = [
144119
ThoughtStep(
145120
title="Prompt to generate search arguments",
@@ -148,12 +123,12 @@ async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
148123
),
149124
ThoughtStep(
150125
title="Search using generated search arguments",
151-
description=results.output["query"],
126+
description=results.output.query,
152127
props={
153128
"top": self.chat_params.top,
154129
"vector_search": self.chat_params.enable_vector_search,
155130
"text_search": self.chat_params.enable_text_search,
156-
"filters": results.output["filters"],
131+
"filters": results.output.filters,
157132
},
158133
),
159134
ThoughtStep(

src/backend/fastapi_app/rag_base.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pathlib
22
from abc import ABC, abstractmethod
33
from collections.abc import AsyncGenerator
4+
from typing import Union
45

56
from openai.types.chat import ChatCompletionMessageParam
7+
from pydantic_ai.messages import ModelRequest, ModelResponse, TextPart, UserPromptPart
68

79
from fastapi_app.api_models import (
810
ChatParams,
@@ -12,7 +14,6 @@
1214
RetrievalResponseDelta,
1315
ThoughtStep,
1416
)
15-
from fastapi_app.postgres_models import Item
1617

1718

1819
class RAGChatBase(ABC):
@@ -31,7 +32,19 @@ def get_chat_params(
3132
original_user_query = messages[-1]["content"]
3233
if not isinstance(original_user_query, str):
3334
raise ValueError("The most recent message content must be a string.")
34-
past_messages = messages[:-1]
35+
36+
# Convert to PydanticAI format:
37+
past_messages: list[Union[ModelRequest, ModelResponse]] = []
38+
for message in messages[:-1]:
39+
content = message["content"]
40+
if not isinstance(content, str):
41+
raise ValueError("All messages must have string content.")
42+
if message["role"] == "user":
43+
past_messages.append(ModelRequest(parts=[UserPromptPart(content=content)]))
44+
elif message["role"] == "assistant":
45+
past_messages.append(ModelResponse(parts=[TextPart(content=content)]))
46+
else:
47+
raise ValueError(f"Cannot convert message: {message}")
3548

3649
return ChatParams(
3750
top=overrides.top,
@@ -48,9 +61,7 @@ def get_chat_params(
4861
)
4962

5063
@abstractmethod
51-
async def prepare_context(
52-
self, chat_params: ChatParams
53-
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
64+
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
5465
raise NotImplementedError
5566

5667
def prepare_rag_request(self, user_query, items: list[ItemPublic]) -> str:
@@ -60,19 +71,15 @@ def prepare_rag_request(self, user_query, items: list[ItemPublic]) -> str:
6071
@abstractmethod
6172
async def answer(
6273
self,
63-
chat_params: ChatParams,
64-
contextual_messages: list[ChatCompletionMessageParam],
65-
results: list[Item],
74+
items: list[ItemPublic],
6675
earlier_thoughts: list[ThoughtStep],
6776
) -> RetrievalResponse:
6877
raise NotImplementedError
6978

7079
@abstractmethod
7180
async def answer_stream(
7281
self,
73-
chat_params: ChatParams,
74-
contextual_messages: list[ChatCompletionMessageParam],
75-
results: list[Item],
82+
items: list[ItemPublic],
7683
earlier_thoughts: list[ThoughtStep],
7784
) -> AsyncGenerator[RetrievalResponseDelta, None]:
7885
raise NotImplementedError

src/backend/fastapi_app/rag_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
model_settings=ModelSettings(
4949
temperature=self.chat_params.temperature,
5050
max_tokens=self.chat_params.response_token_limit,
51-
seed=self.chat_params.seed,
51+
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
5252
),
5353
)
5454

src/backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"opentelemetry-instrumentation-sqlalchemy",
2020
"opentelemetry-instrumentation-aiohttp-client",
2121
"opentelemetry-instrumentation-openai",
22-
"pydantic-ai"
22+
"pydantic-ai-slim[openai]"
2323
]
2424

2525
[build-system]

0 commit comments

Comments
 (0)