diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index 2263b6667..bde32de38 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -29,15 +29,8 @@ branchProtectionRules: requiresStrictStatusChecks: true requiredStatusCheckContexts: - "cla/google" - - "retrieval_service" - "llm_demo" - - "retrieval-service-app-pr (retrieval-app-testing)" - - "retrieval-service-postgres-pr (retrieval-app-testing)" - - "retrieval-service-alloydb-pr (retrieval-app-testing)" - - "retrieval-service-cloudsql-pg-pr (retrieval-app-testing)" - - "llm-demo-langchain-tools-pr (retrieval-app-testing)" - "llm-demo-langgraph-pr (retrieval-app-testing)" - - "llm-demo-vertexai-fc-pr (retrieval-app-testing)" # Set team access permissionRules: - team: senseai-eco diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c5030f74e..8de62e866 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -30,12 +30,8 @@ permissions: read-all jobs: integration: - name: ${{ matrix.dir }} + name: Cymbal Air runs-on: ubuntu-latest - strategy: - matrix: - dir: [retrieval_service, llm_demo] - fail-fast: false permissions: contents: read id-token: write @@ -51,11 +47,9 @@ jobs: python-version: "3.11" - name: Install requirements - working-directory: ${{ matrix.dir }} run: pip install -r requirements.txt -r requirements-test.txt - name: Run lints - working-directory: ${{ matrix.dir }} run: | black --check . isort --check . diff --git a/.github/workflows/lint_fallback.yml b/.github/workflows/lint_fallback.yml index faf534b58..f61c61ca5 100644 --- a/.github/workflows/lint_fallback.yml +++ b/.github/workflows/lint_fallback.yml @@ -21,11 +21,8 @@ on: jobs: integration: - name: ${{ matrix.dir }} + name: Cymbal Air runs-on: ubuntu-latest - strategy: - matrix: - dir: [retrieval_service, llm_demo] permissions: contents: none diff --git a/llm_demo/Dockerfile b/Dockerfile similarity index 100% rename from llm_demo/Dockerfile rename to Dockerfile diff --git a/llm_demo/app.py b/app.py similarity index 95% rename from llm_demo/app.py rename to app.py index 6addf875d..ea5a0773d 100644 --- a/llm_demo/app.py +++ b/app.py @@ -27,7 +27,7 @@ from markdown import markdown from starlette.middleware.sessions import SessionMiddleware -from orchestrator import createOrchestrator +from orchestrator import Orchestrator routes = APIRouter() templates = Jinja2Templates(directory="templates") @@ -195,13 +195,19 @@ async def decline_flight(request: Request): # This is helpful in case of reloads so there doesn't seem to be a break in communication. orchestrator = request.app.state.orchestrator response = await orchestrator.user_session_decline_ticket(request.session["uuid"]) + response = ( + response["output"] + if response + else "Booking declined. What else can I help you with?" + ) request.session["history"].append( {"type": "ai", "data": {"content": "Please confirm if you would like to book."}} ) request.session["history"].append( {"type": "human", "data": {"content": "I changed my mind."}} ) - return None + request.session["history"].append({"type": "ai", "data": {"content": response}}) + return response @routes.post("/reset") @@ -237,16 +243,13 @@ def clear_user_info(session: dict[str, Any]): def init_app( - orchestration_type: Optional[str], client_id: Optional[str], middleware_secret: Optional[str], ) -> FastAPI: # FastAPI setup - if orchestration_type is None: - raise HTTPException(status_code=500, detail="Orchestrator not found") app = FastAPI(lifespan=lifespan) app.state.client_id = client_id - app.state.orchestrator = createOrchestrator(orchestration_type) + app.state.orchestrator = Orchestrator() app.include_router(routes) app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware(SessionMiddleware, secret_key=middleware_secret) @@ -256,12 +259,9 @@ def init_app( if __name__ == "__main__": PORT = int(os.getenv("PORT", default=8081)) HOST = os.getenv("HOST", default="0.0.0.0") - ORCHESTRATION_TYPE = os.getenv("ORCHESTRATION_TYPE", default="langchain-tools") CLIENT_ID = os.getenv("CLIENT_ID") MIDDLEWARE_SECRET = os.getenv("MIDDLEWARE_SECRET", default="this is a secret") - app = init_app( - ORCHESTRATION_TYPE, client_id=CLIENT_ID, middleware_secret=MIDDLEWARE_SECRET - ) + app = init_app(client_id=CLIENT_ID, middleware_secret=MIDDLEWARE_SECRET) if app is None: raise TypeError("app not instantiated") uvicorn.run(app, host=HOST, port=PORT) diff --git a/llm_demo/app_test.py b/app_test.py similarity index 100% rename from llm_demo/app_test.py rename to app_test.py diff --git a/llm_demo/evaluation.cloudbuild.yaml b/evaluation.cloudbuild.yaml similarity index 92% rename from llm_demo/evaluation.cloudbuild.yaml rename to evaluation.cloudbuild.yaml index 7a5136a16..0f4ec02d3 100644 --- a/llm_demo/evaluation.cloudbuild.yaml +++ b/evaluation.cloudbuild.yaml @@ -14,14 +14,11 @@ steps: - id: Install dependencies name: python:3.11 - dir: llm_demo script: pip install -r requirements.txt -r requirements-test.txt --user - id: "Run evaluation service" name: python:3.11 - dir: llm_demo env: # Set env var expected by tests - - "ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE}" - "RETRIEVAL_EXPERIMENT_NAME=${_RETRIEVAL_EXPERIMENT_NAME}" - "RESPONSE_EXPERIMENT_NAME=${_RESPONSE_EXPERIMENT_NAME}" secretEnv: @@ -37,7 +34,6 @@ options: dynamic_substitutions: true substitutions: - _ORCHESTRATION_TYPE: "langchain-tools" _RETRIEVAL_EXPERIMENT_NAME: "retrieval-phase-eval-${_PR_NUMBER}" _RESPONSE_EXPERIMENT_NAME: "response-phase-eval-${_PR_NUMBER}" diff --git a/llm_demo/evaluation/__init__.py b/evaluation/__init__.py similarity index 100% rename from llm_demo/evaluation/__init__.py rename to evaluation/__init__.py diff --git a/llm_demo/evaluation/eval_golden.py b/evaluation/eval_golden.py similarity index 100% rename from llm_demo/evaluation/eval_golden.py rename to evaluation/eval_golden.py diff --git a/llm_demo/evaluation/evaluation.py b/evaluation/evaluation.py similarity index 95% rename from llm_demo/evaluation/evaluation.py rename to evaluation/evaluation.py index fa6d0caca..8208687f8 100644 --- a/llm_demo/evaluation/evaluation.py +++ b/evaluation/evaluation.py @@ -17,22 +17,20 @@ from typing import Dict, List import pandas as pd -from pydantic import BaseModel, Field from vertexai.evaluation import EvalTask from vertexai.evaluation import _base as evaluation_base -from orchestrator import BaseOrchestrator +from orchestrator import Orchestrator from .eval_golden import EvalData, ToolCall from .metrics import response_phase_metrics, retrieval_phase_metrics async def run_llm_for_eval( - eval_list: List[EvalData], orc: BaseOrchestrator, session: Dict, session_id: str + eval_list: List[EvalData], orc: Orchestrator, session: Dict, session_id: str ) -> List[EvalData]: """ Generate llm_tool_calls and llm_output for golden dataset query. - This function is only compatible with the langchain-tools orchestration. """ agent = orc.get_user_session(session_id) for eval_data in eval_list: diff --git a/llm_demo/evaluation/metrics.py b/evaluation/metrics.py similarity index 100% rename from llm_demo/evaluation/metrics.py rename to evaluation/metrics.py diff --git a/llm_demo/langgraph.int.tests.cloudbuild.yaml b/integration.cloudbuild.yaml similarity index 93% rename from llm_demo/langgraph.int.tests.cloudbuild.yaml rename to integration.cloudbuild.yaml index 65c817620..90549cca9 100644 --- a/llm_demo/langgraph.int.tests.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -14,14 +14,13 @@ steps: - id: "Deploy to Cloud Run" name: "gcr.io/cloud-builders/gcloud:latest" - dir: llm_demo script: | #!/usr/bin/env bash gcloud run deploy ${_SERVICE} \ --source . \ --region ${_REGION} \ --no-allow-unauthenticated \ - --update-env-vars ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} + --update-env-vars TOOLBOX_URL=${_TOOLBOX_URL} - id: "Test Frontend" name: "gcr.io/cloud-builders/gcloud:latest" @@ -32,7 +31,6 @@ steps: - | export URL=$(gcloud run services describe ${_SERVICE} --region ${_REGION} --format 'value(status.url)') export ID_TOKEN=$(gcloud auth print-identity-token --audiences $$URL) - export ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} # Test `/` route curl -c cookies.txt -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL @@ -77,4 +75,4 @@ substitutions: _GCR_HOSTNAME: ${_REGION}-docker.pkg.dev _SERVICE: demo-service-${BUILD_ID} _REGION: us-central1 - _ORCHESTRATION_TYPE: langgraph + _TOOLBOX_URL: https://toolbox-107716898620.us-central1.run.app \ No newline at end of file diff --git a/llm_demo/langchain_tools.int.tests.cloudbuild.yaml b/llm_demo/langchain_tools.int.tests.cloudbuild.yaml deleted file mode 100644 index 0ad807b78..000000000 --- a/llm_demo/langchain_tools.int.tests.cloudbuild.yaml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2023 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: "Deploy to Cloud Run" - name: "gcr.io/cloud-builders/gcloud:latest" - dir: llm_demo - script: | - #!/usr/bin/env bash - gcloud run deploy ${_SERVICE} \ - --source . \ - --region ${_REGION} \ - --no-allow-unauthenticated \ - --update-env-vars ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} - - - id: "Test Frontend" - name: "gcr.io/cloud-builders/gcloud:latest" - entrypoint: /bin/bash - env: # Set env var expected by app - args: - - "-c" - - | - export URL=$(gcloud run services describe ${_SERVICE} --region ${_REGION} --format 'value(status.url)') - export ID_TOKEN=$(gcloud auth print-identity-token --audiences $$URL) - export ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} - - # Test `/` route - curl -c cookies.txt -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL - - # Test `/chat` route should fail - msg=$(curl -si --show-error \ - -X POST \ - -H "Authorization: Bearer $$ID_TOKEN" \ - -H 'Content-Type: application/json' \ - -d '{"prompt":"How can you help me?"}' \ - $$URL/chat) - - if grep -q "400" <<< "$msg"; then - echo "Chat Handler Test: PASSED" - else - echo "Chat Handler Test: FAILED" - echo $msg && exit 1 - fi - - # Test `/chat` route - curl -b cookies.txt -si --fail --show-error \ - -X POST \ - -H "Authorization: Bearer $$ID_TOKEN" \ - -H 'Content-Type: application/json' \ - -d '{"prompt":"How can you help me?"}' \ - $$URL/chat - - - id: "Delete image and service" - name: "gcr.io/cloud-builders/gcloud" - script: | - #!/usr/bin/env bash - gcloud artifacts docker images delete $_GCR_HOSTNAME/$PROJECT_ID/cloud-run-source-deploy/$_SERVICE --quiet - gcloud run services delete ${_SERVICE} --region ${_REGION} --quiet - -serviceAccount: "projects/$PROJECT_ID/serviceAccounts/548341735270-compute@developer.gserviceaccount.com" # Necessary for ID token creation -options: - automapSubstitutions: true - logging: CLOUD_LOGGING_ONLY # Necessary for custom service account - dynamic_substitutions: true - -substitutions: - _GCR_HOSTNAME: ${_REGION}-docker.pkg.dev - _SERVICE: demo-service-${BUILD_ID} - _REGION: us-central1 - _ORCHESTRATION_TYPE: langchain-tools diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py deleted file mode 100644 index cc4edd679..000000000 --- a/llm_demo/orchestrator/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import langchain_tools, langgraph, vertexai_function_calling -from .orchestrator import BaseOrchestrator, createOrchestrator - -__ALL__ = [ - "BaseOrchestrator", - "createOrchestrator", - "langchain_tools", - "vertexai_function_calling", - "langgraph", -] diff --git a/llm_demo/orchestrator/langchain_tools/__init__.py b/llm_demo/orchestrator/langchain_tools/__init__.py deleted file mode 100644 index 0633529c2..000000000 --- a/llm_demo/orchestrator/langchain_tools/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .langchain_tools_orchestrator import LangChainToolsOrchestrator - -__ALL__ = ["LangChainToolsOrchestrator"] diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py deleted file mode 100644 index f1ee7679e..000000000 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import os -import uuid -from datetime import datetime -from typing import Any, Dict, List, Optional - -from aiohttp import ClientSession, TCPConnector -from fastapi import HTTPException -from langchain.agents import AgentType, initialize_agent -from langchain.agents.agent import AgentExecutor -from langchain.globals import set_verbose # type: ignore -from langchain.memory import ConversationBufferMemory -from langchain_community.chat_message_histories import ChatMessageHistory -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.tools import StructuredTool -from langchain_google_vertexai import ChatVertexAI -from pytz import timezone - -from ..orchestrator import BaseOrchestrator, classproperty -from .tools import ( - get_confirmation_needing_tools, - initialize_tools, - insert_ticket, - validate_ticket, -) - -set_verbose(bool(os.getenv("DEBUG", default=False))) -BASE_HISTORY = { - "type": "ai", - "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, -} - - -class UserAgent: - client: ClientSession - agent: AgentExecutor - - def __init__( - self, - client: ClientSession, - agent: AgentExecutor, - memory: ConversationBufferMemory, - ): - self.client = client - self.agent = agent - self.memory = memory - - @classmethod - def initialize_agent( - cls, - client: ClientSession, - tools: List[StructuredTool], - history: List[BaseMessage], - prompt: ChatPromptTemplate, - model: str, - ) -> "UserAgent": - # TODO: Use .bind_tools(tools) to bind the tools with the LLM. - llm = ChatVertexAI(max_output_tokens=512, model_name=model, temperature=0.0) - memory = ConversationBufferMemory( - chat_memory=ChatMessageHistory(messages=history), - memory_key="chat_history", - input_key="input", - output_key="output", - ) - agent = initialize_agent( - tools, - llm, - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - memory=memory, - handle_parsing_errors=True, - max_iterations=3, - early_stopping_method="generate", - return_intermediate_steps=True, - ) - agent.agent.llm_chain.prompt = prompt # type: ignore - return UserAgent(client, agent, memory) - - async def close(self): - await self.client.close() - - async def invoke(self, prompt: str) -> Dict[str, Any]: - try: - response = await self.agent.ainvoke({"input": prompt}) - except Exception as err: - raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") - return response - - async def insert_ticket(self, params: str): - return await insert_ticket(self.client, params) - - def reset_memory(self, base_message: List[BaseMessage]): - self.memory.clear() - self.memory.chat_memory = ChatMessageHistory(messages=base_message) - - -class LangChainToolsOrchestrator(BaseOrchestrator): - _user_sessions: Dict[str, UserAgent] - # aiohttp context - connector = None - - def __init__(self): - self._user_sessions = {} - - @classproperty - def kind(cls): - return "langchain-tools" - - def user_session_exist(self, uuid: str) -> bool: - return uuid in self._user_sessions - - async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: - user_session = self.get_user_session(uuid) - response = await user_session.insert_ticket(params) - return response - - async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: - """ - Used if there's a process to be done after user decline ticket. - Return None is nothing is needed to be done. - """ - return None - - async def check_and_add_confirmations(self, response: Dict[str, Any]): - for step in response.get("intermediate_steps") or []: - if len(step) > 0: - # Find the called tool in the step - called_tool = step[0] - # Check to see if the agent has made a decision to call Prepare Insert Ticket - # This tool is a no-op and requires user confirmation before continuing - if called_tool.tool in self.confirmation_needing_tools: - if called_tool.tool == "Insert Ticket": - flight_info = await validate_ticket( - self.client, called_tool.tool_input - ) - return {"tool": called_tool.tool, "params": flight_info} - return {"tool": called_tool.tool, "params": called_tool.tool_input} - return None - - async def user_session_create(self, session: dict[str, Any]): - """Create and load an agent executor with tools and LLM.""" - print("Initializing agent..") - if "uuid" not in session: - session["uuid"] = str(uuid.uuid4()) - id = session["uuid"] - if "history" not in session: - session["history"] = [BASE_HISTORY] - history = self.parse_messages(session["history"]) - client = await self.create_client_session() - tools = await initialize_tools(client) - prompt = self.create_prompt_template(tools) - agent = UserAgent.initialize_agent(client, tools, history, prompt, self.MODEL) - self._user_sessions[id] = agent - self.confirmation_needing_tools = get_confirmation_needing_tools() - self.client = client - - async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]: - user_session = self.get_user_session(uuid) - # Send prompt to LLM - agent_response = await user_session.invoke(prompt) - # Check for calls that may require confirmation to proceed - confirmation = await self.check_and_add_confirmations(agent_response) - # Build final response - response = {} - response["output"] = agent_response.get("output") - if confirmation: - response["confirmation"] = confirmation - return response - - def user_session_reset(self, session: dict[str, Any], uuid: str): - user_session = self.get_user_session(uuid) - del session["history"] - base_history = self.get_base_history(session) - session["history"] = [base_history] - history = self.parse_messages(session["history"]) - user_session.reset_memory(history) - - def get_user_session(self, uuid: str) -> UserAgent: - return self._user_sessions[uuid] - - async def get_connector(self) -> TCPConnector: - if self.connector is None: - self.connector = TCPConnector(limit=100) - return self.connector - - async def create_client_session(self) -> ClientSession: - return ClientSession( - connector=await self.get_connector(), - connector_owner=False, - headers={}, - raise_for_status=True, - ) - - def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate: - # Create new prompt template - tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] - ) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = FORMAT_INSTRUCTIONS.format( - tool_names=tool_names, - ) - current_datetime = "Today's date and current time is {cur_datetime}." - template = "\n\n".join( - [ - PREFIX, - current_datetime, - TOOLS_PREFIX, - tool_strings, - format_instructions, - SUFFIX, - ] - ) - human_message_template = "{input}\n\n{agent_scratchpad}" - - prompt = ChatPromptTemplate.from_messages( - [("system", template), ("human", human_message_template)] - ) - prompt = prompt.partial(cur_datetime=self.get_datetime) - return prompt - - def get_datetime(self): - formatter = "%A, %m/%d/%Y, %H:%M:%S" - now = datetime.now(timezone("US/Pacific")) - return now.strftime(formatter) - - def parse_messages(self, datas: List[Any]) -> List[BaseMessage]: - messages: List[BaseMessage] = [] - for data in datas: - if data["type"] == "human": - messages.append(HumanMessage(content=data["data"]["content"])) - elif data["type"] == "ai": - messages.append(AIMessage(content=data["data"]["content"])) - else: - raise Exception("Message type not found.") - return messages - - def get_base_history(self, session: dict[str, Any]): - if "user_info" in session: - base_history = { - "type": "ai", - "data": { - "content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?" - }, - } - return base_history - return BASE_HISTORY - - async def user_session_signout(self, uuid: str): - user_session = self.get_user_session(uuid) - if user_session: - await user_session.close() - del self._user_sessions[uuid] - - async def close_clients(self): - close_client_tasks = [ - asyncio.create_task(a.close()) for a in self._user_sessions.values() - ] - await asyncio.gather(*close_client_tasks) - - -PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. - -Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its -hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer -service! - -Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist -with a wide range of tasks, from answering simple questions to complex multi-query questions that -require passing results from one query to another. Using the latest AI models, Assistant is able to -generate human-like text based on the input it receives, allowing it to engage in natural-sounding -conversations and provide responses that are coherent and relevant to the topic at hand. The assistant should -not answer questions about other peoples information for privacy reasons. - -Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air -as well as ammenities of San Francisco Airport.""" - -TOOLS_PREFIX = """ -TOOLS: ------- - -Assistant has access to the following tools:""" - -FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) -and an action_input key (tool input). - -Valid "action" values: "Final Answer" or {tool_names} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -```""" - -SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate. -If using a tool, reminder to ALWAYS respond with a valid json blob of a single action. -Format is Action:```$JSON_BLOB```then Observation:. -Thought: - -Previous conversation history: -{chat_history} -""" diff --git a/llm_demo/orchestrator/langchain_tools/tools.py b/llm_demo/orchestrator/langchain_tools/tools.py deleted file mode 100644 index 9aa006577..000000000 --- a/llm_demo/orchestrator/langchain_tools/tools.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from datetime import date, datetime -from typing import Any, Dict, Optional - -import aiohttp -import google.oauth2.id_token # type: ignore -from google.auth import compute_engine # type: ignore -from google.auth.transport.requests import Request # type: ignore -from langchain_core.tools import StructuredTool -from pydantic import BaseModel, Field - -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -CREDENTIALS = None - - -def filter_none_values(params: Dict) -> Dict: - return {key: value for key, value in params.items() if value is not None} - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not hasattr(CREDENTIALS, "id_token"): - # Use Compute Engine default credential - CREDENTIALS = compute_engine.IDTokenCredentials( - request=Request(), - target_audience=BASE_URL, - use_metadata_identity_endpoint=True, - ) - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - if hasattr(CREDENTIALS, "id_token"): - return CREDENTIALS.id_token - else: - return CREDENTIALS.token - - -def get_headers(client: aiohttp.ClientSession): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers - - -# Tools -class AirportSearchInput(BaseModel): - country: Optional[str] = Field(description="Country") - city: Optional[str] = Field(description="City") - name: Optional[str] = Field(description="Airport name") - - -def generate_search_airports(client: aiohttp.ClientSession): - async def search_airports(country: str, city: str, name: str): - params = { - "country": country, - "city": city, - "name": name, - } - response = await client.get( - url=f"{BASE_URL}/airports/search", - params=filter_none_values(params), - headers=get_headers(client), - ) - - response_json = await response.json() - response_results = response_json.get("results") - if len(response_results) < 1: - return "There are no airports matching that query. Let the user know there are no results." - else: - return response_results - - return search_airports - - -class FlightNumberInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - - -def generate_search_flights_by_number(client: aiohttp.ClientSession): - async def search_flights_by_number(airline: str, flight_number: str): - response = await client.get( - url=f"{BASE_URL}/flights/search", - params={"airline": airline, "flight_number": flight_number}, - headers=get_headers(client), - ) - - response_json = await response.json() - return response_json.get("results") - - return search_flights_by_number - - -class ListFlights(BaseModel): - departure_airport: Optional[str] = Field( - description="Departure airport 3-letter code", - ) - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - date: str = Field(description="Date of flight departure") - - -def generate_list_flights(client: aiohttp.ClientSession): - async def list_flights( - departure_airport: str, - arrival_airport: str, - date: str, - ): - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "date": date, - } - response = await client.get( - url=f"{BASE_URL}/flights/search", - params=filter_none_values(params), - headers=get_headers(client), - ) - - response_json = await response.json() - response_results = response_json.get("results") - if len(response_results) < 1: - return "There are no flights matching that query. Let the user know there are no results." - else: - return response_results - - return list_flights - - -class QueryInput(BaseModel): - query: str = Field(description="Search query") - - -def generate_search_amenities(client: aiohttp.ClientSession): - async def search_amenities(query: str): - response = await client.get( - url=f"{BASE_URL}/amenities/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client), - ) - - response_json = await response.json() - response_results = response_json.get("results") - return response_results - - return search_amenities - - -def generate_search_policies(client: aiohttp.ClientSession): - async def search_policies(query: str): - response = await client.get( - url=f"{BASE_URL}/policies/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client), - ) - - response_json = await response.json() - response_results = response_json.get("results") - return response_results - - return search_policies - - -class TicketInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - departure_airport: str = Field( - description="Departure airport 3-letter code", - ) - departure_time: datetime = Field(description="Flight departure datetime") - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - arrival_time: Optional[datetime] = Field(description="Flight arrival datetime") - - -def generate_insert_ticket(client: aiohttp.ClientSession): - async def insert_ticket( - airline: str | None = None, - flight_number: str | None = None, - departure_airport: str | None = None, - arrival_airport: str | None = None, - departure_time: datetime | date | None = None, - arrival_time: datetime | date | None = None, - ): - return f"Booking ticket on {airline} {flight_number}" - - return insert_ticket - - -async def insert_ticket(client: aiohttp.ClientSession, params: str): - ticket_info = json.loads(params) - response = await client.post( - url=f"{BASE_URL}/tickets/insert", - params={ - "airline": ticket_info.get("airline"), - "flight_number": ticket_info.get("flight_number"), - "departure_airport": ticket_info.get("departure_airport"), - "arrival_airport": ticket_info.get("arrival_airport"), - "departure_time": ticket_info.get("departure_time").replace("T", " "), - "arrival_time": ticket_info.get("arrival_time").replace("T", " "), - }, - headers=get_headers(client), - ) - response_json = await response.json() - return "Flight booking successful." - - -async def validate_ticket(client: aiohttp.ClientSession, ticket_info: Dict[Any, Any]): - response = await client.get( - url=f"{BASE_URL}/tickets/validate", - params=filter_none_values( - { - "airline": ticket_info.get("airline"), - "flight_number": ticket_info.get("flight_number"), - "departure_airport": ticket_info.get("departure_airport"), - "departure_time": ticket_info.get("departure_time", "").replace( - "T", " " - ), - } - ), - headers=get_headers(client), - ) - response_json = await response.json() - response_results = response_json.get("results") - - flight_info = { - "airline": response_results.get("airline"), - "flight_number": response_results.get("flight_number"), - "departure_airport": response_results.get("departure_airport"), - "arrival_airport": response_results.get("arrival_airport"), - "departure_time": response_results.get("departure_time"), - "arrival_time": response_results.get("arrival_time"), - } - return flight_info - - -def generate_list_tickets(client: aiohttp.ClientSession): - async def list_tickets(): - response = await client.get( - url=f"{BASE_URL}/tickets/list", - headers=get_headers(client), - ) - - response_json = await response.json() - tickets = response_json.get("results") - if len(tickets) == 0: - return { - "results": "There are no upcoming tickets", - "sql": response_json.get("sql"), - } - else: - return response_json - - return list_tickets - - -# Tools for agent -async def initialize_tools(client: aiohttp.ClientSession): - return [ - StructuredTool.from_function( - coroutine=generate_search_airports(client), - name="Search Airport", - description=""" - Use this tool to list all airports matching search criteria. - Takes at least one of country, city, name, or all and returns all matching airports. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - country, city, name. - Example: - {{ - "country": "United States", - "city": "San Francisco", - "name": null - }} - Example: - {{ - "country": null, - "city": "Goroka", - "name": "Goroka" - }} - Example: - {{ - "country": "Mexico", - "city": null, - "name": null - }} - """, - args_schema=AirportSearchInput, - ), - StructuredTool.from_function( - coroutine=generate_search_flights_by_number(client), - name="Search Flights By Flight Number", - description=""" - Use this tool to get information for a specific flight. - Takes an airline code and flight number and returns info on the flight. - Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number. - A airline code is a code for an airline service consisting of two-character - airline designator and followed by flight number, which is 1 to 4 digit number. - For example, if given CY 0123, the airline is "CY", and flight_number is "123". - Another example for this is DL 1234, the airline is "DL", and flight_number is "1234". - If the tool returns more than one option choose the date closes to today. - Example: - {{ - "airline": "CY", - "flight_number": "888", - }} - Example: - {{ - "airline": "DL", - "flight_number": "1234", - }} - """, - args_schema=FlightNumberInput, - ), - StructuredTool.from_function( - coroutine=generate_list_flights(client), - name="List Flights", - description=""" - Use this tool to list flights information matching search criteria. - Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. - If 3-letter iata code is not provided for departure_airport or arrival_airport, use search airport tools to get iata code information. - Do NOT guess a date, ask user for date input if it is not given. Date must be in the following format: YYYY-MM-DD. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date. - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": null, - "date": 2025-10-30" - }} - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": "SEA", - "date": "2025-11-01" - }} - Example: - {{ - "departure_airport": null, - "arrival_airport": "SFO", - "date": "2025-01-01" - }} - """, - args_schema=ListFlights, - ), - StructuredTool.from_function( - coroutine=generate_search_amenities(client), - name="Search Amenities", - description=""" - Use this tool to search amenities by name or to recommended airport amenities at SFO. - If user provides flight info, use 'Search Flights by Flight Number' - first to get gate info and location. - Only recommend amenities that are returned by this query. - Find amenities close to the user by matching the terminal and then comparing - the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 - B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. - Input of this tool must be in JSON format and include one `query` input. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_search_policies(client), - name="Search Policies", - description=""" - Use this tool to search for cymbal air passenger policy. - Policy that are listed is unchangeable. - You will not answer any questions outside of the policy given. - Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. - Input of this tool must be in JSON format and include one `query` input. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_insert_ticket(client), - name="Insert Ticket", - description=""" - Use this tool to book a flight ticket for the user. - Example: - {{ - "airline": "AA", - "flight_number": "452", - "departure_airport": "LAX", - "arrival_airport": "SFO", - "departure_time": "2025-01-01 05:50:00", - "arrival_time": "2025-01-01 09:23:00" - }} - Example: - {{ - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "arrival_airport": "DEN", - "departure_time": "2025-01-08 05:50:00", - "arrival_time": "2025-01-08 09:23:00" - }} - Example: - {{ - "airline": "OO", - "flight_number": "6307", - "departure_airport": "SFO", - "arrival_airport": "MSP", - "departure_time": "2025-10-28 20:13:00", - "arrival_time": "2025-10-28 21:07:00" - }} - """, - args_schema=TicketInput, - ), - StructuredTool.from_function( - coroutine=generate_list_tickets(client), - name="List Tickets", - description=""" - Use this tool to list a user's flight tickets. - Takes no input and returns a list of current user's flight tickets. - Input is always empty JSON blob. Example: {{}} - """, - ), - ] - - -def get_confirmation_needing_tools(): - return ["Insert Ticket"] diff --git a/llm_demo/orchestrator/langgraph/tool_node.py b/llm_demo/orchestrator/langgraph/tool_node.py deleted file mode 100644 index fd3111a72..000000000 --- a/llm_demo/orchestrator/langgraph/tool_node.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import copy -import json -import uuid -from itertools import repeat -from typing import Any, Callable, Dict, Optional, Sequence, Union - -from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage -from langchain_core.runnables import RunnableConfig -from langchain_core.runnables.config import get_executor_for_config -from langchain_core.tools import BaseTool -from langchain_core.tools import tool as create_tool -from langgraph.utils.runnable import RunnableCallable - - -def str_output(output: Any) -> str: - if isinstance(output, str): - return output - else: - try: - return json.dumps(output) - except Exception: - return str(output) - - -class ToolNode(RunnableCallable): - """ - A node that runs the tools requested in the last AIMessage. It can be used - either in StateGraph with a "messages" key or in MessageGraph. If multiple - tool calls are requested, they will be run in parallel. The output will be - a list of ToolMessages, one for each tool call. - """ - - def __init__( - self, - tools: Sequence[Union[BaseTool, Callable]], - *, - name: str = "tools", - tags: Optional[list[str]] = None, - ) -> None: - super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) - self.tools_by_name: Dict[str, BaseTool] = {} - for tool_ in tools: - if not isinstance(tool_, BaseTool): - tool_ = create_tool(tool_) - else: - base_tool_ = tool_ - if hasattr(tool_, "name"): - self.tools_by_name[tool_.name] = base_tool_ - - def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any: - if messages := input.get("messages", []): - output_type = "dict" - message = messages[-1] - else: - raise ValueError("No message found in input") - - if not isinstance(message, AIMessage): - raise ValueError("Last message is not an AIMessage") - - user_id_token = input.get("user_id_token") - - def run_one(call: ToolCall, user_id_token: Optional[str]): - args = copy.copy(call["args"]) or {} - args["user_id_token"] = user_id_token - response = self.tools_by_name[call["name"]].invoke(args, config) - output = response.get("results") - sql = response.get("sql") - tool_call_id = call.get("id") or str(uuid.uuid4()) - return ToolMessage( - content=str_output(output), - name=call["name"], - tool_call_id=tool_call_id, - additional_kwargs={"sql": sql}, - ) - - with get_executor_for_config(config) as executor: - outputs = [ - *executor.map(run_one, message.tool_calls, repeat(user_id_token)) - ] - if output_type == "list": - return outputs - else: - return {"messages": outputs} - - async def _afunc(self, input: dict[str, Any], config: RunnableConfig) -> Any: - if messages := input.get("messages", []): - output_type = "dict" - message = messages[-1] - else: - raise ValueError("No message found in input") - - if not isinstance(message, AIMessage): - raise ValueError("Last message is not an AIMessage") - - user_id_token = input.get("user_id_token") - - async def run_one(call: ToolCall, user_id_token: Optional[str]): - args = copy.copy(call["args"]) or {} - args["user_id_token"] = user_id_token - response = await self.tools_by_name[call["name"]].ainvoke(args, config) - output = response.get("results") - sql = response.get("sql") - tool_call_id = call.get("id") or str(uuid.uuid4()) - return ToolMessage( - content=str_output(output), - name=call["name"], - tool_call_id=tool_call_id, - additional_kwargs={"sql": sql}, - ) - - outputs = await asyncio.gather( - *(run_one(call, user_id_token) for call in message.tool_calls) - ) - if output_type == "list": - return outputs - else: - return {"messages": outputs} diff --git a/llm_demo/orchestrator/langgraph/tools.py b/llm_demo/orchestrator/langgraph/tools.py deleted file mode 100644 index 18ed8da11..000000000 --- a/llm_demo/orchestrator/langgraph/tools.py +++ /dev/null @@ -1,455 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from dataclasses import dataclass -from datetime import date, datetime -from typing import Any, Dict, Optional - -import aiohttp -import google.oauth2.id_token # type: ignore -from google.auth import compute_engine # type: ignore -from google.auth.transport.requests import Request # type: ignore -from langchain_core.tools import StructuredTool -from pydantic import BaseModel, Field - -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -CREDENTIALS = None - - -def filter_none_values(params: Dict) -> Dict: - return {key: value for key, value in params.items() if value is not None} - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not hasattr(CREDENTIALS, "id_token"): - # Use Compute Engine default credential - CREDENTIALS = compute_engine.IDTokenCredentials( - request=Request(), - target_audience=BASE_URL, - use_metadata_identity_endpoint=True, - ) - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - if hasattr(CREDENTIALS, "id_token"): - return CREDENTIALS.id_token - else: - return CREDENTIALS.token - - -def get_headers(client: aiohttp.ClientSession, user_id_token: str): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - headers["User-Id-Token"] = f"Bearer {user_id_token}" - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers - - -# Tools -class AirportSearchInput(BaseModel): - country: Optional[str] = Field(description="Country") - city: Optional[str] = Field(description="City") - name: Optional[str] = Field(description="Airport name") - user_id_token: Optional[str] - - -def generate_search_airports(client: aiohttp.ClientSession): - async def search_airports(country: str, city: str, name: str, user_id_token: str): - params = { - "country": country, - "city": city, - "name": name, - } - response = await client.get( - url=f"{BASE_URL}/airports/search", - params=filter_none_values(params), - headers=get_headers(client, user_id_token), - ) - - response_json = await response.json() - if len(response_json) < 1: - return "There are no airports matching that query. Let the user know there are no results." - else: - return response_json - - return search_airports - - -class FlightNumberInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - user_id_token: Optional[str] - - -def generate_search_flights_by_number(client: aiohttp.ClientSession): - async def search_flights_by_number( - airline: str, flight_number: str, user_id_token: str - ): - response = await client.get( - url=f"{BASE_URL}/flights/search", - params={"airline": airline, "flight_number": flight_number}, - headers=get_headers(client, user_id_token), - ) - - return await response.json() - - return search_flights_by_number - - -class ListFlightsInput(BaseModel): - departure_airport: Optional[str] = Field( - description="Departure airport 3-letter code", - ) - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - date: str = Field(description="Date of flight departure") - user_id_token: Optional[str] - - -def generate_list_flights(client: aiohttp.ClientSession): - async def list_flights( - departure_airport: str, - arrival_airport: str, - date: str, - user_id_token: str, - ): - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "date": date, - } - response = await client.get( - url=f"{BASE_URL}/flights/search", - params=filter_none_values(params), - headers=get_headers(client, user_id_token), - ) - - response_json = await response.json() - if len(response_json) < 1: - return { - "results": "There are no flights matching that query. Let the user know there are no results." - } - else: - return response_json - - return list_flights - - -class QueryInput(BaseModel): - query: str = Field(description="Search query") - user_id_token: Optional[str] - - -def generate_search_amenities(client: aiohttp.ClientSession): - async def search_amenities(query: str, user_id_token: str): - response = await client.get( - url=f"{BASE_URL}/amenities/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client, user_id_token), - ) - - response = await response.json() - return response - - return search_amenities - - -def generate_search_policies(client: aiohttp.ClientSession): - async def search_policies(query: str, user_id_token: str): - response = await client.get( - url=f"{BASE_URL}/policies/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client, user_id_token), - ) - - response = await response.json() - return response - - return search_policies - - -class TicketInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - departure_airport: str = Field( - description="Departure airport 3-letter code", - ) - departure_time: datetime = Field(description="Flight departure datetime") - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - arrival_time: Optional[datetime] = Field(description="Flight arrival datetime") - - -def generate_insert_ticket(client: aiohttp.ClientSession): - async def insert_ticket( - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: datetime, - arrival_time: datetime, - ): - return {"results": f"Booking ticket on {airline} {flight_number}"} - - return insert_ticket - - -@dataclass -class TicketInfo: - airline: str - flight_number: str - departure_airport: str - departure_time: str - arrival_airport: str - arrival_time: str - - -async def insert_ticket( - client: aiohttp.ClientSession, ticket_info: TicketInfo, user_id_token: str -): - response = await client.post( - url=f"{BASE_URL}/tickets/insert", - params={ - "airline": ticket_info.airline, - "flight_number": ticket_info.flight_number, - "departure_airport": ticket_info.departure_airport, - "arrival_airport": ticket_info.arrival_airport, - "departure_time": ticket_info.departure_time.replace("T", " "), - "arrival_time": ticket_info.arrival_time.replace("T", " "), - }, - headers=get_headers(client, user_id_token), - ) - response = await response.json() - return "Flight booking successful." - - -async def validate_ticket( - client: aiohttp.ClientSession, ticket_info: Dict[Any, Any], user_id_token: str -): - response = await client.get( - url=f"{BASE_URL}/tickets/validate", - params=filter_none_values( - { - "airline": ticket_info.get("airline"), - "flight_number": ticket_info.get("flight_number"), - "departure_airport": ticket_info.get("departure_airport"), - "departure_time": ticket_info.get("departure_time", "").replace( - "T", " " - ), - } - ), - headers=get_headers(client, user_id_token), - ) - response_json = await response.json() - response_results = response_json.get("results") - - flight_info = { - "airline": response_results.get("airline"), - "flight_number": response_results.get("flight_number"), - "departure_airport": response_results.get("departure_airport"), - "arrival_airport": response_results.get("arrival_airport"), - "departure_time": response_results.get("departure_time"), - "arrival_time": response_results.get("arrival_time"), - } - return flight_info - - -def generate_list_tickets(client: aiohttp.ClientSession): - async def list_tickets(user_id_token: str): - response = await client.get( - url=f"{BASE_URL}/tickets/list", - headers=get_headers(client, user_id_token), - ) - - response_json = await response.json() - tickets = response_json.get("results") - if len(tickets) == 0: - return { - "results": "There are no upcoming tickets", - "sql": response_json.get("sql"), - } - else: - return response_json - - return list_tickets - - -# Tools for agent -async def initialize_tools(client: aiohttp.ClientSession): - return [ - StructuredTool.from_function( - coroutine=generate_search_airports(client), - name="Search Airport", - description=""" - Use this tool to list all airports matching search criteria. - Takes at least one of country, city, name, or all and returns all matching airports. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - country, city, name. - Example: - {{ - "country": "United States", - "city": "San Francisco", - "name": null - }} - Example: - {{ - "country": null, - "city": "Goroka", - "name": "Goroka" - }} - Example: - {{ - "country": "Mexico", - "city": null, - "name": null - }} - """, - args_schema=AirportSearchInput, - ), - StructuredTool.from_function( - coroutine=generate_search_flights_by_number(client), - name="Search Flights By Flight Number", - description=""" - Use this tool to get information for a specific flight. - Takes an airline code and flight number and returns info on the flight. - Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number. - A airline code is a code for an airline service consisting of two-character - airline designator and followed by flight number, which is 1 to 4 digit number. - For example, if given CY 0123, the airline is "CY", and flight_number is "123". - Another example for this is DL 1234, the airline is "DL", and flight_number is "1234". - If the tool returns more than one option choose the date closes to today. - Example: - {{ - "airline": "CY", - "flight_number": "888", - }} - Example: - {{ - "airline": "DL", - "flight_number": "1234", - }} - """, - args_schema=FlightNumberInput, - ), - StructuredTool.from_function( - coroutine=generate_list_flights(client), - name="List Flights", - description=""" - Use this tool to list flights information matching search criteria. - Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. - If 3-letter iata code is not provided for departure_airport or arrival_airport, use search airport tools to get iata code information. - Do NOT guess a date, ask user for date input if it is not given. Date must be in the following format: YYYY-MM-DD. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date. - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": null, - "date": 2025-10-30" - }} - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": "SEA", - "date": "2025-11-01" - }} - Example: - {{ - "departure_airport": null, - "arrival_airport": "SFO", - "date": "2025-01-01" - }} - """, - args_schema=ListFlightsInput, - ), - StructuredTool.from_function( - coroutine=generate_search_amenities(client), - name="Search Amenities", - description=""" - Use this tool to search amenities by name or to recommended airport amenities at SFO. - If user provides flight info, use 'Search Flights by Flight Number' - first to get gate info and location. - Only recommend amenities that are returned by this query. - Find amenities close to the user by matching the terminal and then comparing - the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 - B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. - Input of this tool must be in JSON format and include one `query` input. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_search_policies(client), - name="Search Policies", - description=""" - Use this tool to search for cymbal air passenger policy. - Policy that are listed is unchangeable. - You will not answer any questions outside of the policy given. - Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. - Input of this tool must be in JSON format and include one `query` input. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_insert_ticket(client), - name="Insert Ticket", - description=""" - Use this tool to book a flight ticket for the user. - Example: - {{ - "airline": "AA", - "flight_number": "452", - "departure_airport": "LAX", - "arrival_airport": "SFO", - "departure_time": "2025-01-01 05:50:00", - "arrival_time": "2025-01-01 09:23:00" - }} - Example: - {{ - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "arrival_airport": "DEN", - "departure_time": "2025-01-08 05:50:00", - "arrival_time": "2025-01-08 09:23:00" - }} - Example: - {{ - "airline": "OO", - "flight_number": "6307", - "departure_airport": "SFO", - "arrival_airport": "MSP", - "departure_time": "2025-10-28 20:13:00", - "arrival_time": "2025-10-28 21:07:00" - }} - """, - args_schema=TicketInput, - ), - StructuredTool.from_function( - coroutine=generate_list_tickets(client), - name="List Tickets", - description=""" - Use this tool to list a user's flight tickets. - Takes no input and returns a list of current user's flight tickets. - Input is always empty JSON blob. Example: {{}} - """, - ), - ] - - -def get_confirmation_needing_tools(): - return ["Insert Ticket"] diff --git a/llm_demo/orchestrator/orchestrator.py b/llm_demo/orchestrator/orchestrator.py deleted file mode 100644 index 084570b5d..000000000 --- a/llm_demo/orchestrator/orchestrator.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any, Optional - - -class classproperty: - def __init__(self, func): - self.fget = func - - def __get__(self, instance, owner): - return self.fget(owner) - - -class BaseOrchestrator(ABC): - MODEL = "gemini-2.0-flash-001" - - @classproperty - @abstractmethod - def kind(cls): - pass - - @abstractmethod - def user_session_exist(self, uuid: str) -> bool: - """Check if user session exist.""" - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def user_session_create(self, session: dict[str, Any]): - """Create user session for orchestrator.""" - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]: - """Invoke user session and return a response from llm orchestrator.""" - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - def user_session_reset(self, session: dict[str, Any], uuid: str): - """Reset and clear history from user session.""" - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - def get_user_session(self, uuid: str) -> Any: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def user_session_signout(self, uuid: str): - """Sign out from user session. Clear and restart session.""" - raise NotImplementedError("Subclass should implement this!") - - def set_user_session_header(self, uuid: str, user_id_token: str): - user_session = self.get_user_session(uuid) - user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}" - - def get_user_id_token(self, uuid: str) -> Optional[str]: - if self.user_session_exist(uuid): - user_session = self.get_user_session(uuid) - if user_session.client and "User-Id-Token" in user_session.client.headers: - token = user_session.client.headers["User-Id-Token"] - parts = str(token).split(" ") - if len(parts) != 2 or parts[0] != "Bearer": - raise Exception("Invalid ID token") - return parts[1] - return None - - -def createOrchestrator(orchestration_type: str) -> "BaseOrchestrator": - for cls in BaseOrchestrator.__subclasses__(): - s = f"{orchestration_type} == {cls.kind}" - if orchestration_type == cls.kind: - return cls() # type: ignore - raise TypeError(f"No orchestration type of kind {orchestration_type}") diff --git a/llm_demo/orchestrator/vertexai_function_calling/__init__.py b/llm_demo/orchestrator/vertexai_function_calling/__init__.py deleted file mode 100644 index 5b743df4c..000000000 --- a/llm_demo/orchestrator/vertexai_function_calling/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .function_calling_orchestrator import FunctionCallingOrchestrator - -__ALL__ = ["FunctionCallingOrchestrator"] diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py deleted file mode 100644 index 1c17427c1..000000000 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import os -import uuid -from datetime import datetime -from typing import Any, Dict, List, Optional - -from aiohttp import ClientSession, TCPConnector -from fastapi import HTTPException -from google.protobuf.json_format import MessageToDict # type: ignore -from pytz import timezone -from vertexai.preview.generative_models import ( # type: ignore - Content, - GenerationConfig, - GenerativeModel, - Part, -) - -from ..orchestrator import BaseOrchestrator, classproperty -from .functions import ( - BASE_URL, - assistant_tool, - function_request, - get_confirmation_needing_tools, - get_headers, - insert_ticket, -) - -DEBUG = os.getenv("DEBUG", default=False) -BASE_HISTORY = { - "type": "ai", - "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, -} - - -class UserModel: - client: ClientSession - model: GenerativeModel - history: List[Content] - - def __init__(self, client: ClientSession, model: GenerativeModel): - self.client = client - self.model = model - self.history = [] - - @classmethod - def initialize_model(cls, client: ClientSession, model: str) -> "UserModel": - model = GenerativeModel(model, tools=[assistant_tool()]) - return UserModel(client, model) - - async def close(self): - await self.client.close() - - async def invoke(self, input_prompt: str) -> Dict[str, Any]: - prompt = self.get_prompt() - user_prompt_content = Content( - role="user", - parts=[ - Part.from_text(prompt + input_prompt), - ], - ) - self.history.append(user_prompt_content) - model_response = await self.request_model(user_prompt_content) - self.debug_log(f"Prompt:\n{prompt}\n\nQuestion: {input_prompt}.") - self.debug_log(f"\nFunction call response:\n{model_response}") - response_function_call_content = model_response.candidates[0].content - part_response = response_function_call_content.parts[0] - confirmation = None - - # implement multi turn chat with while loop - while "function_call" in part_response._raw_part: - self.history.append(response_function_call_content) - function_call = MessageToDict(part_response.function_call._pb) - function_name = function_call.get("name") - if function_name in get_confirmation_needing_tools(): - function_response = self.confirmation_response( - function_name, function_call.get("args") - ) - confirmation = { - "tool": function_name, - "params": function_call.get("args"), - } - else: - function_response = await self.request_function(function_call) - self.debug_log(f"Function response:\n{function_response}") - part = Part.from_function_response( - name=function_call["name"], - response={ - "content": function_response, - }, - ) - content = Content( - parts=[part], - ) - self.history.append(content) - model_response = await self.request_model(self.history) - response_function_call_content = model_response.candidates[0].content - part_response = response_function_call_content.parts[0] - - if "text" in part_response._raw_part: - model_text = part_response.text - model_content = Content( - role="model", - parts=[ - Part.from_text(model_text), - ], - ) - self.history.append(model_content) - self.debug_log(f"Output content: {model_text}") - return {"output": model_text, "confirmation": confirmation} - else: - raise HTTPException( - status_code=500, detail="Error: Chat model response unknown" - ) - - def get_prompt(self) -> str: - formatter = "%A, %m/%d/%Y, %H:%M:%S" - now = datetime.now(timezone("US/Pacific")).strftime("%A, %m/%d/%Y, %H:%M:%S") - prompt = f"{PREFIX}\nToday's date and current time is {now}." - return prompt - - def debug_log(self, output: str) -> None: - if DEBUG: - print(output) - - async def request_model(self, contents: List[Content]): - try: - model_response = await self.model.generate_content_async( - contents, - generation_config=GenerationConfig(temperature=0), - ) - except Exception as err: - raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") - return model_response - - def confirmation_response(self, function_name, function_params): - if function_name == "insert_ticket": - return f"Booking ticket on {function_params.get('airline')} {function_params.get('flight_number')}" - return "" - - async def request_function(self, function_call): - url = function_request(function_call["name"]) - params = function_call["args"] - self.debug_log(f"Function url is {url}.\nParams is {params}.") - response = await self.client.get( - url=f"{BASE_URL}/{url}", - params=params, - headers=get_headers(self.client), - ) - response_json = await response.json() - response_results = response_json.get("results") - return response_results - - async def insert_ticket(self, params: str): - return await insert_ticket(self.client, params) - - def reset_memory(self, model: str): - """reinitiate chat model to reset memory.""" - del self.history - self.history = [] - - -class FunctionCallingOrchestrator(BaseOrchestrator): - _user_sessions: Dict[str, UserModel] - # aiohttp context - connector = None - - def __init__(self): - self._user_sessions = {} - - @classproperty - def kind(cls): - return "vertexai-function-calling" - - def user_session_exist(self, uuid: str) -> bool: - return uuid in self._user_sessions - - async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: - user_session = self.get_user_session(uuid) - response = await user_session.insert_ticket(params) - return response - - async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: - """ - Used if there's a process to be done after user decline ticket. - Return None is nothing is needed to be done. - """ - return None - - async def user_session_create(self, session: dict[str, Any]): - """Create and load an agent executor with tools and LLM.""" - print("Initializing agent..") - if "uuid" not in session: - session["uuid"] = str(uuid.uuid4()) - id = session["uuid"] - if "history" not in session: - session["history"] = [BASE_HISTORY] - client = await self.create_client_session() - model = UserModel.initialize_model(client, self.MODEL) - self._user_sessions[id] = model - self.client = client - - async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]: - user_session = self.get_user_session(uuid) - # Send prompt to LLM - response = await user_session.invoke(prompt) - return response - - def user_session_reset(self, session: dict[str, Any], uuid: str): - user_session = self.get_user_session(uuid) - del session["history"] - base_history = self.get_base_history(session) - session["history"] = [base_history] - user_session.reset_memory(self.MODEL) - - def get_user_session(self, uuid: str) -> UserModel: - return self._user_sessions[uuid] - - async def get_connector(self) -> TCPConnector: - if self.connector is None: - self.connector = TCPConnector(limit=100) - return self.connector - - async def create_client_session(self) -> ClientSession: - return ClientSession( - connector=await self.get_connector(), - connector_owner=False, - headers={}, - raise_for_status=True, - ) - - def get_base_history(self, session: dict[str, Any]): - if "user_info" in session: - base_history = { - "type": "ai", - "data": { - "content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?" - }, - } - return base_history - return BASE_HISTORY - - async def user_session_signout(self, uuid: str): - user_session = self.get_user_session(uuid) - if user_session: - await user_session.close() - del self._user_sessions[uuid] - - async def close_clients(self): - close_client_tasks = [ - asyncio.create_task(a.close()) for a in self._user_sessions.values() - ] - await asyncio.gather(*close_client_tasks) - - -PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. - -Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its -hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer -service! - -Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist -with a wide range of tasks, from answering simple questions to complex multi-query questions that -require passing results from one query to another. Using the latest AI models, Assistant is able to -generate human-like text based on the input it receives, allowing it to engage in natural-sounding -conversations and provide responses that are coherent and relevant to the topic at hand. - -Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air -as well as ammenities of San Francisco Airport. -""" diff --git a/llm_demo/orchestrator/vertexai_function_calling/functions.py b/llm_demo/orchestrator/vertexai_function_calling/functions.py deleted file mode 100644 index ae11d50a0..000000000 --- a/llm_demo/orchestrator/vertexai_function_calling/functions.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import aiohttp -from vertexai.preview import generative_models # type: ignore - -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -CREDENTIALS = None - -search_airports_func = generative_models.FunctionDeclaration( - name="airports_search", - description=""" - Use this tool to list all airports matching search criteria. - Takes at least one of country, city, name, or all and returns all matching airports. - This function could also be used to serach for airport information such as iata code. - """, - parameters={ - "type": "object", - "properties": { - "country": {"type": "string", "description": "country"}, - "city": {"type": "string", "description": "city"}, - "name": { - "type": "string", - "description": "Full or partial name of an airport", - }, - }, - }, -) - -search_amenities_func = generative_models.FunctionDeclaration( - name="amenities_search", - description=""" - Use this tool to search amenities by name or to recommended airport amenities at SFO. - If user provides flight info, use 'search_flights_by_number' - first to get gate info and location. - Only recommend amenities that are returned by this query. - Find amenities close to the user by matching the terminal and then comparing - the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 - B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. - top_k value is defaulted to 5. - """, - parameters={ - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "top_k": { - "type": "integer", - "description": "Number of matching amenities to return. Default this value to 5.", - }, - }, - "required": ["query", "top_k"], - }, -) - -search_policies_func = generative_models.FunctionDeclaration( - name="policies_search", - description="Use this tool to search for cymbal air passenger policy. Policy that are listed is unchangeable. You will not answer any questions outside of the policy given. Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. If top_k is not specified, default to 5.", - parameters={ - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "top_k": { - "type": "integer", - "description": "Number of matching policy to return. Default this value to 5.", - }, - }, - }, -) - -search_flights_by_number_func = generative_models.FunctionDeclaration( - name="search_flights_by_number", - description=""" - Use this tool to get info for a specific flight. Do NOT use this tool with a flight id. - Takes an airline and flight number and returns info on the flight. - Do NOT guess an airline or flight number. - A flight number is a code for an airline service consisting of two-character - airline designator and a 1 to 4 digit number ex. OO123, DL 1234, BA 405, AS 3452. - If the tool returns more than one option choose the date closes to today. - """, - parameters={ - "type": "object", - "properties": { - "airline": { - "type": "string", - "description": "A code for an airline service consisting of two-character airline designator.", - }, - "flight_number": { - "type": "string", - "description": "A 1 to 4 digit number of the flight.", - }, - }, - "required": ["airline", "flight_number"], - }, -) - -list_flights_func = generative_models.FunctionDeclaration( - name="list_flights", - description=""" - Use this tool to list all flights matching search criteria. - Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. - Date must be provided, prompt user if it is not given. - The format of date must be YYYY-MM-DD. Convert terms like 'today' or 'yesterday' to a valid date format. - If iata code is not provided for departure_airport or arrival_airport, use airports_search function to get iata code. - """, - parameters={ - "type": "object", - "properties": { - "departure_airport": { - "type": "string", - "description": "The iata code for flight departure airport. Example: 'SFO', 'DEN'.", - }, - "arrival_airport": { - "type": "string", - "description": "The iata code for flight arrival airport. Example: 'SFO', 'DEN'.", - }, - "date": { - "type": "string", - "description": "The date of flight must be in the following format: YYYY-MM-DD.", - }, - }, - "required": ["date"], - }, -) - -insert_ticket_func = generative_models.FunctionDeclaration( - name="insert_ticket", - description="Use this tool to book a flight ticket for the user.", - parameters={ - "type": "object", - "properties": { - "airline": { - "type": "string", - "description": "A code for an airline service consisting of two-character airline designator.", - }, - "flight_number": { - "type": "string", - "description": "A 1 to 4 digit number of the flight.", - }, - "departure_airport": { - "type": "string", - "description": "The iata code for flight departure airport.", - }, - "arrival_airport": { - "type": "string", - "description": "The iata code for flight arrival airport.", - }, - "departure_time": { - "type": "string", - "description": "The departure time for flight.", - }, - "arrival_time": { - "type": "string", - "description": "The arrival time for flight.", - }, - }, - "required": [ - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - ], - }, -) - -list_tickets_func = generative_models.FunctionDeclaration( - name="list_tickets", - description="Use this tool to list a user's flight tickets. This tool takes no input parameters and returns a list of current user's flight tickets.", - parameters={ - "type": "object", - }, -) - - -async def insert_ticket(client: aiohttp.ClientSession, params: str): - ticket_info = json.loads(params) - response = await client.post( - url=f"{BASE_URL}/tickets/insert", - params={ - "airline": ticket_info.get("airline"), - "flight_number": ticket_info.get("flight_number"), - "departure_airport": ticket_info.get("departure_airport"), - "arrival_airport": ticket_info.get("arrival_airport"), - "departure_time": ticket_info.get("departure_time").replace("T", " "), - "arrival_time": ticket_info.get("arrival_time").replace("T", " "), - }, - headers=get_headers(client), - ) - response = await response.json() - return response - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not hasattr(CREDENTIALS, "id_token"): - # Use Compute Engine default credential - CREDENTIALS = compute_engine.IDTokenCredentials( - request=Request(), - target_audience=BASE_URL, - use_metadata_identity_endpoint=True, - ) - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - if hasattr(CREDENTIALS, "id_token"): - return CREDENTIALS.id_token - else: - return CREDENTIALS.token - - -def get_headers(client: aiohttp.ClientSession): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers - - -def function_request(function_call_name: str) -> str: - functions_url = { - "airports_search": "airports/search", - "search_flights_by_number": "flights/search", - "list_flights": "flights/search", - "amenities_search": "amenities/search", - "policies_search": "policies/search", - "insert_ticket": "tickets/insert", - "list_tickets": "tickets/list", - } - return functions_url[function_call_name] - - -def assistant_tool(): - return generative_models.Tool( - function_declarations=[ - search_airports_func, - search_amenities_func, - search_policies_func, - search_flights_by_number_func, - list_flights_func, - insert_ticket_func, - list_tickets_func, - ], - ) - - -def get_confirmation_needing_tools(): - return ["insert_ticket"] diff --git a/llm_demo/vertexai_function_calling.int.tests.cloudbuild.yaml b/llm_demo/vertexai_function_calling.int.tests.cloudbuild.yaml deleted file mode 100644 index 68ad009f1..000000000 --- a/llm_demo/vertexai_function_calling.int.tests.cloudbuild.yaml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2023 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: "Deploy to Cloud Run" - name: "gcr.io/cloud-builders/gcloud:latest" - dir: llm_demo - script: | - #!/usr/bin/env bash - gcloud run deploy ${_SERVICE} \ - --source . \ - --region ${_REGION} \ - --no-allow-unauthenticated \ - --update-env-vars ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} - - - id: "Test Frontend" - name: "gcr.io/cloud-builders/gcloud:latest" - entrypoint: /bin/bash - env: # Set env var expected by app - args: - - "-c" - - | - export URL=$(gcloud run services describe ${_SERVICE} --region ${_REGION} --format 'value(status.url)') - export ID_TOKEN=$(gcloud auth print-identity-token --audiences $$URL) - export ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} - - # Test `/` route - curl -c cookies.txt -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL - - # Test `/chat` route should fail - msg=$(curl -si --show-error \ - -X POST \ - -H "Authorization: Bearer $$ID_TOKEN" \ - -H 'Content-Type: application/json' \ - -d '{"prompt":"How can you help me?"}' \ - $$URL/chat) - - if grep -q "400" <<< "$msg"; then - echo "Chat Handler Test: PASSED" - else - echo "Chat Handler Test: FAILED" - echo $msg && exit 1 - fi - - # Test `/chat` route - curl -b cookies.txt -si --fail --show-error \ - -X POST \ - -H "Authorization: Bearer $$ID_TOKEN" \ - -H 'Content-Type: application/json' \ - -d '{"prompt":"How can you help me?"}' \ - $$URL/chat - - - id: "Delete image and service" - name: "gcr.io/cloud-builders/gcloud" - script: | - #!/usr/bin/env bash - gcloud artifacts docker images delete $_GCR_HOSTNAME/$PROJECT_ID/cloud-run-source-deploy/$_SERVICE --quiet - gcloud run services delete ${_SERVICE} --region ${_REGION} --quiet - -serviceAccount: "projects/$PROJECT_ID/serviceAccounts/548341735270-compute@developer.gserviceaccount.com" # Necessary for ID token creation -options: - automapSubstitutions: true - logging: CLOUD_LOGGING_ONLY # Necessary for custom service account - dynamic_substitutions: true - -substitutions: - _GCR_HOSTNAME: ${_REGION}-docker.pkg.dev - _SERVICE: demo-service-${BUILD_ID} - _REGION: us-central1 - _ORCHESTRATION_TYPE: vertexai-function-calling diff --git a/llm_demo/orchestrator/langgraph/__init__.py b/orchestrator/__init__.py similarity index 85% rename from llm_demo/orchestrator/langgraph/__init__.py rename to orchestrator/__init__.py index f7e7d7004..f357920e9 100644 --- a/llm_demo/orchestrator/langgraph/__init__.py +++ b/orchestrator/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .langgraph_orchestrator import LangGraphOrchestrator +from .orchestrator import Orchestrator -__ALL__ = ["LangGraphOrchestrator"] +__ALL__ = ["Orchestrator"] diff --git a/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py b/orchestrator/orchestrator.py similarity index 73% rename from llm_demo/orchestrator/langgraph/langgraph_orchestrator.py rename to orchestrator/orchestrator.py index 4b2d7a191..42673c83d 100644 --- a/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py +++ b/orchestrator/orchestrator.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,20 +16,15 @@ import os import uuid from datetime import datetime -from typing import Annotated, Any, Dict, List, Literal, Optional, Sequence, TypedDict +from typing import Any, Dict, List, Optional, Sequence -from aiohttp import ClientSession, TCPConnector -from fastapi import HTTPException from langchain.globals import set_verbose # type: ignore from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableConfig, RunnableLambda -from langchain_core.tools import StructuredTool +from langchain_core.prompts import ChatPromptTemplate from langgraph.checkpoint.base import empty_checkpoint from langgraph.checkpoint.memory import MemorySaver from pytz import timezone -from ..orchestrator import BaseOrchestrator, classproperty from .react_graph import create_graph from .tools import initialize_tools @@ -41,21 +36,18 @@ } -class LangGraphOrchestrator(BaseOrchestrator): +class Orchestrator: + MODEL = "gemini-2.0-flash-001" + _user_sessions: Dict[str, str] # aiohttp context connector = None - client: Optional[ClientSession] = None def __init__(self): self._user_sessions = {} self._langgraph_app = None self._checkpointer = None - @classproperty - def kind(cls): - return "langgraph" - def user_session_exist(self, uuid: str) -> bool: return uuid in self._user_sessions @@ -74,14 +66,19 @@ async def user_session_decline_ticket(self, uuid: str) -> dict[str, Any]: async def user_session_create(self, session: dict[str, Any]): """Create and load an agent executor with tools and LLM.""" - client = await self.create_client_session() if self._langgraph_app is None: print("Initializing graph..") - tools = await initialize_tools(client) - prompt = self.create_prompt_template(tools) + tools, insert_ticket, validate_ticket = await initialize_tools() + prompt = self.create_prompt_template() checkpointer = MemorySaver() langgraph_app = await create_graph( - tools, checkpointer, prompt, self.MODEL, client, DEBUG + tools, + insert_ticket, + validate_ticket, + checkpointer, + prompt, + self.MODEL, + DEBUG, ) self._checkpointer = checkpointer self._langgraph_app = langgraph_app @@ -97,7 +94,6 @@ async def user_session_create(self, session: dict[str, Any]): config = self.get_config(session_id) self._langgraph_app.update_state(config, {"messages": history}) self._user_sessions[session_id] = "" - self.client = client async def user_session_invoke( self, uuid: str, user_prompt: Optional[str] @@ -176,36 +172,12 @@ def set_user_session_header(self, uuid: str, user_id_token: str): def get_user_id_token(self, uuid: str) -> Optional[str]: return self._user_sessions.get(uuid) - async def get_connector(self) -> TCPConnector: - if self.connector is None: - self.connector = TCPConnector(limit=100) - return self.connector - - async def create_client_session(self) -> ClientSession: - return ClientSession( - connector=await self.get_connector(), - connector_owner=False, - headers={}, - raise_for_status=True, - ) - - def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate: - # Create new prompt template - tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] - ) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = FORMAT_INSTRUCTIONS.format( - tool_names=tool_names, - ) + def create_prompt_template(self) -> ChatPromptTemplate: current_datetime = "Today's date and current time is {cur_datetime}." template = "\n\n".join( [ PREFIX, current_datetime, - TOOLS_PREFIX, - tool_strings, - format_instructions, SUFFIX, ] ) @@ -244,18 +216,24 @@ def get_base_history(self, session: dict[str, Any]): return BASE_HISTORY def get_config(self, uuid: str): - return {"configurable": {"thread_id": uuid, "checkpoint_ns": ""}} + return { + "configurable": { + "thread_id": uuid, + "auth_token_getters": { + "my_google_service": lambda: self.get_user_id_token(uuid) + }, + "checkpoint_ns": "", + }, + } async def user_session_signout(self, uuid: str): checkpoint = empty_checkpoint() config = self.get_config(uuid) - self._checkpointer.put(config=config, checkpoint=checkpoint, metadata={}) + self._checkpointer.put( + config=config, checkpoint=checkpoint, metadata={}, new_versions={} + ) del self._user_sessions[uuid] - async def close_clients(self): - if self.client: - await self.client.close() - PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. @@ -268,43 +246,9 @@ async def close_clients(self): require passing results from one query to another. Using the latest AI models, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. The assistant should -not answer questions about other peoples information for privacy reasons. +not answer questions about other people's information for privacy reasons. Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air -as well as ammenities of San Francisco Airport.""" - -TOOLS_PREFIX = """ -TOOLS: ------- -Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are: - -""" - -FORMAT_INSTRUCTIONS = """ -When responding, please output a response in one of two formats: - -** Option 1:** -Use this is you want to use a tool. -Markdown code snippet formatted in the following schema: -```json -{{{{ - "action": string, \ The action to take. Must be one of {tool_names} - "action_input": string \ The input to the action -}}}} -``` - -**Option 2:** -Use this if you want to respond directly to the human. -Markdown code snippet formatted following schema: -```json -{{{{ - "action": "Final Answer", - "action_input": string \ You should put what you want to return to user here -}}}} -``` -""" - -SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate. - -Remember to respond with a markdown code snippet of a json a single action, and NOTHING else. -""" +as well as amenities of San Francisco Airport.""" + +SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate.""" diff --git a/llm_demo/orchestrator/langgraph/react_graph.py b/orchestrator/react_graph.py similarity index 68% rename from llm_demo/orchestrator/langgraph/react_graph.py rename to orchestrator/react_graph.py index b05134f1e..40c75c879 100644 --- a/llm_demo/orchestrator/langgraph/react_graph.py +++ b/orchestrator/react_graph.py @@ -16,13 +16,10 @@ import uuid from typing import Annotated, Literal, Sequence, TypedDict -from aiohttp import ClientSession from langchain_core.messages import ( AIMessage, BaseMessage, - HumanMessage, ToolCall, - ToolMessage, ) from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import RunnableConfig, RunnableLambda @@ -31,19 +28,15 @@ from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages from langgraph.managed import IsLastStep +from langgraph.prebuilt import ToolNode +from toolbox_langchain import ToolboxTool -from .tool_node import ToolNode -from .tools import ( - TicketInfo, - get_confirmation_needing_tools, - insert_ticket, - validate_ticket, -) +from .tools import get_confirmation_needing_tools class UserState(TypedDict): """ - State with messages and ClientSession for each session/user. + State with messages for each session/user. """ messages: Annotated[Sequence[BaseMessage], add_messages] @@ -52,18 +45,21 @@ class UserState(TypedDict): async def create_graph( - tools, + tools: list[ToolboxTool], + insert_ticket: ToolboxTool, + validate_ticket: ToolboxTool, checkpointer: MemorySaver, prompt: ChatPromptTemplate, model_name: str, - client: ClientSession, debug: bool, ): """ Creates a graph that works with a chat model that utilizes tool calling. Args: - tools: A list of StructuredTools that will bind with the chat model. + tools: A list of ToolboxTools that will bind with the chat model. + insert_ticket: A ToolboxTool that inserts ticket for the logged in user. + validate_ticket: A ToolboxTool that validates the given flight data. checkpointer: The checkpoint saver object. This is useful for persisting the state of the graph (e.g., as chat memory). prompt: Initial prompt for the model. This applies to messages before they @@ -85,11 +81,13 @@ async def create_graph( tool_node = ToolNode(tools) # model node - # TODO: Use .bind_tools(tools) to bind the tools with the LLM. model = ChatVertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0) + # Bind the tools with the LLM. + model_with_tools = model.bind_tools(tools) + # Add the prompt to the model to create a model runnable - model_runnable = prompt | model + model_runnable = prompt | model_with_tools async def acall_model(state: UserState, config: RunnableConfig): """ @@ -98,43 +96,6 @@ async def acall_model(state: UserState, config: RunnableConfig): """ messages = state["messages"] res = await model_runnable.ainvoke({"messages": messages}, config) - - # TODO: Remove the temporary fix of parsing LLM response and invoking - # tools until we use bind_tools API and have automatic response parsing - # and tool calling. (see - # https://langchain-ai.github.io/langgraph/#example) - if "```json" in res.content: - try: - response = str(res.content).replace("```json", "").replace("```", "") - json_response = json.loads(response) - action = json_response.get("action") - action_input = json_response.get("action_input") - if action == "Final Answer": - res = AIMessage(content=action_input) - else: - res = AIMessage( - content="suggesting a tool call", - tool_calls=[ - ToolCall( - id=str(uuid.uuid4()), name=action, args=action_input - ) - ], - ) - except Exception as e: - json_response = response - res = AIMessage( - content="Sorry, failed to generate the right format for response" - ) - - # if model exceed the number of steps and has not yet return a final answer - if state["is_last_step"] and hasattr(res, "tool_calls"): - return { - "messages": [ - AIMessage( - content="Sorry, need more steps to process this request.", - ) - ] - } return {"messages": [res]} def agent_should_continue( @@ -151,26 +112,25 @@ def agent_should_continue( for tool_call in last_message.tool_calls: tool_name = tool_call["name"] if tool_name in confirmation_needing_tools: - if tool_name == "Insert Ticket": + if tool_name == "insert_ticket": return "booking_validation" return "continue" # Otherwise, we stop (reply to the user) return "end" - async def booking_validation_node(state: UserState, config: RunnableConfig): + async def booking_validation_node(state: UserState): """ The node representing async function that validate the ticket. After ticket validation, it will return AIMessage with updated ticket args. """ messages = state["messages"] last_message = messages[-1] - user_id_token = state["user_id_token"] if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: tool_call = last_message.tool_calls[0] # Run ticket validation and return the correct ticket information - flight_info = await validate_ticket( - client, tool_call.get("args"), user_id_token - ) + flight_info = await validate_ticket.ainvoke(tool_call.get("args")) + flight_info = json.loads(flight_info) + flight_info = flight_info[0] new_message = AIMessage( content="Please confirm if you would like to book the ticket.", @@ -197,26 +157,17 @@ def booking_should_continue(state: UserState) -> Literal["continue", "agent"]: # Otherwise, send response back to agent return "agent" - async def insert_ticket_node(state: UserState, config: RunnableConfig): + async def insert_ticket_node(state: UserState): """ Node to update human response to prevent """ messages = state["messages"] last_message = messages[-1] - user_id_token = state["user_id_token"] # Run insert ticket if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: tool_call = last_message.tool_calls[0] tool_args = tool_call.get("args") - ticket_info = TicketInfo(**tool_args) - output = await insert_ticket(client, ticket_info, user_id_token) - tool_call_id = tool_call.get("id") - tool_message = ToolMessage( - content=output, name="Insert Ticket", tool_call_id=tool_call_id - ) - human_message = HumanMessage(content="Looks good to me.") - ai_message = AIMessage(content=output) - return {"messages": [human_message, tool_message, ai_message]} + await insert_ticket.ainvoke(tool_args) # Define constant node strings AGENT_NODE = "agent" diff --git a/orchestrator/tools.py b/orchestrator/tools.py new file mode 100644 index 000000000..d67af7d4d --- /dev/null +++ b/orchestrator/tools.py @@ -0,0 +1,38 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Optional + +from toolbox_langchain import ToolboxClient + +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +TOOLBOX_URL = os.getenv("TOOLBOX_URL", default="http://127.0.0.1:5000") + + +# Tools for agent +async def initialize_tools(): + client = ToolboxClient(TOOLBOX_URL) + tools = await client.aload_toolset("cymbal_air") + + # Load insert_ticket and validate_ticket tools separately to implement + # human-in-the-loop. + insert_ticket = await client.aload_tool("insert_ticket") + validate_ticket = await client.aload_tool("validate_ticket") + + return (tools, insert_ticket, validate_ticket) + + +def get_confirmation_needing_tools(): + return ["insert_ticket"] diff --git a/llm_demo/pyproject.toml b/pyproject.toml similarity index 100% rename from llm_demo/pyproject.toml rename to pyproject.toml diff --git a/llm_demo/requirements-test.txt b/requirements-test.txt similarity index 100% rename from llm_demo/requirements-test.txt rename to requirements-test.txt diff --git a/llm_demo/requirements.txt b/requirements.txt similarity index 50% rename from llm_demo/requirements.txt rename to requirements.txt index abbef7a8c..f370ddd4f 100644 --- a/llm_demo/requirements.txt +++ b/requirements.txt @@ -1,19 +1,20 @@ fastapi==0.115.0 -google-auth==2.35.0 -google-cloud-aiplatform[evaluation]==1.72.0 +google-auth==2.40.3 +google-cloud-aiplatform[evaluation]==1.97.0 itsdangerous==2.2.0 jinja2==3.1.5 -langchain-community==0.3.2 -langchain==0.3.7 -langchain-google-vertexai==2.0.7 +langchain-community==0.3.25 +langchain==0.3.25 +langchain-core==0.3.65 +langchain-google-vertexai==2.0.25 markdown==3.7 types-Markdown==3.7.0.20240822 uvicorn[standard]==0.31.0 python-multipart==0.0.18 pytz==2025.1 types-pytz==2025.1.0.20250204 -langgraph==0.2.48 -httpx==0.27.2 +langgraph==0.4.8 pandas-stubs==2.2.2.240807 pandas==2.2.3 -pydantic==2.9.0 \ No newline at end of file +pydantic==2.9.0 +toolbox_langchain==0.2.1 \ No newline at end of file diff --git a/retrieval_service/Dockerfile b/retrieval_service/Dockerfile deleted file mode 100644 index bc3068d6d..000000000 --- a/retrieval_service/Dockerfile +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Use the official lightweight Python image. -# https://hub.docker.com/_/python -FROM python:3.11-slim - -# Allow statements and log messages to immediately appear in the logs -ENV PYTHONUNBUFFERED True - -WORKDIR /app - -# Install production dependencies. -COPY ./requirements.txt requirements.txt -RUN pip install --no-cache-dir -r requirements.txt - -# Copy local code to the container image. -COPY . ./ - -# Run the web service on container startup. -CMD ["python", "run_app.py"] diff --git a/retrieval_service/alloydb.tests.cloudbuild.yaml b/retrieval_service/alloydb.tests.cloudbuild.yaml deleted file mode 100644 index 69141ce48..000000000 --- a/retrieval_service/alloydb.tests.cloudbuild.yaml +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - - - id: Update config - name: python:3.11 - dir: retrieval_service - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - # Create config - cp example-config-alloydb.yml config.yml - sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml - sed -i "s/my-user/$$DB_USER/g" config.yml - sed -i "s/my-password/$$DB_PASS/g" config.yml - sed -i "s/my-project/$PROJECT_ID/g" config.yml - sed -i "s/my-region/${_ALLOYDB_REGION}/g" config.yml - sed -i "s/my-cluster/${_ALLOYDB_CLUSTER}/g" config.yml - sed -i "s/my-instance/${_ALLOYDB_INSTANCE}/g" config.yml - - - id: Run Alloy DB integration tests - name: python:3.11 - dir: retrieval_service - env: # Set env var expected by tests - - "DB_NAME=${_DATABASE_NAME}" - - "DB_PROJECT=$PROJECT_ID" - - "DB_REGION=${_ALLOYDB_REGION}" - - "DB_CLUSTER=${_ALLOYDB_CLUSTER}" - - "DB_INSTANCE=${_ALLOYDB_INSTANCE}" - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - python -m pytest --cov=datastore.providers.alloydb --cov-config=coverage/.alloydb-coveragerc datastore/providers/alloydb_test.py - -substitutions: - _DATABASE_NAME: test_${SHORT_SHA} - _DATABASE_USER: postgres - _ALLOYDB_REGION: "us-central1" - _ALLOYDB_CLUSTER: "my-alloydb-cluster" - _ALLOYDB_INSTANCE: "my-alloydb-instance" - -availableSecrets: - secretManager: - - versionName: projects/$PROJECT_ID/secrets/alloy_db_user/versions/latest - env: DB_USER - - versionName: projects/$PROJECT_ID/secrets/alloy_db_pass/versions/latest - env: DB_PASS - -options: - automapSubstitutions: true - substitutionOption: 'ALLOW_LOOSE' - dynamic_substitutions: true diff --git a/retrieval_service/app.tests.cloudbuild.yaml b/retrieval_service/app.tests.cloudbuild.yaml deleted file mode 100644 index 88f4c075f..000000000 --- a/retrieval_service/app.tests.cloudbuild.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - - - id: Run retrieval service app tests - name: python:3.11 - dir: retrieval_service - script: | - #!/usr/bin/env bash - python -m pytest --cov=app --cov-config=coverage/.app-coveragerc app/app_test.py diff --git a/retrieval_service/app/__init__.py b/retrieval_service/app/__init__.py deleted file mode 100644 index b28c73119..000000000 --- a/retrieval_service/app/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .app import EMBEDDING_MODEL_NAME, init_app, parse_config diff --git a/retrieval_service/app/app.py b/retrieval_service/app/app.py deleted file mode 100644 index 0220702a9..000000000 --- a/retrieval_service/app/app.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from contextlib import asynccontextmanager -from ipaddress import IPv4Address, IPv6Address -from typing import Optional - -import yaml -from fastapi import FastAPI -from langchain_google_vertexai import VertexAIEmbeddings -from pydantic import BaseModel - -import datastore - -from .routes import routes - -EMBEDDING_MODEL_NAME = "text-embedding-005" - - -class AppConfig(BaseModel): - host: IPv4Address | IPv6Address = IPv4Address("127.0.0.1") - port: int = 8080 - datastore: datastore.Config - clientId: Optional[str] = None - - -def parse_config(path: str) -> AppConfig: - with open(path, "r") as file: - config = yaml.safe_load(file) - return AppConfig(**config) - - -# gen_init is a wrapper to initialize the datastore during app startup -def gen_init(cfg: AppConfig): - async def initialize_datastore(app: FastAPI): - app.state.datastore = await datastore.create(cfg.datastore) - app.state.embed_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME) - yield - await app.state.datastore.close() - - return asynccontextmanager(initialize_datastore) - - -def init_app(cfg: AppConfig) -> FastAPI: - app = FastAPI(lifespan=gen_init(cfg)) - app.state.client_id = cfg.clientId - app.include_router(routes) - return app diff --git a/retrieval_service/app/app_test.py b/retrieval_service/app/app_test.py deleted file mode 100644 index f15be9c46..000000000 --- a/retrieval_service/app/app_test.py +++ /dev/null @@ -1,1070 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastapi.testclient import TestClient -from google.oauth2 import id_token - -import datastore -import models - -from . import init_app - - -@pytest.fixture(scope="module") -def app(): - mock_cfg = MagicMock() - mock_cfg.clientId = "fake client id" - app = init_app(mock_cfg) - if app is None: - raise TypeError("app did not initialize") - return app - - -@patch.object(datastore, "create") -def test_hello_world(m_datastore, app): - m_datastore = AsyncMock() - with TestClient(app) as client: - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"message": "Hello World"} - - -get_airport_params = [ - pytest.param( - "get_airport_by_id", - { - "id": 1, - }, - models.Airport( - id=1, - iata="FOO", - name="get_airport_by_id", - city="BAR", - country="FOO BAR", - ), - { - "id": 1, - "iata": "FOO", - "name": "get_airport_by_id", - "city": "BAR", - "country": "FOO BAR", - }, - id="id_only", - ), - pytest.param( - "get_airport_by_iata", - {"iata": "sfo"}, - models.Airport( - id=1, - iata="FOO", - name="get_airport_by_iata", - city="BAR", - country="FOO BAR", - ), - { - "id": 1, - "iata": "FOO", - "name": "get_airport_by_iata", - "city": "BAR", - "country": "FOO BAR", - }, - id="iata_only", - ), -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", get_airport_params -) -@patch.object(datastore, "create") -def test_get_airport(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get( - "/airports", - params=params, - ) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert output == expected - assert models.Airport.model_validate(output) - - -@pytest.mark.parametrize( - "params", - [ - pytest.param({}, id="no_params"), - ], -) -@patch.object(datastore, "create") -def test_get_airport_with_bad_params(m_datastore, app, params): - m_datastore = AsyncMock() - with TestClient(app) as client: - response = client.get("/airports", params=params) - assert response.status_code == 422 - assert ( - response.json()["detail"] == "Request requires query params: airport id or iata" - ) - - -search_airports_params = [ - pytest.param( - "search_airports", - { - "country": "United States", - "city": "san francisco", - "name": "san francisco", - }, - [ - models.Airport( - id=1, - iata="FOO", - name="search_airports", - city="BAR", - country="FOO BAR", - ) - ], - [ - { - "id": 1, - "iata": "FOO", - "name": "search_airports", - "city": "BAR", - "country": "FOO BAR", - } - ], - id="country_city_and_name", - ), - pytest.param( - "search_airports", - {"country": "United States"}, - [ - models.Airport( - id=1, - iata="FOO", - name="search_airports", - city="BAR", - country="FOO BAR", - ) - ], - [ - { - "id": 1, - "iata": "FOO", - "name": "search_airports", - "city": "BAR", - "country": "FOO BAR", - } - ], - id="country_only", - ), - pytest.param( - "search_airports", - {"city": "san francisco"}, - [ - models.Airport( - id=1, - iata="FOO", - name="search_airports", - city="BAR", - country="FOO BAR", - ) - ], - [ - { - "id": 1, - "iata": "FOO", - "name": "search_airports", - "city": "BAR", - "country": "FOO BAR", - } - ], - id="city_only", - ), - pytest.param( - "search_airports", - {"name": "san francisco"}, - [ - models.Airport( - id=1, - iata="FOO", - name="search_airports", - city="BAR", - country="FOO BAR", - ) - ], - [ - { - "id": 1, - "iata": "FOO", - "name": "search_airports", - "city": "BAR", - "country": "FOO BAR", - } - ], - id="name_only", - ), -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", search_airports_params -) -@patch.object(datastore, "create") -def test_search_airports(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get( - "/airports/search", - params=params, - ) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert output == expected - assert models.Airport.model_validate(output[0]) - - -@pytest.mark.parametrize( - "params", - [ - pytest.param({}, id="no_params"), - ], -) -@patch.object(datastore, "create") -def test_search_airports_with_bad_params(m_datastore, app, params): - m_datastore = AsyncMock() - with TestClient(app) as client: - response = client.get("/airports/search", params=params) - assert response.status_code == 422 - - -get_amenity_params = [ - pytest.param( - "get_amenity", - {"id": 1}, - models.Amenity( - id=1, - name="get_amenity", - description="FOO", - location="BAR", - terminal="FOO BAR", - category="FEE", - hour="BAZ", - ), - { - "id": 1, - "name": "get_amenity", - "description": "FOO", - "location": "BAR", - "terminal": "FOO BAR", - "category": "FEE", - "hour": "BAZ", - "sunday_start_hour": None, - "sunday_end_hour": None, - "monday_start_hour": None, - "monday_end_hour": None, - "tuesday_start_hour": None, - "tuesday_end_hour": None, - "wednesday_start_hour": None, - "wednesday_end_hour": None, - "thursday_start_hour": None, - "thursday_end_hour": None, - "friday_start_hour": None, - "friday_end_hour": None, - "saturday_start_hour": None, - "saturday_end_hour": None, - "content": None, - "embedding": None, - }, - ) -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", get_amenity_params -) -@patch.object(datastore, "create") -def test_get_amenity(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get( - "/amenities", - params={ - "id": 1, - }, - ) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert output == expected - assert models.Amenity.model_validate(output) - - -amenities_search_params = [ - pytest.param( - "amenities_search", - { - "query": "A place to get food.", - "top_k": 2, - }, - [ - models.Amenity( - id=1, - name="amenities_search", - description="FOO", - location="BAR", - terminal="FOO BAR", - category="FEE", - hour="BAZ", - ), - models.Amenity( - id=2, - name="amenities_search", - description="FOO", - location="BAR", - terminal="FOO BAR", - category="FEE", - hour="BAZ", - ), - ], - [ - { - "id": 1, - "name": "amenities_search", - "description": "FOO", - "location": "BAR", - "terminal": "FOO BAR", - "category": "FEE", - "hour": "BAZ", - "sunday_start_hour": None, - "sunday_end_hour": None, - "monday_start_hour": None, - "monday_end_hour": None, - "tuesday_start_hour": None, - "tuesday_end_hour": None, - "wednesday_start_hour": None, - "wednesday_end_hour": None, - "thursday_start_hour": None, - "thursday_end_hour": None, - "friday_start_hour": None, - "friday_end_hour": None, - "saturday_start_hour": None, - "saturday_end_hour": None, - "content": None, - "embedding": None, - }, - { - "id": 2, - "name": "amenities_search", - "description": "FOO", - "location": "BAR", - "terminal": "FOO BAR", - "category": "FEE", - "hour": "BAZ", - "sunday_start_hour": None, - "sunday_end_hour": None, - "monday_start_hour": None, - "monday_end_hour": None, - "tuesday_start_hour": None, - "tuesday_end_hour": None, - "wednesday_start_hour": None, - "wednesday_end_hour": None, - "thursday_start_hour": None, - "thursday_end_hour": None, - "friday_start_hour": None, - "friday_end_hour": None, - "saturday_start_hour": None, - "saturday_end_hour": None, - "content": None, - "embedding": None, - }, - ], - ) -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", amenities_search_params -) -@patch.object(datastore, "create") -def test_amenities_search(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get( - "/amenities/search", - params=params, - ) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert len(output) == params["top_k"] - assert output == expected - assert models.Amenity.model_validate(output[0]) - - -get_flight_params = [ - pytest.param( - "get_flight", - {"flight_id": 1935}, - models.Flight( - id=1, - airline="get_flight", - flight_number="FOOBAR", - departure_airport="FOO", - arrival_airport="BAR", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="BAZ", - arrival_gate="QUX", - ), - { - "id": 1, - "airline": "get_flight", - "flight_number": "FOOBAR", - "departure_airport": "FOO", - "arrival_airport": "BAR", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - "departure_gate": "BAZ", - "arrival_gate": "QUX", - }, - id="successful", - ) -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", get_flight_params -) -@patch.object(datastore, "create") -def test_get_flight(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get( - "/flights", - params=params, - ) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert output == expected - assert models.Flight.model_validate(output) - - -search_flights_params = [ - pytest.param( - "search_flights_by_airports", - { - "departure_airport": "LAX", - "arrival_airport": "SFO", - "date": "2023-11-01", - }, - [ - models.Flight( - id=1, - airline="search_flights_by_airports", - flight_number="FOOBAR", - departure_airport="FOO", - arrival_airport="BAR", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="BAZ", - arrival_gate="QUX", - ) - ], - [ - { - "id": 1, - "airline": "search_flights_by_airports", - "flight_number": "FOOBAR", - "departure_airport": "FOO", - "arrival_airport": "BAR", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - "departure_gate": "BAZ", - "arrival_gate": "QUX", - } - ], - id="departure_and_arrival_airport", - ), - pytest.param( - "search_flights_by_airports", - {"arrival_airport": "SFO", "date": "2023-11-01"}, - [ - models.Flight( - id=1, - airline="search_flights_by_airports", - flight_number="FOOBAR", - departure_airport="FOO", - arrival_airport="BAR", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="BAZ", - arrival_gate="QUX", - ) - ], - [ - { - "id": 1, - "airline": "search_flights_by_airports", - "flight_number": "FOOBAR", - "departure_airport": "FOO", - "arrival_airport": "BAR", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - "departure_gate": "BAZ", - "arrival_gate": "QUX", - } - ], - id="arrival_airport_only", - ), - pytest.param( - "search_flights_by_airports", - {"departure_airport": "EWR", "date": "2023-11-01"}, - [ - models.Flight( - id=1, - airline="search_flights_by_airports", - flight_number="FOOBAR", - departure_airport="FOO", - arrival_airport="BAR", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="BAZ", - arrival_gate="QUX", - ) - ], - [ - { - "id": 1, - "airline": "search_flights_by_airports", - "flight_number": "FOOBAR", - "departure_airport": "FOO", - "arrival_airport": "BAR", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - "departure_gate": "BAZ", - "arrival_gate": "QUX", - } - ], - id="departure_airport_only", - ), - pytest.param( - "search_flights_by_number", - {"airline": "DL", "flight_number": "1106"}, - [ - models.Flight( - id=1, - airline="search_flights_by_number", - flight_number="FOOBAR", - departure_airport="FOO", - arrival_airport="BAR", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="BAZ", - arrival_gate="QUX", - ) - ], - [ - { - "id": 1, - "airline": "search_flights_by_number", - "flight_number": "FOOBAR", - "departure_airport": "FOO", - "arrival_airport": "BAR", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - "departure_gate": "BAZ", - "arrival_gate": "QUX", - } - ], - id="airline_and_flight_number", - ), -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", search_flights_params -) -@patch.object(datastore, "create") -def test_search_flights(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get("/flights/search", params=params) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert output == expected - assert models.Flight.model_validate(output[0]) - - -search_flights_bad_params = [ - pytest.param( - { - "departure_airport": "LAX", - "arrival_airport": "SFO", - }, - id="departure_and_arrival_airport", - ), - pytest.param( - {"arrival_airport": "SFO"}, - id="arrival_airport_only", - ), - pytest.param( - {"departure_airport": "EWR"}, - id="departure_airport_only", - ), - pytest.param( - {"flight_number": "1106"}, - id="flight_number_only", - ), - pytest.param( - {"airline": "DL"}, - id="airline_only", - ), -] - - -@pytest.mark.parametrize("params", search_flights_bad_params) -@patch.object(datastore, "create") -def test_search_flights_with_bad_params(m_datastore, app, params): - m_datastore = AsyncMock() - with TestClient(app) as client: - response = client.get("/flights/search", params=params) - assert response.status_code == 422 - - -validate_ticket_params = [ - pytest.param( - "validate_ticket", - { - "airline": "CY", - "flight_number": "888", - "departure_airport": "LAX", - "departure_time": "2023-01-01T05:57:00", - }, - [ - models.Flight( - id=1, - airline="validate_ticket", - flight_number="FOOBAR", - departure_airport="FOO", - arrival_airport="BAR", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="BAZ", - arrival_gate="QUX", - ) - ], - [ - { - "id": 1, - "airline": "validate_ticket", - "flight_number": "FOOBAR", - "departure_airport": "FOO", - "arrival_airport": "BAR", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - "departure_gate": "BAZ", - "arrival_gate": "QUX", - } - ], - id="validate_ticket", - ), -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", validate_ticket_params -) -@patch.object(datastore, "create") -def test_validate_ticket(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get("/tickets/validate", params=params) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert output == expected - assert models.Flight.model_validate(output[0]) - - -policies_search_params = [ - pytest.param( - "policies_search", - { - "query": "Additional fee for flight changes.", - "top_k": 1, - }, - [ - models.Policy( - id=1, - content="foo bar", - ), - ], - [ - { - "id": 1, - "content": "foo bar", - "embedding": None, - }, - ], - ) -] - - -@pytest.mark.parametrize( - "method_name, params, mock_return, expected", policies_search_params -) -@patch.object(datastore, "create") -def test_policies_search(m_datastore, app, method_name, params, mock_return, expected): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - response = client.get( - "/policies/search", - params=params, - ) - assert response.status_code == 200 - res = response.json() - output = res["results"] - assert len(output) == params["top_k"] - assert output == expected - assert models.Policy.model_validate(output[0]) - - -@patch.object(datastore, "create") -def test_insert_ticket_missing_user_info(m_datastore, app): - m_datastore = AsyncMock() - with TestClient(app) as client: - response = client.post( - "/tickets/insert", - headers={"User-Id-Token": "Bearer invalid_token"}, - params={ - "airline": "CY", - "flight_number": "888", - "departure_airport": "LAX", - "arrival_airport": "JFK", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - }, - ) - assert response.status_code == 401 - assert response.json()["detail"] == "User login required for data insertion" - - -insert_ticket_params = [ - pytest.param( - "insert_ticket", - "valid_token", - { - "sub": 123, - "name": "test_user_name", - "email": "test_user_email", - }, - { - "airline": "CY", - "flight_number": "888", - "departure_airport": "LAX", - "arrival_airport": "JFK", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - }, - [ - models.Ticket( - user_id=123, - user_name="test_user_name", - user_email="test_user_email", - airline="CY", - flight_number="888", - departure_airport="LAX", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - ), - ], - 200, - [ - { - "user_id": 123, - "user_name": "test_user_name", - "user_email": "test_user_email", - "airline": "CY", - "flight_number": "888", - "departure_airport": "LAX", - "arrival_airport": "JFK", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - } - ], - ), - pytest.param( - "insert_ticket", - "invalid_token", - None, - { - "airline": "CY", - "flight_number": "888", - "departure_airport": "LAX", - "arrival_airport": "JFK", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - }, - [ - models.Ticket( - user_id=123, - user_name="test_user_name", - user_email="test_user_email", - airline="CY", - flight_number="888", - departure_airport="LAX", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - ), - ], - 401, - {"detail": "User login required for data insertion"}, - ), -] - - -@pytest.mark.parametrize( - "method_name, mock_token, mock_user_info, params, mock_return, expected_status, expected", - insert_ticket_params, -) -@patch.object(id_token, "verify_oauth2_token") -@patch.object(datastore, "create") -def test_insert_ticket( - m_datastore, - m_verify_oauth2_token, - app, - method_name, - mock_token, - mock_user_info, - params, - mock_return, - expected_status, - expected, -): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=mock_return), - ) as mock_method: - m_verify_oauth2_token.return_value = mock_user_info - response = client.post( - "/tickets/insert", - headers={"User-Id-Token": "Bearer " + mock_token}, - params=params, - ) - assert response.status_code == expected_status - res = response.json() - assert len(res) == 1 - assert res == expected - assert m_verify_oauth2_token.call_count == 1 - assert len(m_verify_oauth2_token.mock_calls[0].args) == 2 - assert m_verify_oauth2_token.mock_calls[0].args[0] == mock_token - if expected_status == 200: - assert models.Ticket.model_validate(res[0]) - assert mock_method.call_count == 1 - assert mock_method.mock_calls[0].args == tuple( - mock_user_info.values() - ) + tuple(params.values()) - else: - assert mock_method.call_count == 0 - - -list_tickets_params = [ - pytest.param( - "list_tickets", - "valid_token", - { - "sub": 123, - "name": "test_user_name", - "email": "test_user_email", - }, - [ - models.Ticket( - user_id=123, - user_name="test_user_name", - user_email="test_user_email", - airline="CY", - flight_number="888", - departure_airport="LAX", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - ), - ], - 200, - { - "results": [ - { - "user_id": 123, - "user_name": "test_user_name", - "user_email": "test_user_email", - "airline": "CY", - "flight_number": "888", - "departure_airport": "LAX", - "arrival_airport": "JFK", - "departure_time": "2023-01-01T05:57:00", - "arrival_time": "2023-01-01T12:13:00", - }, - ], - "sql": None, - }, - ), - pytest.param( - "list_tickets", - "invalid_token", - None, - [ - models.Ticket( - user_id=123, - user_name="test_user_name", - user_email="test_user_email", - airline="CY", - flight_number="888", - departure_airport="LAX", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2023-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2023-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - ), - ], - 401, - {"detail": "User login required for data insertion"}, - ), -] - - -@pytest.mark.parametrize( - "method_name, mock_token, mock_user_info, mock_return, expected_status, expected", - list_tickets_params, -) -@patch.object(id_token, "verify_oauth2_token") -@patch.object(datastore, "create") -def test_list_tickets( - m_datastore, - m_verify_oauth2_token, - app, - method_name, - mock_token, - mock_user_info, - mock_return, - expected_status, - expected, -): - with TestClient(app) as client: - with patch.object( - m_datastore.return_value, - method_name, - AsyncMock(return_value=(mock_return, None)), - ) as mock_method: - m_verify_oauth2_token.return_value = mock_user_info - response = client.get( - "/tickets/list", - headers={"User-Id-Token": "Bearer " + mock_token}, - ) - assert response.status_code == expected_status - res = response.json() - assert res == expected - assert m_verify_oauth2_token.call_count == 1 - assert len(m_verify_oauth2_token.mock_calls[0].args) == 2 - assert m_verify_oauth2_token.mock_calls[0].args[0] == mock_token - if expected_status == 200: - assert len(res) == 2 - assert len(res["results"]) == 1 - assert models.Ticket.model_validate(res["results"][0]) - assert mock_method.call_count == 1 - assert len(mock_method.mock_calls[0].args) == 1 - assert mock_method.mock_calls[0].args[0] == mock_user_info["sub"] - else: - assert mock_method.call_count == 0 diff --git a/retrieval_service/app/routes.py b/retrieval_service/app/routes.py deleted file mode 100644 index afdfc9f5b..000000000 --- a/retrieval_service/app/routes.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Mapping, Optional - -from fastapi import APIRouter, HTTPException, Request -from google.auth.transport import requests # type:ignore -from google.oauth2 import id_token # type:ignore -from langchain_core.embeddings import Embeddings - -import datastore - -routes = APIRouter() - - -def _ParseUserIdToken(headers: Mapping[str, Any]) -> Optional[str]: - """Parses the bearer token out of the request headers.""" - # authorization_header = headers.lower() - user_id_token_header = headers.get("User-Id-Token") - if not user_id_token_header: - raise Exception("no user authorization header") - - parts = str(user_id_token_header).split(" ") - if len(parts) != 2 or parts[0] != "Bearer": - raise Exception("Invalid ID token") - - return parts[1] - - -async def get_user_info(request): - headers = request.headers - token = _ParseUserIdToken(headers) - try: - id_info = id_token.verify_oauth2_token( - token, requests.Request(), audience=request.app.state.client_id - ) - - return { - "user_id": id_info.get("sub"), - "user_name": id_info.get("name"), - "user_email": id_info.get("email"), - } - - except Exception as e: # pylint: disable=broad-except - print(e) - - -@routes.get("/") -async def root(): - return {"message": "Hello World"} - - -@routes.get("/airports") -async def get_airport( - request: Request, - id: Optional[int] = None, - iata: Optional[str] = None, -): - ds: datastore.Client = request.app.state.datastore - if id: - results, sql = await ds.get_airport_by_id(id) - elif iata: - results, sql = await ds.get_airport_by_iata(iata) - else: - raise HTTPException( - status_code=422, - detail="Request requires query params: airport id or iata", - ) - return {"results": results, "sql": sql} - - -@routes.get("/airports/search") -async def search_airports( - request: Request, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, -): - if country is None and city is None and name is None: - raise HTTPException( - status_code=422, - detail="Request requires at least one query params: country, city, or airport name", - ) - - ds: datastore.Client = request.app.state.datastore - results, sql = await ds.search_airports(country, city, name) - return {"results": results, "sql": sql} - - -@routes.get("/amenities") -async def get_amenity(id: int, request: Request): - ds: datastore.Client = request.app.state.datastore - results, sql = await ds.get_amenity(id) - return {"results": results, "sql": sql} - - -@routes.get("/amenities/search") -async def amenities_search(query: str, top_k: int, request: Request): - ds: datastore.Client = request.app.state.datastore - - embed_service: Embeddings = request.app.state.embed_service - query_embedding = embed_service.embed_query(query) - - results, sql = await ds.amenities_search(query_embedding, 0.5, top_k) - return {"results": results, "sql": sql} - - -@routes.get("/flights") -async def get_flight(flight_id: int, request: Request): - ds: datastore.Client = request.app.state.datastore - results, sql = await ds.get_flight(flight_id) - return {"results": results, "sql": sql} - - -@routes.get("/flights/search") -async def search_flights( - request: Request, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - date: Optional[str] = None, - airline: Optional[str] = None, - flight_number: Optional[str] = None, -): - ds: datastore.Client = request.app.state.datastore - if date and (arrival_airport or departure_airport): - results, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - elif airline and flight_number: - results, sql = await ds.search_flights_by_number(airline, flight_number) - else: - raise HTTPException( - status_code=422, - detail="Request requires query params: arrival_airport, departure_airport, date, or both airline and flight_number", - ) - return {"results": results, "sql": sql} - - -@routes.post("/tickets/insert") -async def insert_ticket( - request: Request, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, -): - user_info = await get_user_info(request) - if user_info is None: - raise HTTPException( - status_code=401, - detail="User login required for data insertion", - ) - ds: datastore.Client = request.app.state.datastore - results = await ds.insert_ticket( - user_info["user_id"], - user_info["user_name"], - user_info["user_email"], - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time, - ) - return results - - -@routes.get("/tickets/validate") -async def validate_ticket( - request: Request, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, -): - ds: datastore.Client = request.app.state.datastore - results, sql = await ds.validate_ticket( - airline, - flight_number, - departure_airport, - departure_time, - ) - return {"results": results, "sql": sql} - - -@routes.get("/tickets/list") -async def list_tickets( - request: Request, -): - user_info = await get_user_info(request) - if user_info is None: - raise HTTPException( - status_code=401, - detail="User login required for data insertion", - ) - ds: datastore.Client = request.app.state.datastore - results, sql = await ds.list_tickets(user_info["user_id"]) - return {"results": results, "sql": sql} - - -@routes.get("/policies/search") -async def policies_search(query: str, top_k: int, request: Request): - ds: datastore.Client = request.app.state.datastore - - embed_service: Embeddings = request.app.state.embed_service - query_embedding = embed_service.embed_query(query) - - results, sql = await ds.policies_search(query_embedding, 0.5, top_k) - return {"results": results, "sql": sql} diff --git a/retrieval_service/cloudsql-mysql.tests.cloudbuild.yaml b/retrieval_service/cloudsql-mysql.tests.cloudbuild.yaml deleted file mode 100644 index d95282a32..000000000 --- a/retrieval_service/cloudsql-mysql.tests.cloudbuild.yaml +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - - - id: Update config mysql - name: python:3.11 - dir: retrieval_service - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - # Create config - cp example-config-cloudsql.yml config.yml - sed -i "s/cloudsql-engine/cloudsql-mysql/g" config.yml - sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml - sed -i "s/my-user/$$DB_USER/g" config.yml - sed -i "s/my-password/$$DB_PASS/g" config.yml - sed -i "s/my-project/$PROJECT_ID/g" config.yml - sed -i "s/my-region/${_CLOUDSQL_REGION}/g" config.yml - sed -i "s/my-instance/${_CLOUDSQL_INSTANCE}/g" config.yml - - - id: Run Cloud SQL mysql DB integration tests - name: python:3.11 - dir: retrieval_service - env: # Set env var expected by tests - - "DB_NAME=${_DATABASE_NAME}" - - "DB_PROJECT=$PROJECT_ID" - - "DB_REGION=${_CLOUDSQL_REGION}" - - "DB_INSTANCE=${_CLOUDSQL_INSTANCE}" - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - python -m pytest --cov=datastore.providers.cloudsql_mysql --cov-config=coverage/.cloudsql-mysql-coveragerc datastore/providers/cloudsql_mysql_test.py - -substitutions: - _DATABASE_NAME: test_${SHORT_SHA} - _CLOUDSQL_REGION: "us-central1" - _CLOUDSQL_INSTANCE: "my-cloudsql-mysql-instance" - -availableSecrets: - secretManager: - - versionName: projects/$PROJECT_ID/secrets/cloudsql_mysql_pass/versions/latest - env: DB_PASS - - versionName: projects/$PROJECT_ID/secrets/cloudsql_mysql_user/versions/latest - env: DB_USER - -options: - automapSubstitutions: true - substitutionOption: 'ALLOW_LOOSE' - dynamic_substitutions: true diff --git a/retrieval_service/cloudsql-pg.tests.cloudbuild.yaml b/retrieval_service/cloudsql-pg.tests.cloudbuild.yaml deleted file mode 100644 index 240317d17..000000000 --- a/retrieval_service/cloudsql-pg.tests.cloudbuild.yaml +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - - - id: Update config postgres - name: python:3.11 - dir: retrieval_service - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - # Create config - cp example-config-cloudsql.yml config.yml - sed -i "s/cloudsql-engine/cloudsql-postgres/g" config.yml - sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml - sed -i "s/my-user/$$DB_USER/g" config.yml - sed -i "s/my-password/$$DB_PASS/g" config.yml - sed -i "s/my-project/$PROJECT_ID/g" config.yml - sed -i "s/my-region/${_CLOUDSQL_REGION}/g" config.yml - sed -i "s/my-instance/${_CLOUDSQL_INSTANCE}/g" config.yml - - - id: Run Cloud SQL postgres DB integration tests - name: python:3.11 - dir: retrieval_service - env: # Set env var expected by tests - - "DB_NAME=${_DATABASE_NAME}" - - "DB_PROJECT=$PROJECT_ID" - - "DB_REGION=${_CLOUDSQL_REGION}" - - "DB_INSTANCE=${_CLOUDSQL_INSTANCE}" - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - python -m pytest --cov=datastore.providers.cloudsql_postgres --cov-config=coverage/.cloudsql-pg-coveragerc datastore/providers/cloudsql_postgres_test.py - -substitutions: - _DATABASE_NAME: test_${SHORT_SHA} - _CLOUDSQL_REGION: "us-central1" - _CLOUDSQL_INSTANCE: "my-cloudsql-pg-instance" - -availableSecrets: - secretManager: - - versionName: projects/$PROJECT_ID/secrets/cloudsql_pg_pass/versions/latest - env: DB_PASS - - versionName: projects/$PROJECT_ID/secrets/cloudsql_pg_user/versions/latest - env: DB_USER - -options: - automapSubstitutions: true - substitutionOption: 'ALLOW_LOOSE' - dynamic_substitutions: true diff --git a/retrieval_service/coverage/.alloydb-coveragerc b/retrieval_service/coverage/.alloydb-coveragerc deleted file mode 100644 index 996e2b7e3..000000000 --- a/retrieval_service/coverage/.alloydb-coveragerc +++ /dev/null @@ -1,9 +0,0 @@ -[run] -branch = true -omit = - */__init__.py - -[report] -show_missing = true -precision = 2 -fail_under = 93 diff --git a/retrieval_service/coverage/.app-coveragerc b/retrieval_service/coverage/.app-coveragerc deleted file mode 100644 index c298a1dd3..000000000 --- a/retrieval_service/coverage/.app-coveragerc +++ /dev/null @@ -1,10 +0,0 @@ -[run] -branch = true -omit = - */__init__.py - app_test.py - -[report] -show_missing = true -precision = 2 -fail_under = 95 diff --git a/retrieval_service/coverage/.cloudsql-mysql-coveragerc b/retrieval_service/coverage/.cloudsql-mysql-coveragerc deleted file mode 100644 index 7794b3324..000000000 --- a/retrieval_service/coverage/.cloudsql-mysql-coveragerc +++ /dev/null @@ -1,9 +0,0 @@ -[run] -branch = true -omit = - */__init__.py - -[report] -show_missing = true -precision = 2 -fail_under = 95 diff --git a/retrieval_service/coverage/.cloudsql-pg-coveragerc b/retrieval_service/coverage/.cloudsql-pg-coveragerc deleted file mode 100644 index 7794b3324..000000000 --- a/retrieval_service/coverage/.cloudsql-pg-coveragerc +++ /dev/null @@ -1,9 +0,0 @@ -[run] -branch = true -omit = - */__init__.py - -[report] -show_missing = true -precision = 2 -fail_under = 95 diff --git a/retrieval_service/coverage/.postgres-coveragerc b/retrieval_service/coverage/.postgres-coveragerc deleted file mode 100644 index c04c8a29b..000000000 --- a/retrieval_service/coverage/.postgres-coveragerc +++ /dev/null @@ -1,9 +0,0 @@ -[run] -branch = true -omit = - */__init__.py - -[report] -show_missing = true -precision = 2 -fail_under = 85 diff --git a/retrieval_service/coverage/.spanner-coveragerc b/retrieval_service/coverage/.spanner-coveragerc deleted file mode 100644 index c0dbdf857..000000000 --- a/retrieval_service/coverage/.spanner-coveragerc +++ /dev/null @@ -1,9 +0,0 @@ -[run] -branch = true -omit = - */__init__.py - -[report] -show_missing = true -precision = 2 -fail_under = 81 diff --git a/retrieval_service/datastore/__init__.py b/retrieval_service/datastore/__init__.py deleted file mode 100644 index 3ca2bae87..000000000 --- a/retrieval_service/datastore/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -from . import providers -from .datastore import Client, create - -Config = Union[ - providers.firestore.Config, - providers.postgres.Config, - providers.cloudsql_postgres.Config, - providers.spanner_gsql.Config, - providers.spanner_postgres.Config, - providers.alloydb.Config, - providers.cloudsql_mysql.Config, -] - -__ALL__ = [Client, Config, create, providers] diff --git a/retrieval_service/datastore/datastore.py b/retrieval_service/datastore/datastore.py deleted file mode 100644 index 13fe298a9..000000000 --- a/retrieval_service/datastore/datastore.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import csv -from abc import ABC, abstractmethod -from typing import Any, Generic, List, Optional, TypeVar - -import models - - -class AbstractConfig(ABC): - kind: str - - -C = TypeVar("C", bound=AbstractConfig) - - -class classproperty: - def __init__(self, func): - self.fget = func - - def __get__(self, _, owner): - return self.fget(owner) - - -class Client(ABC, Generic[C]): - @classproperty - @abstractmethod - def kind(cls): - pass - - @classmethod - @abstractmethod - async def create(cls, config: C) -> "Client": - pass - - async def load_dataset( - self, airports_ds_path, amenities_ds_path, flights_ds_path, policies_ds_path - ) -> tuple[ - List[models.Airport], - List[models.Amenity], - List[models.Flight], - List[models.Policy], - ]: - airports: List[models.Airport] = [] - with open(airports_ds_path, "r") as f: - reader = csv.DictReader(f, delimiter=",") - airports = [models.Airport.model_validate(line) for line in reader] - - amenities: list[models.Amenity] = [] - with open(amenities_ds_path, "r") as f: - reader = csv.DictReader(f, delimiter=",") - amenities = [models.Amenity.model_validate(line) for line in reader] - - flights: List[models.Flight] = [] - with open(flights_ds_path, "r") as f: - reader = csv.DictReader(f, delimiter=",") - flights = [models.Flight.model_validate(line) for line in reader] - - policies: List[models.Policy] = [] - with open(policies_ds_path, "r") as f: - reader = csv.DictReader(f, delimiter=",") - policies = [models.Policy.model_validate(line) for line in reader] - return airports, amenities, flights, policies - - async def export_dataset( - self, - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) -> None: - with open(airports_new_path, "w") as f: - col_names = ["id", "iata", "name", "city", "country"] - writer = csv.DictWriter(f, col_names, delimiter=",") - writer.writeheader() - for a in airports: - writer.writerow(a.model_dump()) - - with open(amenities_new_path, "w") as f: - col_names = [ - "id", - "name", - "description", - "location", - "terminal", - "category", - "hour", - "sunday_start_hour", - "sunday_end_hour", - "monday_start_hour", - "monday_end_hour", - "tuesday_start_hour", - "tuesday_end_hour", - "wednesday_start_hour", - "wednesday_end_hour", - "thursday_start_hour", - "thursday_end_hour", - "friday_start_hour", - "friday_end_hour", - "saturday_start_hour", - "saturday_end_hour", - "content", - "embedding", - ] - writer = csv.DictWriter(f, col_names, delimiter=",") - writer.writeheader() - for a in amenities: - writer.writerow(a.model_dump()) - - with open(flights_new_path, "w") as f: - col_names = [ - "id", - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - "departure_gate", - "arrival_gate", - ] - writer = csv.DictWriter(f, col_names, delimiter=",") - writer.writeheader() - for fl in flights: - writer.writerow(fl.model_dump()) - - with open(policies_new_path, "w") as f: - col_names = [ - "id", - "content", - "embedding", - ] - writer = csv.DictWriter(f, col_names, delimiter=",") - writer.writeheader() - for p in policies: - writer.writerow(p.model_dump()) - - @abstractmethod - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - pass - - @abstractmethod - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - pass - - @abstractmethod - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def search_flights_by_number( - self, - airline: str, - flight_number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def search_flights_by_airports( - self, - date, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - raise NotImplementedError("Subclass should implement this!") - - @abstractmethod - async def close(self): - pass - - -async def create(config: AbstractConfig) -> Client: - for cls in Client.__subclasses__(): - if config.kind == cls.kind: - return await cls.create(config) # type: ignore - raise TypeError(f"No clients of kind '{config.kind}'") diff --git a/retrieval_service/datastore/helpers.py b/retrieval_service/datastore/helpers.py deleted file mode 100644 index 13e025f21..000000000 --- a/retrieval_service/datastore/helpers.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import sqlparse - - -def format_sql(sql: str, params: dict): - """ - Format Postgres SQL to human readable text by replacing placeholders. - Handles dict-based (:key) formats. - """ - for key, value in params.items(): - sql = sql.replace(f":{key}", f"{value}") - # format the SQL - formatted_sql = ( - sqlparse.format( - sql, - reindent=True, - keyword_case="upper", - use_space_around_operators=True, - strip_whitespace=True, - ) - .replace("\n", "
") - .replace(" ", '
') - ) - return formatted_sql.replace("
", "", 1) diff --git a/retrieval_service/datastore/providers/__init__.py b/retrieval_service/datastore/providers/__init__.py deleted file mode 100644 index e5dcb0353..000000000 --- a/retrieval_service/datastore/providers/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import ( - alloydb, - cloudsql_mysql, - cloudsql_postgres, - firestore, - postgres, - spanner_gsql, - spanner_postgres, -) - -__ALL__ = [ - alloydb, - postgres, - cloudsql_mysql, - cloudsql_postgres, - firestore, - spanner_gsql, - spanner_postgres, -] diff --git a/retrieval_service/datastore/providers/alloydb.py b/retrieval_service/datastore/providers/alloydb.py deleted file mode 100644 index 04fcd34eb..000000000 --- a/retrieval_service/datastore/providers/alloydb.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Literal, Optional - -import asyncpg -from google.cloud.alloydb.connector import AsyncConnector, RefreshStrategy -from pgvector.asyncpg import register_vector -from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine - -import models - -from .. import datastore -from .postgres import Client as PostgresClient - -ALLOYDB_PG_IDENTIFIER = "alloydb-postgres" - - -class Config(BaseModel, datastore.AbstractConfig): - kind: Literal["alloydb-postgres"] - project: str - region: str - cluster: str - instance: str - user: str - password: str - database: str - - -class Client(datastore.Client[Config]): - __connector: Optional[AsyncConnector] = None - __pg_client: PostgresClient - - @datastore.classproperty - def kind(cls): - return ALLOYDB_PG_IDENTIFIER - - def __init__(self, async_engine: AsyncEngine): - self.__pg_client = PostgresClient(async_engine) - - @classmethod - async def create(cls, config: Config) -> "Client": - async def getconn() -> asyncpg.Connection: - if cls.__connector is None: - cls.__connector = AsyncConnector(refresh_strategy=RefreshStrategy.LAZY) - - conn: asyncpg.Connection = await cls.__connector.connect( - # Alloydb instance connection name - f"projects/{config.project}/locations/{config.region}/clusters/{config.cluster}/instances/{config.instance}", - "asyncpg", - user=f"{config.user}", - password=f"{config.password}", - db=f"{config.database}", - ip_type="PUBLIC", - ) - await register_vector(conn) - return conn - - async_engine = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, - ) - if async_engine is None: - raise TypeError("async_engine not instantiated") - return cls(async_engine) - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - await self.__pg_client.initialize_data(airports, amenities, flights, policies) - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - return await self.__pg_client.export_data() - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_client.get_airport_by_id(id) - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_client.get_airport_by_iata(iata) - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - return await self.__pg_client.search_airports(country, city, name) - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - return await self.__pg_client.get_amenity(id) - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_client.amenities_search( - query_embedding, similarity_threshold, top_k - ) - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_client.get_flight(flight_id) - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_client.search_flights_by_number(airline, number) - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_client.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_client.validate_ticket( - airline, flight_number, departure_airport, departure_time - ) - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - await self.__pg_client.insert_ticket( - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time, - ) - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_client.list_tickets(user_id) - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - return await self.__pg_client.policies_search( - query_embedding, similarity_threshold, top_k - ) - - async def close(self): - await self.__pg_client.close() diff --git a/retrieval_service/datastore/providers/alloydb_test.py b/retrieval_service/datastore/providers/alloydb_test.py deleted file mode 100644 index 0a9a23f0b..000000000 --- a/retrieval_service/datastore/providers/alloydb_test.py +++ /dev/null @@ -1,710 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Any, AsyncGenerator, List - -import asyncpg -import pytest -import pytest_asyncio -from csv_diff import compare, load_csv # type: ignore -from google.cloud.alloydb.connector import AsyncConnector - -import models - -from .. import datastore -from . import alloydb -from .test_data import ( - amenities_query_embedding1, - amenities_query_embedding2, - foobar_query_embedding, - policies_query_embedding1, - policies_query_embedding2, -) -from .utils import get_env_var - -pytestmark = pytest.mark.asyncio(scope="module") - - -@pytest.fixture(scope="module") -def db_user() -> str: - return get_env_var("DB_USER", "name of a postgres user") - - -@pytest.fixture(scope="module") -def db_pass() -> str: - return get_env_var("DB_PASS", "password for the postgres user") - - -@pytest.fixture(scope="module") -def db_project() -> str: - return get_env_var("DB_PROJECT", "project id for google cloud") - - -@pytest.fixture(scope="module") -def db_region() -> str: - return get_env_var("DB_REGION", "region for alloydb instance") - - -@pytest.fixture(scope="module") -def db_cluster() -> str: - return get_env_var("DB_CLUSTER", "cluster for alloydb") - - -@pytest.fixture(scope="module") -def db_instance() -> str: - return get_env_var("DB_INSTANCE", "instance for alloydb") - - -@pytest_asyncio.fixture(scope="module") -async def create_db( - db_user: str, - db_pass: str, - db_project: str, - db_region: str, - db_cluster: str, - db_instance: str, -) -> AsyncGenerator[str, None]: - db_name = get_env_var("DB_NAME", "name of a postgres database") - connector = AsyncConnector() - project_instance = f"projects/{db_project}/locations/{db_region}/clusters/{db_cluster}/instances/{db_instance}" - # Database does not exist, create it. - sys_conn: asyncpg.Connection = await connector.connect( - project_instance, - "asyncpg", - user=f"{db_user}", - password=f"{db_pass}", - db="postgres", - ip_type="PUBLIC", - ) - await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') - await sys_conn.execute(f'CREATE DATABASE "{db_name}";') - conn: asyncpg.Connection = await connector.connect( - project_instance, - "asyncpg", - user=f"{db_user}", - password=f"{db_pass}", - db=f"{db_name}", - ip_type="PUBLIC", - ) - await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") - await conn.close() - yield db_name - await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') - await sys_conn.close() - - -@pytest_asyncio.fixture(scope="module") -async def ds( - create_db: str, - db_user: str, - db_pass: str, - db_project: str, - db_region: str, - db_cluster: str, - db_instance: str, -) -> AsyncGenerator[datastore.Client, None]: - cfg = alloydb.Config( - kind="alloydb-postgres", - user=db_user, - password=db_pass, - database=create_db, - project=db_project, - region=db_region, - cluster=db_cluster, - instance=db_instance, - ) - ds = await datastore.create(cfg) - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, - amenities_ds_path, - flights_ds_path, - policies_ds_path, - ) - await ds.initialize_data(airports, amenities, flights, policies) - - if ds is None: - raise TypeError("datastore creation failure") - yield ds - await ds.close() - - -def check_file_diff(file_diff): - assert file_diff["added"] == [] - assert file_diff["removed"] == [] - assert file_diff["changed"] == [] - assert file_diff["columns_added"] == [] - assert file_diff["columns_removed"] == [] - - -async def test_export_dataset(ds: alloydb.Client): - airports, amenities, flights, policies = await ds.export_data() - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - diff_airports = compare( - load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") - ) - check_file_diff(diff_airports) - - diff_amenities = compare( - load_csv(open(amenities_ds_path), "id"), - load_csv(open(amenities_new_path), "id"), - ) - check_file_diff(diff_amenities) - - diff_flights = compare( - load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") - ) - check_file_diff(diff_flights) - - diff_policies = compare( - load_csv(open(policies_ds_path), "id"), - load_csv(open(policies_new_path), "id"), - ) - check_file_diff(diff_policies) - - -async def test_get_airport_by_id(ds: alloydb.Client): - res, sql = await ds.get_airport_by_id(1) - expected = models.Airport( - id=1, - iata="MAG", - name="Madang Airport", - city="Madang", - country="Papua New Guinea", - ) - assert res == expected - assert sql is not None - - -@pytest.mark.parametrize( - "iata", - [ - pytest.param("SFO", id="upper_case"), - pytest.param("sfo", id="lower_case"), - ], -) -async def test_get_airport_by_iata(ds: alloydb.Client, iata: str): - res, sql = await ds.get_airport_by_iata(iata) - expected = models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - assert res == expected - assert sql is not None - - -search_airports_test_data = [ - pytest.param( - "Philippines", - "San jose", - None, - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=2313, - iata="EUQ", - name="Evelio Javier Airport", - city="San Jose", - country="Philippines", - ), - ], - id="country_and_city_only", - ), - pytest.param( - "united states", - "san francisco", - None, - [ - models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - ], - id="country_and_name_only", - ), - pytest.param( - None, - "San Jose", - "San Jose", - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=3548, - iata="SJC", - name="Norman Y. Mineta San Jose International Airport", - city="San Jose", - country="United States", - ), - ], - id="city_and_name_only", - ), - pytest.param( - "Foo", - "FOO BAR", - "Foo bar", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) -async def test_search_airports( - ds: alloydb.Client, - country: str, - city: str, - name: str, - expected: List[models.Airport], -): - res, sql = await ds.search_airports(country, city, name) - assert res == expected - assert sql is not None - - -async def test_get_amenity(ds: alloydb.Client): - res, sql = await ds.get_amenity(0) - expected = models.Amenity( - id=0, - name="Coffee Shop 732", - description="Serving American cuisine.", - location="Near Gate B12", - terminal="Terminal 3", - category="restaurant", - hour="Daily 7:00 am - 10:00 pm", - sunday_start_hour=None, - sunday_end_hour=None, - monday_start_hour=None, - monday_end_hour=None, - tuesday_start_hour=None, - tuesday_end_hour=None, - wednesday_start_hour=None, - wednesday_end_hour=None, - thursday_start_hour=None, - thursday_end_hour=None, - friday_start_hour=None, - friday_end_hour=None, - saturday_start_hour=None, - saturday_end_hour=None, - ) - assert res == expected - assert sql is not None - - -amenities_search_test_data = [ - pytest.param( - # "Where can I get coffee near gate A6?" - amenities_query_embedding1, - 0.35, - 1, - [ - { - "name": "Coffee Shop 732", - "description": "Serving American cuisine.", - "location": "Near Gate B12", - "terminal": "Terminal 3", - "category": "restaurant", - "hour": "Daily 7:00 am - 10:00 pm", - }, - ], - id="search_coffee_shop", - ), - pytest.param( - # "Where can I look for luxury goods?" - amenities_query_embedding2, - 0.35, - 2, - [ - { - "name": "Gucci Duty Free", - "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", - "location": "Gate E9", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - { - "name": "Hermes Duty Free", - "description": "High-end French brand duty-free shop offering luxury goods and accessories.", - "location": "Gate E18", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - ], - id="search_luxury_goods", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.1, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data -) -async def test_amenities_search( - ds: alloydb.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[Any], -): - res, sql = await ds.amenities_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is not None - - -async def test_get_flight(ds: alloydb.Client): - res, sql = await ds.get_flight(1) - expected = models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime("2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ) - assert res == expected - assert sql is not None - - -search_flights_by_number_test_data = [ - pytest.param( - "UA", - "1158", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=55455, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2025-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="B50", - arrival_gate="E4", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "UU", - "0000", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "airline, number, expected", search_flights_by_number_test_data -) -async def test_search_flights_by_number( - ds: alloydb.Client, - airline: str, - number: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_number(airline, number) - assert res == expected - assert sql is not None - - -search_flights_by_airports_test_data = [ - pytest.param( - "2025-01-01", - "SFO", - "ORD", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=13, - airline="UA", - flight_number="616", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="A11", - arrival_gate="D8", - ), - models.Flight( - id=25, - airline="AA", - flight_number="242", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E30", - arrival_gate="C1", - ), - models.Flight( - id=109, - airline="UA", - flight_number="1640", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E27", - arrival_gate="C24", - ), - models.Flight( - id=119, - airline="AA", - flight_number="197", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="D25", - arrival_gate="E49", - ), - models.Flight( - id=136, - airline="UA", - flight_number="1564", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E3", - arrival_gate="C48", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "2025-01-01", - "FOO", - "BAR", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "date, departure_airport, arrival_airport, expected", - search_flights_by_airports_test_data, -) -async def test_search_flights_by_airports( - ds: alloydb.Client, - date: str, - departure_airport: str, - arrival_airport: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - assert res == expected - assert sql is not None - - -policies_search_test_data = [ - pytest.param( - # "What is the fee for extra baggage?" - policies_query_embedding1, - 0.35, - 1, - [ - "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", - ], - id="search_extra_baggage_fee", - ), - pytest.param( - # "Can I change my flight?" - policies_query_embedding2, - 0.35, - 2, - [ - "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", - "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", - ], - id="search_flight_delays", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.35, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data -) -async def test_policies_search( - ds: alloydb.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[str], -): - res, sql = await ds.policies_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is not None - - -validate_ticket_data = [ - pytest.param( - { - "airline": "UA", - "flight_number": "1158", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:57:00", - }, - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ), - 'SELECT *
FROM flights
WHERE airline ILIKE UA
AND flight_number ILIKE 1158
AND departure_airport ILIKE SFO
AND departure_time = 2025-01-01 05:57:00', - ), - pytest.param( - { - "airline": "XX", - "flight_number": "9999", - "departure_airport": "ZZZ", - "departure_time": "2025-01-01 05:57:00", - }, - None, - None, - ), -] - - -@pytest.mark.parametrize("params, expected_data, expected_sql", validate_ticket_data) -async def test_validate_ticket(ds: alloydb.Client, params, expected_data, expected_sql): - flight, sql = await ds.validate_ticket(**params) - assert flight == expected_data - assert sql == expected_sql diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py deleted file mode 100644 index 5545507c8..000000000 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ /dev/null @@ -1,822 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from datetime import datetime -from typing import Any, Literal, Optional - -import pymysql -from google.cloud.sql.connector import Connector, RefreshStrategy -from pydantic import BaseModel -from sqlalchemy import Engine, create_engine, text -from sqlalchemy.engine.base import Engine - -import models - -from .. import datastore - -MYSQL_IDENTIFIER = "cloudsql-mysql" - - -class Config(BaseModel, datastore.AbstractConfig): - kind: Literal["cloudsql-mysql"] - project: str - region: str - instance: str - user: str - password: str - database: str - - -class Client(datastore.Client[Config]): - __pool: Engine - __db_name: str - __connector: Optional[Connector] = None - - @datastore.classproperty - def kind(cls): - return MYSQL_IDENTIFIER - - def __init__(self, pool: Engine, db_name: str): - self.__pool = pool - self.__db_name = db_name - - @classmethod - def create_sync(cls, config: Config) -> "Client": - def getconn() -> pymysql.Connection: - if cls.__connector is None: - cls.__connector = Connector(refresh_strategy=RefreshStrategy.LAZY) - - return cls.__connector.connect( - # Cloud SQL instance connection name - f"{config.project}:{config.region}:{config.instance}", - "pymysql", - user=f"{config.user}", - password=f"{config.password}", - db=f"{config.database}", - autocommit=True, - ) - - pool = create_engine( - "mysql+pymysql://", - creator=getconn, - ) - if pool is None: - raise TypeError("pool not instantiated") - return cls(pool, config.database) - - @classmethod - async def create(cls, config: Config) -> "Client": - loop = asyncio.get_running_loop() - - pool = await loop.run_in_executor(None, cls.create_sync, config) - return pool - - def initialize_data_sync( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - with self.__pool.connect() as conn: - # If the table already exists, drop it to avoid conflicts - conn.execute(text("DROP TABLE IF EXISTS airports")) - # Create a new table - conn.execute( - text( - """ - CREATE TABLE airports( - id INT PRIMARY KEY, - iata TEXT, - name TEXT, - city TEXT, - country TEXT - ) - """ - ) - ) - # Insert all the data - conn.execute( - text( - """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)""" - ), - parameters=[ - { - "id": a.id, - "iata": a.iata, - "name": a.name, - "city": a.city, - "country": a.country, - } - for a in airports - ], - ) - - # If the table already exists, drop it to avoid conflicts - conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE")) - - # Create a new table - conn.execute( - text( - """ - CREATE TABLE amenities( - id INT PRIMARY KEY, - name TEXT, - description TEXT, - location TEXT, - terminal TEXT, - category TEXT, - hour TEXT, - sunday_start_hour TIME, - sunday_end_hour TIME, - monday_start_hour TIME, - monday_end_hour TIME, - tuesday_start_hour TIME, - tuesday_end_hour TIME, - wednesday_start_hour TIME, - wednesday_end_hour TIME, - thursday_start_hour TIME, - thursday_end_hour TIME, - friday_start_hour TIME, - friday_end_hour TIME, - saturday_start_hour TIME, - saturday_end_hour TIME, - content TEXT NOT NULL, - embedding vector(768) USING VARBINARY NOT NULL - ) - """ - ) - ) - - # Insert all the data - conn.execute( - text( - """ - INSERT INTO amenities VALUES (:id, :name, :description, :location, - :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour, - :monday_start_hour, :monday_end_hour, :tuesday_start_hour, - :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour, - :thursday_start_hour, :thursday_end_hour, :friday_start_hour, - :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, string_to_vector(:embedding)) - """ - ), - parameters=[ - { - "id": a.id, - "name": a.name, - "description": a.description, - "location": a.location, - "terminal": a.terminal, - "category": a.category, - "hour": a.hour, - "sunday_start_hour": a.sunday_start_hour, - "sunday_end_hour": a.sunday_end_hour, - "monday_start_hour": a.monday_start_hour, - "monday_end_hour": a.monday_end_hour, - "tuesday_start_hour": a.tuesday_start_hour, - "tuesday_end_hour": a.tuesday_end_hour, - "wednesday_start_hour": a.wednesday_start_hour, - "wednesday_end_hour": a.wednesday_end_hour, - "thursday_start_hour": a.thursday_start_hour, - "thursday_end_hour": a.thursday_end_hour, - "friday_start_hour": a.friday_start_hour, - "friday_end_hour": a.friday_end_hour, - "saturday_start_hour": a.saturday_start_hour, - "saturday_end_hour": a.saturday_end_hour, - "content": a.content, - "embedding": f"{a.embedding}", - } - for a in amenities - ], - ) - - # If the table already exists, drop it to avoid conflicts - conn.execute(text("DROP TABLE IF EXISTS flights")) - # Create a new table - conn.execute( - text( - """ - CREATE TABLE flights( - id INTEGER PRIMARY KEY, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP, - departure_gate TEXT, - arrival_gate TEXT - ) - """ - ) - ) - # Insert all the data - conn.execute( - text( - """ - INSERT INTO flights VALUES (:id, :airline, :flight_number, - :departure_airport, :arrival_airport, :departure_time, - :arrival_time, :departure_gate, :arrival_gate) - """ - ), - parameters=[ - { - "id": f.id, - "airline": f.airline, - "flight_number": f.flight_number, - "departure_airport": f.departure_airport, - "arrival_airport": f.arrival_airport, - "departure_time": f.departure_time, - "arrival_time": f.arrival_time, - "departure_gate": f.departure_gate, - "arrival_gate": f.arrival_gate, - } - for f in flights - ], - ) - - # If the table already exists, drop it to avoid conflicts - conn.execute(text("DROP TABLE IF EXISTS tickets")) - # Create a new table - conn.execute( - text( - """ - CREATE TABLE tickets( - user_id TEXT, - user_name TEXT, - user_email TEXT, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP - ) - """ - ) - ) - - # If the table already exists, drop it to avoid conflicts - conn.execute(text("DROP TABLE IF EXISTS policies")) - # Create a new table - conn.execute( - text( - """ - CREATE TABLE policies( - id INT PRIMARY KEY, - content TEXT NOT NULL, - embedding vector(768) USING VARBINARY NOT NULL - ) - """ - ) - ) - # Insert all the data - conn.execute( - text( - """ - INSERT INTO policies VALUES (:id, :content, string_to_vector(:embedding)) - """ - ), - parameters=[ - { - "id": p.id, - "content": p.content, - "embedding": f"{p.embedding}", - } - for p in policies - ], - ) - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.initialize_data_sync, airports, amenities, flights, policies - ) - - def export_data_sync( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - with self.__pool.connect() as conn: - airport_task = conn.execute( - text("""SELECT * FROM airports ORDER BY id ASC""") - ) - amenity_task = conn.execute( - text( - """ - SELECT id, - name, - description, - location, - terminal, - category, - hour, - DATE_FORMAT(sunday_start_hour, '%H:%i') AS sunday_start_hour, - DATE_FORMAT(sunday_end_hour, '%H:%i') AS sunday_end_hour, - DATE_FORMAT(monday_start_hour, '%H:%i') AS monday_start_hour, - DATE_FORMAT(monday_end_hour, '%H:%i') AS monday_end_hour, - DATE_FORMAT(tuesday_start_hour, '%H:%i') AS tuesday_start_hour, - DATE_FORMAT(tuesday_end_hour, '%H:%i') AS tuesday_end_hour, - DATE_FORMAT(wednesday_start_hour, '%H:%i') AS wednesday_start_hour, - DATE_FORMAT(wednesday_end_hour, '%H:%i') AS wednesday_end_hour, - DATE_FORMAT(thursday_start_hour, '%H:%i') AS thursday_start_hour, - DATE_FORMAT(thursday_end_hour, '%H:%i') AS thursday_end_hour, - DATE_FORMAT(friday_start_hour, '%H:%i') AS friday_start_hour, - DATE_FORMAT(friday_end_hour, '%H:%i') AS friday_end_hour, - DATE_FORMAT(saturday_start_hour, '%H:%i') AS saturday_start_hour, - DATE_FORMAT(saturday_end_hour, '%H:%i') AS saturday_end_hour, - content, - vector_to_string(embedding) as embedding - FROM amenities ORDER BY id ASC - """ - ) - ) - flights_task = conn.execute( - text("""SELECT * FROM flights ORDER BY id ASC""") - ) - policy_task = conn.execute( - text( - """SELECT id, content, vector_to_string(embedding) as embedding FROM policies ORDER BY id ASC""" - ) - ) - - airport_results = (airport_task).mappings().fetchall() - amenity_results = (amenity_task).mappings().fetchall() - flights_results = (flights_task).mappings().fetchall() - policy_results = (policy_task).mappings().fetchall() - - airports = [models.Airport.model_validate(a) for a in airport_results] - amenities = [models.Amenity.model_validate(a) for a in amenity_results] - flights = [models.Flight.model_validate(f) for f in flights_results] - policies = [models.Policy.model_validate(p) for p in policy_results] - - return airports, amenities, flights, policies - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.export_data_sync) - return res - - def get_airport_by_id_sync( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - with self.__pool.connect() as conn: - s = text("""SELECT * FROM airports WHERE id=:id""") - params = {"id": id} - result = (conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Airport.model_validate(result) - return res, None - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor(None, self.get_airport_by_id_sync, id) - return res, sql - - def get_airport_by_iata_sync( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - with self.__pool.connect() as conn: - s = text("""SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(:iata)""") - params = {"iata": iata} - result = (conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Airport.model_validate(result) - return res, None - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor(None, self.get_airport_by_iata_sync, iata) - return res, sql - - def search_airports_sync( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT * FROM airports - WHERE (:country IS NULL OR LOWER(country) LIKE CONCAT('%', LOWER(:country), '%')) - AND (:city IS NULL OR LOWER(city) LIKE CONCAT('%', LOWER(:city), '%')) - AND (:name IS NULL OR LOWER(name) LIKE CONCAT('%', LOWER(:name), '%')) - LIMIT 10; - """ - ) - params = { - "country": country, - "city": city, - "name": name, - } - results = (conn.execute(s, parameters=params)).mappings().fetchall() - - res = [models.Airport.model_validate(r) for r in results] - return res, None - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor( - None, self.search_airports_sync, country, city, name - ) - return res, sql - - def get_amenity_sync( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT id, name, description, location, terminal, category, hour - FROM amenities WHERE id=:id - """ - ) - params = {"id": id} - result = (conn.execute(s, parameters=params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Amenity.model_validate(result) - return res, None - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor(None, self.get_amenity_sync, id) - return res, sql - - def amenities_search_sync( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT name, description, location, terminal, category, hour - FROM amenities - ORDER BY APPROX_DISTANCE(embedding, string_to_vector(:query), 'distance_measure=cosine') LIMIT :search_options - """ - ) - params = { - "query": f"{query_embedding}", - "search_options": top_k, - } - results = (conn.execute(s, parameters=params)).mappings().fetchall() - - res = [r for r in results] - return res, None - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor( - None, - self.amenities_search_sync, - query_embedding, - similarity_threshold, - top_k, - ) - return res, sql - - def get_flight_sync( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT * FROM flights - WHERE id = :flight_id - """ - ) - params = {"flight_id": flight_id} - result = (conn.execute(s, parameters=params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Flight.model_validate(result) - return res, None - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor(None, self.get_flight_sync, flight_id) - return res, sql - - def search_flights_by_number_sync( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT * FROM flights - WHERE airline = :airline - AND flight_number = :number - LIMIT 10 - """ - ) - params = { - "airline": airline, - "number": number, - } - results = (conn.execute(s, parameters=params)).mappings().fetchall() - - res = [models.Flight.model_validate(r) for r in results] - return res, None - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor( - None, self.search_flights_by_number_sync, airline, number - ) - return res, sql - - def search_flights_by_airports_sync( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT * FROM flights - WHERE (CAST(:departure_airport AS CHAR(255)) IS NULL OR LOWER(departure_airport) LIKE LOWER(:departure_airport)) - AND (CAST(:arrival_airport AS CHAR(255)) IS NULL OR LOWER(arrival_airport) LIKE LOWER(:arrival_airport)) - AND departure_time >= CAST(:datetime AS DATETIME) - AND (departure_time < DATE_ADD(CAST(:datetime AS DATETIME), interval 1 day)) - LIMIT 10 - """ - ) - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "datetime": datetime.strptime(date, "%Y-%m-%d"), - } - - results = (conn.execute(s, parameters=params)).mappings().fetchall() - - res = [models.Flight.model_validate(r) for r in results] - return res, None - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor( - None, - self.search_flights_by_airports_sync, - date, - departure_airport, - arrival_airport, - ) - return res, sql - - def validate_ticket_sync( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT * FROM flights - WHERE LOWER(airline) LIKE LOWER(:airline) - AND LOWER(flight_number) LIKE LOWER(:flight_number) - AND LOWER(departure_airport) LIKE LOWER(:departure_airport) - AND departure_time = CAST(:departure_time AS DATETIME) - LIMIT 10 - """ - ) - params = { - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "departure_time": departure_time, - } - - result = (conn.execute(s, parameters=params)).mappings().fetchone() - if result is None: - return None, None - res = models.Flight.model_validate(result) - return res, None - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor( - None, - self.validate_ticket_sync, - airline, - flight_number, - departure_airport, - departure_time, - ) - return res, sql - - def insert_ticket_sync( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - with self.__pool.connect() as conn: - s = text( - """ - INSERT INTO tickets ( - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time - ) VALUES ( - :user_id, - :user_name, - :user_email, - :airline, - :flight_number, - :departure_airport, - :arrival_airport, - :departure_time, - :arrival_time - ); - """ - ) - params = { - "user_id": user_id, - "user_name": user_name, - "user_email": user_email, - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "departure_time": departure_time, - "arrival_time": arrival_time, - } - conn.execute(s, params).mappings() - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, - self.insert_ticket_sync, - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time, - ) - - def list_tickets_sync( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets - WHERE user_id = :user_id - """ - ) - params = { - "user_id": user_id, - } - - results = (conn.execute(s, parameters=params)).mappings().fetchall() - - res = [r for r in results] - return res, None - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[models.Ticket], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor(None, self.list_tickets_sync, user_id) - return res, sql - - def policies_search_sync( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - with self.__pool.connect() as conn: - s = text( - """ - SELECT content - FROM policies - ORDER BY APPROX_DISTANCE(embedding, string_to_vector(:query), 'distance_measure=cosine') LIMIT :search_options - """ - ) - params = { - "query": f"{query_embedding}", - "search_options": top_k, - } - - results = (conn.execute(s, parameters=params)).mappings().fetchall() - - res = [r["content"] for r in results] - return res, None - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - loop = asyncio.get_running_loop() - res, sql = await loop.run_in_executor( - None, - self.policies_search_sync, - query_embedding, - similarity_threshold, - top_k, - ) - return res, sql - - async def close(self): - self.__pool.dispose() diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py deleted file mode 100644 index 382811abf..000000000 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ /dev/null @@ -1,762 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from datetime import datetime -from typing import Any, AsyncGenerator, List - -import pymysql -import pytest -import pytest_asyncio -from csv_diff import compare, load_csv # type: ignore -from google.cloud.sql.connector import Connector - -import models - -from .. import datastore -from . import cloudsql_mysql -from .test_data import ( - amenities_query_embedding1, - amenities_query_embedding2, - foobar_query_embedding, - policies_query_embedding1, - policies_query_embedding2, -) -from .utils import get_env_var - -pytestmark = pytest.mark.asyncio(scope="module") - - -@pytest.fixture(scope="module") -def db_user() -> str: - return get_env_var("DB_USER", "name of a mysql user") - - -@pytest.fixture(scope="module") -def db_pass() -> str: - return get_env_var("DB_PASS", "password for the mysql user") - - -@pytest.fixture(scope="module") -def db_project() -> str: - return get_env_var("DB_PROJECT", "project id for google cloud") - - -@pytest.fixture(scope="module") -def db_region() -> str: - return get_env_var("DB_REGION", "region for cloud sql instance") - - -@pytest.fixture(scope="module") -def db_instance() -> str: - return get_env_var("DB_INSTANCE", "instance for cloud sql") - - -@pytest_asyncio.fixture(scope="module") -async def create_db( - db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str -) -> AsyncGenerator[str, None]: - db_name = get_env_var("DB_NAME", "name of a cloud sql mysql database") - loop = asyncio.get_running_loop() - connector = Connector(loop=loop) - project_instance = f"{db_project}:{db_region}:{db_instance}" - # Database does not exist, create it. - sys_conn: pymysql.Connection = await connector.connect_async( - # cloud sql instance connection name - project_instance, - "pymysql", - user=f"{db_user}", - password=f"{db_pass}", - db="mysql", - ) - cursor = sys_conn.cursor() - - cursor.execute(f"drop database if exists {db_name};") - cursor.execute(f"CREATE DATABASE {db_name};") - conn: pymysql.Connection = await connector.connect_async( - # Cloud SQL instance connection name - project_instance, - "pymysql", - user=f"{db_user}", - password=f"{db_pass}", - db=f"{db_name}", - ) - conn.close() - yield db_name - cursor.execute(f"drop database if exists {db_name};") - cursor.close() - - -@pytest_asyncio.fixture(scope="module") -async def ds( - create_db: str, - db_user: str, - db_pass: str, - db_project: str, - db_region: str, - db_instance: str, -) -> AsyncGenerator[datastore.Client, None]: - cfg = cloudsql_mysql.Config( - kind="cloudsql-mysql", - user=db_user, - password=db_pass, - database=create_db, - project=db_project, - region=db_region, - instance=db_instance, - ) - ds = await datastore.create(cfg) - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, - amenities_ds_path, - flights_ds_path, - policies_ds_path, - ) - await ds.initialize_data(airports, amenities, flights, policies) - - if ds is None: - raise TypeError("datastore creation failure") - yield ds - - await ds.close() - - -def only_embedding_changed(file_diff): - return all( - key == "embedding" - for change in file_diff["changed"] - for key in change["changes"] - ) - - -def check_file_diff(file_diff): - assert file_diff["added"] == [] - assert file_diff["removed"] == [] - assert file_diff["columns_added"] == [] - assert file_diff["columns_removed"] == [] - assert file_diff["changed"] == [] or only_embedding_changed(file_diff) - - -async def test_export_dataset(ds: cloudsql_mysql.Client): - airports, amenities, flights, policies = await ds.export_data() - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - diff_airports = compare( - load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") - ) - check_file_diff(diff_airports) - - diff_amenities = compare( - load_csv(open(amenities_ds_path), "id"), - load_csv(open(amenities_new_path), "id"), - ) - check_file_diff(diff_amenities) - - diff_flights = compare( - load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") - ) - check_file_diff(diff_flights) - - diff_policies = compare( - load_csv(open(policies_ds_path), "id"), - load_csv(open(policies_new_path), "id"), - ) - - check_file_diff(diff_policies) - - -async def test_get_airport_by_id(ds: cloudsql_mysql.Client): - res, sql = await ds.get_airport_by_id(1) - expected = models.Airport( - id=1, - iata="MAG", - name="Madang Airport", - city="Madang", - country="Papua New Guinea", - ) - assert res == expected - assert sql is None - - -@pytest.mark.parametrize( - "iata", - [ - pytest.param("SFO", id="upper_case"), - pytest.param("sfo", id="lower_case"), - ], -) -async def test_get_airport_by_iata(ds: cloudsql_mysql.Client, iata: str): - res, sql = await ds.get_airport_by_iata(iata) - expected = models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - assert res == expected - assert sql is None - - -search_airports_test_data = [ - pytest.param( - "Philippines", - "San jose", - None, - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=2313, - iata="EUQ", - name="Evelio Javier Airport", - city="San Jose", - country="Philippines", - ), - ], - id="country_and_city_only", - ), - pytest.param( - "united states", - "san francisco", - None, - [ - models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - ], - id="country_and_name_only", - ), - pytest.param( - None, - "San Jose", - "San Jose", - [ - models.Airport( - id=1714, - iata="GSJ", - name="San José Airport", - city="San Jose", - country="Guatemala", - ), - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=3548, - iata="SJC", - name="Norman Y. Mineta San Jose International Airport", - city="San Jose", - country="United States", - ), - ], - id="city_and_name_only", - ), - pytest.param( - "Foo", - "FOO BAR", - "Foo bar", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) -async def test_search_airports( - ds: cloudsql_mysql.Client, - country: str, - city: str, - name: str, - expected: List[models.Airport], -): - res, sql = await ds.search_airports(country, city, name) - assert res == expected - assert sql is None - - -async def test_get_amenity(ds: cloudsql_mysql.Client): - res, sql = await ds.get_amenity(0) - expected = models.Amenity( - id=0, - name="Coffee Shop 732", - description="Serving American cuisine.", - location="Near Gate B12", - terminal="Terminal 3", - category="restaurant", - hour="Daily 7:00 am - 10:00 pm", - sunday_start_hour=None, - sunday_end_hour=None, - monday_start_hour=None, - monday_end_hour=None, - tuesday_start_hour=None, - tuesday_end_hour=None, - wednesday_start_hour=None, - wednesday_end_hour=None, - thursday_start_hour=None, - thursday_end_hour=None, - friday_start_hour=None, - friday_end_hour=None, - saturday_start_hour=None, - saturday_end_hour=None, - ) - assert res == expected - assert sql is None - - -amenities_search_test_data = [ - pytest.param( - # "Where can I get coffee near gate A6?" - amenities_query_embedding1, - 0.35, - 1, - [ - { - "name": "Coffee Shop 732", - "description": "Serving American cuisine.", - "location": "Near Gate B12", - "terminal": "Terminal 3", - "category": "restaurant", - "hour": "Daily 7:00 am - 10:00 pm", - }, - ], - id="search_coffee_shop", - ), - pytest.param( - # "Where can I look for luxury goods?" - amenities_query_embedding2, - 0.35, - 2, - [ - { - "name": "Gucci Duty Free", - "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", - "location": "Gate E9", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - { - "name": "Hermes Duty Free", - "description": "High-end French brand duty-free shop offering luxury goods and accessories.", - "location": "Gate E18", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - ], - id="search_luxury_goods", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data -) -async def test_amenities_search( - ds: cloudsql_mysql.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[Any], -): - res, sql = await ds.amenities_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is None - - -async def test_get_flight(ds: cloudsql_mysql.Client): - res, sql = await ds.get_flight(1) - expected = models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime("2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ) - assert res == expected - assert sql is None - - -search_flights_by_number_test_data = [ - pytest.param( - "UA", - "1158", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=55455, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2025-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="B50", - arrival_gate="E4", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "UU", - "0000", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "airline, number, expected", search_flights_by_number_test_data -) -async def test_search_flights_by_number( - ds: cloudsql_mysql.Client, - airline: str, - number: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_number(airline, number) - assert res == expected - assert sql is None - - -search_flights_by_airports_test_data = [ - pytest.param( - "2025-01-01", - "SFO", - "ORD", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=13, - airline="UA", - flight_number="616", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="A11", - arrival_gate="D8", - ), - models.Flight( - id=25, - airline="AA", - flight_number="242", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E30", - arrival_gate="C1", - ), - models.Flight( - id=109, - airline="UA", - flight_number="1640", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E27", - arrival_gate="C24", - ), - models.Flight( - id=119, - airline="AA", - flight_number="197", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="D25", - arrival_gate="E49", - ), - models.Flight( - id=136, - airline="UA", - flight_number="1564", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E3", - arrival_gate="C48", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "2025-01-01", - "FOO", - "BAR", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "date, departure_airport, arrival_airport, expected", - search_flights_by_airports_test_data, -) -async def test_search_flights_by_airports( - ds: cloudsql_mysql.Client, - date: str, - departure_airport: str, - arrival_airport: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - assert res == expected - assert sql is None - - -async def test_insert_ticket(ds: cloudsql_mysql.Client): - await ds.insert_ticket( - "1", - "test", - "test", - "UA", - "1532", - "SFO", - "DEN", - "2025-01-01 05:50:00", - "2025-01-01 09:23:00", - ) - - -async def test_list_tickets(ds: cloudsql_mysql.Client): - res, sql = await ds.list_tickets("1") - expected = [ - { - "user_name": "test", - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "arrival_airport": "DEN", - "departure_time": datetime.strptime( - "2025-01-01 05:50:00", "%Y-%m-%d %H:%M:%S" - ), - "arrival_time": datetime.strptime( - "2025-01-01 09:23:00", "%Y-%m-%d %H:%M:%S" - ), - } - ] - - assert res == expected - assert sql is None - - -validate_ticket_data = [ - pytest.param( - { - "airline": "UA", - "flight_number": "1158", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:57:00", - }, - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ), - None, - ), - pytest.param( - { - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:50:00", - }, - models.Flight( - id=0, - airline="UA", - flight_number="1532", - departure_airport="SFO", - arrival_airport="DEN", - departure_time=datetime.strptime( - "2025-01-01 05:50:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 09:23:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="E49", - arrival_gate="D6", - ), - None, - ), - pytest.param( - { - "airline": "XX", - "flight_number": "9999", - "departure_airport": "ZZZ", - "departure_time": "2025-01-01 05:57:00", - }, - None, - None, - ), -] - - -@pytest.mark.parametrize("params, expected_data, expected_sql", validate_ticket_data) -async def test_validate_ticket( - ds: cloudsql_mysql.Client, params, expected_data, expected_sql -): - flight, sql = await ds.validate_ticket(**params) - assert flight == expected_data - assert sql == expected_sql - - -policies_search_test_data = [ - pytest.param( - # "What is the fee for extra baggage?" - policies_query_embedding1, - 0.35, - 1, - [ - "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", - ], - id="search_extra_baggage_fee", - ), - pytest.param( - # "Can I change my flight?" - policies_query_embedding2, - 0.35, - 2, - [ - "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", - "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", - ], - id="search_flight_delays", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data -) -async def test_policies_search( - ds: cloudsql_mysql.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[str], -): - res, sql = await ds.policies_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is None diff --git a/retrieval_service/datastore/providers/cloudsql_postgres.py b/retrieval_service/datastore/providers/cloudsql_postgres.py deleted file mode 100644 index 364b3fe11..000000000 --- a/retrieval_service/datastore/providers/cloudsql_postgres.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from typing import Any, Literal, Optional - -import asyncpg -from google.cloud.sql.connector import Connector, RefreshStrategy -from pgvector.asyncpg import register_vector -from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine - -import models - -from .. import datastore -from .postgres import Client as PostgresClient - -CLOUD_SQL_PG_IDENTIFIER = "cloudsql-postgres" - - -class Config(BaseModel, datastore.AbstractConfig): - kind: Literal["cloudsql-postgres"] - project: str - region: str - instance: str - user: str - password: str - database: str - - -class Client(datastore.Client[Config]): - __pg_client: PostgresClient - __connector: Optional[Connector] = None - - @datastore.classproperty - def kind(cls): - return CLOUD_SQL_PG_IDENTIFIER - - def __init__(self, async_engine: AsyncEngine): - self.__pg_client = PostgresClient(async_engine) - - @classmethod - async def create(cls, config: Config) -> "Client": - async def getconn() -> asyncpg.Connection: - if cls.__connector is None: - loop = asyncio.get_running_loop() - cls.__connector = Connector( - loop=loop, refresh_strategy=RefreshStrategy.LAZY - ) - - conn: asyncpg.Connection = await cls.__connector.connect_async( - # Cloud SQL instance connection name - f"{config.project}:{config.region}:{config.instance}", - "asyncpg", - user=f"{config.user}", - password=f"{config.password}", - db=f"{config.database}", - ) - await register_vector(conn) - return conn - - async_engine = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, - ) - if async_engine is None: - raise TypeError("async_engine not instantiated") - return cls(async_engine) - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - await self.__pg_client.initialize_data(airports, amenities, flights, policies) - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - return await self.__pg_client.export_data() - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_client.get_airport_by_id(id) - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - return await self.__pg_client.get_airport_by_iata(iata) - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - return await self.__pg_client.search_airports(country, city, name) - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - return await self.__pg_client.get_amenity(id) - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_client.amenities_search( - query_embedding, similarity_threshold, top_k - ) - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_client.get_flight(flight_id) - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_client.search_flights_by_number(airline, number) - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - return await self.__pg_client.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - return await self.__pg_client.validate_ticket( - airline, flight_number, departure_airport, departure_time - ) - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - await self.__pg_client.insert_ticket( - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time, - ) - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - return await self.__pg_client.list_tickets(user_id) - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - return await self.__pg_client.policies_search( - query_embedding, similarity_threshold, top_k - ) - - async def close(self): - await self.__pg_client.close() diff --git a/retrieval_service/datastore/providers/cloudsql_postgres_test.py b/retrieval_service/datastore/providers/cloudsql_postgres_test.py deleted file mode 100644 index 6a9ef0702..000000000 --- a/retrieval_service/datastore/providers/cloudsql_postgres_test.py +++ /dev/null @@ -1,758 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from datetime import datetime -from typing import Any, AsyncGenerator, List - -import asyncpg -import pytest -import pytest_asyncio -from csv_diff import compare, load_csv # type: ignore -from google.cloud.sql.connector import Connector - -import models - -from .. import datastore -from . import cloudsql_postgres -from .test_data import ( - amenities_query_embedding1, - amenities_query_embedding2, - foobar_query_embedding, - policies_query_embedding1, - policies_query_embedding2, -) -from .utils import get_env_var - -pytestmark = pytest.mark.asyncio(scope="module") - - -@pytest.fixture(scope="module") -def db_user() -> str: - return get_env_var("DB_USER", "name of a postgres user") - - -@pytest.fixture(scope="module") -def db_pass() -> str: - return get_env_var("DB_PASS", "password for the postgres user") - - -@pytest.fixture(scope="module") -def db_project() -> str: - return get_env_var("DB_PROJECT", "project id for google cloud") - - -@pytest.fixture(scope="module") -def db_region() -> str: - return get_env_var("DB_REGION", "region for cloud sql instance") - - -@pytest.fixture(scope="module") -def db_instance() -> str: - return get_env_var("DB_INSTANCE", "instance for cloud sql") - - -@pytest_asyncio.fixture(scope="module") -async def create_db( - db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str -) -> AsyncGenerator[str, None]: - db_name = get_env_var("DB_NAME", "name of a postgres database") - loop = asyncio.get_running_loop() - connector = Connector(loop=loop) - project_instance = f"{db_project}:{db_region}:{db_instance}" - # Database does not exist, create it. - sys_conn: asyncpg.Connection = await connector.connect_async( - project_instance, - "asyncpg", - user=f"{db_user}", - password=f"{db_pass}", - db="postgres", - ) - await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') - await sys_conn.execute(f'CREATE DATABASE "{db_name}";') - conn: asyncpg.Connection = await connector.connect_async( - project_instance, - "asyncpg", - user=f"{db_user}", - password=f"{db_pass}", - db=f"{db_name}", - ) - await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") - await conn.close() - yield db_name - await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') - await sys_conn.close() - - -@pytest_asyncio.fixture(scope="module") -async def ds( - create_db: str, - db_user: str, - db_pass: str, - db_project: str, - db_region: str, - db_instance: str, -) -> AsyncGenerator[datastore.Client, None]: - cfg = cloudsql_postgres.Config( - kind="cloudsql-postgres", - user=db_user, - password=db_pass, - database=create_db, - project=db_project, - region=db_region, - instance=db_instance, - ) - ds = await datastore.create(cfg) - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, - amenities_ds_path, - flights_ds_path, - policies_ds_path, - ) - await ds.initialize_data(airports, amenities, flights, policies) - - if ds is None: - raise TypeError("datastore creation failure") - yield ds - await ds.close() - - -def check_file_diff(file_diff): - assert file_diff["added"] == [] - assert file_diff["removed"] == [] - assert file_diff["changed"] == [] - assert file_diff["columns_added"] == [] - assert file_diff["columns_removed"] == [] - - -async def test_export_dataset(ds: cloudsql_postgres.Client): - airports, amenities, flights, policies = await ds.export_data() - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - diff_airports = compare( - load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") - ) - check_file_diff(diff_airports) - - diff_amenities = compare( - load_csv(open(amenities_ds_path), "id"), - load_csv(open(amenities_new_path), "id"), - ) - check_file_diff(diff_amenities) - - diff_flights = compare( - load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") - ) - check_file_diff(diff_flights) - - diff_policies = compare( - load_csv(open(policies_ds_path), "id"), - load_csv(open(policies_new_path), "id"), - ) - check_file_diff(diff_policies) - - -async def test_get_airport_by_id(ds: cloudsql_postgres.Client): - res, sql = await ds.get_airport_by_id(1) - expected = models.Airport( - id=1, - iata="MAG", - name="Madang Airport", - city="Madang", - country="Papua New Guinea", - ) - assert res == expected - assert sql is not None - - -@pytest.mark.parametrize( - "iata", - [ - pytest.param("SFO", id="upper_case"), - pytest.param("sfo", id="lower_case"), - ], -) -async def test_get_airport_by_iata(ds: cloudsql_postgres.Client, iata: str): - res, sql = await ds.get_airport_by_iata(iata) - expected = models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - assert res == expected - assert sql is not None - - -search_airports_test_data = [ - pytest.param( - "Philippines", - "San jose", - None, - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=2313, - iata="EUQ", - name="Evelio Javier Airport", - city="San Jose", - country="Philippines", - ), - ], - id="country_and_city_only", - ), - pytest.param( - "united states", - "san francisco", - None, - [ - models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - ], - id="country_and_name_only", - ), - pytest.param( - None, - "San Jose", - "San Jose", - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=3548, - iata="SJC", - name="Norman Y. Mineta San Jose International Airport", - city="San Jose", - country="United States", - ), - ], - id="city_and_name_only", - ), - pytest.param( - "Foo", - "FOO BAR", - "Foo bar", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) -async def test_search_airports( - ds: cloudsql_postgres.Client, - country: str, - city: str, - name: str, - expected: List[models.Airport], -): - res, sql = await ds.search_airports(country, city, name) - assert res == expected - assert sql is not None - - -async def test_get_amenity(ds: cloudsql_postgres.Client): - res, sql = await ds.get_amenity(0) - expected = models.Amenity( - id=0, - name="Coffee Shop 732", - description="Serving American cuisine.", - location="Near Gate B12", - terminal="Terminal 3", - category="restaurant", - hour="Daily 7:00 am - 10:00 pm", - sunday_start_hour=None, - sunday_end_hour=None, - monday_start_hour=None, - monday_end_hour=None, - tuesday_start_hour=None, - tuesday_end_hour=None, - wednesday_start_hour=None, - wednesday_end_hour=None, - thursday_start_hour=None, - thursday_end_hour=None, - friday_start_hour=None, - friday_end_hour=None, - saturday_start_hour=None, - saturday_end_hour=None, - ) - assert res == expected - assert sql is not None - - -amenities_search_test_data = [ - pytest.param( - # "Where can I get coffee near gate A6?" - amenities_query_embedding1, - 0.35, - 1, - [ - { - "name": "Coffee Shop 732", - "description": "Serving American cuisine.", - "location": "Near Gate B12", - "terminal": "Terminal 3", - "category": "restaurant", - "hour": "Daily 7:00 am - 10:00 pm", - }, - ], - id="search_coffee_shop", - ), - pytest.param( - # "Where can I look for luxury goods?" - amenities_query_embedding2, - 0.35, - 2, - [ - { - "name": "Gucci Duty Free", - "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", - "location": "Gate E9", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - { - "name": "Hermes Duty Free", - "description": "High-end French brand duty-free shop offering luxury goods and accessories.", - "location": "Gate E18", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - ], - id="search_luxury_goods", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.1, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data -) -async def test_amenities_search( - ds: cloudsql_postgres.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[Any], -): - res, sql = await ds.amenities_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is not None - - -async def test_get_flight(ds: cloudsql_postgres.Client): - res, sql = await ds.get_flight(1) - expected = models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime("2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ) - assert res == expected - assert sql is not None - - -search_flights_by_number_test_data = [ - pytest.param( - "UA", - "1158", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=55455, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2025-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="B50", - arrival_gate="E4", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "UU", - "0000", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "airline, number, expected", search_flights_by_number_test_data -) -async def test_search_flights_by_number( - ds: cloudsql_postgres.Client, - airline: str, - number: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_number(airline, number) - assert res == expected - assert sql is not None - - -search_flights_by_airports_test_data = [ - pytest.param( - "2025-01-01", - "SFO", - "ORD", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=13, - airline="UA", - flight_number="616", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="A11", - arrival_gate="D8", - ), - models.Flight( - id=25, - airline="AA", - flight_number="242", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E30", - arrival_gate="C1", - ), - models.Flight( - id=109, - airline="UA", - flight_number="1640", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E27", - arrival_gate="C24", - ), - models.Flight( - id=119, - airline="AA", - flight_number="197", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="D25", - arrival_gate="E49", - ), - models.Flight( - id=136, - airline="UA", - flight_number="1564", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E3", - arrival_gate="C48", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "2025-01-01", - "FOO", - "BAR", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "date, departure_airport, arrival_airport, expected", - search_flights_by_airports_test_data, -) -async def test_search_flights_by_airports( - ds: cloudsql_postgres.Client, - date: str, - departure_airport: str, - arrival_airport: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - assert res == expected - assert sql is not None - - -async def test_insert_ticket(ds: cloudsql_postgres.Client): - await ds.insert_ticket( - "1", - "test", - "test", - "UA", - "1532", - "SFO", - "DEN", - "2025-01-01 05:50:00", - "2025-01-01 09:23:00", - ) - - -async def test_list_tickets(ds: cloudsql_postgres.Client): - res, sql = await ds.list_tickets("1") - expected = [ - { - "user_name": "test", - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "arrival_airport": "DEN", - "departure_time": datetime.strptime( - "2025-01-01 05:50:00", "%Y-%m-%d %H:%M:%S" - ), - "arrival_time": datetime.strptime( - "2025-01-01 09:23:00", "%Y-%m-%d %H:%M:%S" - ), - } - ] - - assert res == expected - assert sql is not None - - -validate_ticket_data = [ - pytest.param( - { - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:50:00", - }, - models.Flight( - id=0, - airline="UA", - flight_number="1532", - departure_airport="SFO", - arrival_airport="DEN", - departure_time=datetime.strptime( - "2025-01-01 05:50:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 09:23:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="E49", - arrival_gate="D6", - ), - 'SELECT *
FROM flights
WHERE airline ILIKE UA
AND flight_number ILIKE 1532
AND departure_airport ILIKE SFO
AND departure_time = 2025-01-01 05:50:00', - ), - pytest.param( - { - "airline": "UA", - "flight_number": "1158", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:57:00", - }, - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ), - 'SELECT *
FROM flights
WHERE airline ILIKE UA
AND flight_number ILIKE 1158
AND departure_airport ILIKE SFO
AND departure_time = 2025-01-01 05:57:00', - ), - pytest.param( - { - "airline": "XX", - "flight_number": "9999", - "departure_airport": "ZZZ", - "departure_time": "2025-01-01 05:57:00", - }, - None, - None, - ), -] - - -@pytest.mark.parametrize("params, expected_data, expected_sql", validate_ticket_data) -async def test_validate_ticket( - ds: cloudsql_postgres.Client, params, expected_data, expected_sql -): - flight, sql = await ds.validate_ticket(**params) - assert flight == expected_data - assert sql == expected_sql - - -policies_search_test_data = [ - pytest.param( - # "What is the fee for extra baggage?" - policies_query_embedding1, - 0.35, - 1, - [ - "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", - ], - id="search_extra_baggage_fee", - ), - pytest.param( - # "Can I change my flight?" - policies_query_embedding2, - 0.35, - 2, - [ - "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", - "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", - ], - id="search_flight_delays", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.35, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data -) -async def test_policies_search( - ds: cloudsql_postgres.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[str], -): - res, sql = await ds.policies_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is not None diff --git a/retrieval_service/datastore/providers/firestore.py b/retrieval_service/datastore/providers/firestore.py deleted file mode 100644 index f52390a4d..000000000 --- a/retrieval_service/datastore/providers/firestore.py +++ /dev/null @@ -1,539 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from datetime import datetime, timedelta -from typing import Any, Literal, Optional - -from google.cloud.firestore import AsyncClient # type: ignore -from google.cloud.firestore_v1.async_collection import AsyncCollectionReference -from google.cloud.firestore_v1.async_query import AsyncQuery -from google.cloud.firestore_v1.base_query import FieldFilter -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.vector import Vector -from pydantic import BaseModel - -import models - -from .. import datastore - -FIRESTORE_IDENTIFIER = "firestore" - - -class Config(BaseModel, datastore.AbstractConfig): - kind: Literal["firestore"] - projectId: Optional[str] - - -class Client(datastore.Client[Config]): - __client: AsyncClient - - @datastore.classproperty - def kind(cls): - return FIRESTORE_IDENTIFIER - - def __init__(self, client: AsyncClient): - self.__client = client - self.__policies_collection = AsyncQuery(self.__client.collection("policies")) - self.__amenities_collection = AsyncQuery(self.__client.collection("amenities")) - - @classmethod - async def create(cls, config: Config) -> "Client": - return cls(AsyncClient(project=config.projectId)) - - async def __delete_collections( - self, collection_list: list[AsyncCollectionReference] - ): - # Checks if collection exists and deletes all documents - delete_tasks = [] - for collection_ref in collection_list: - collection_exists = collection_ref.limit(1).stream() - if not collection_exists: - continue - - docs = collection_ref.stream() - async for doc in docs: - delete_tasks.append(asyncio.create_task(doc.reference.delete())) - await asyncio.gather(*delete_tasks) - - async def parse_index_info(self, line: str) -> tuple[str, str]: - # Extract collection and index-id from file path - parts = line.split("/") - collection_name = parts[-3] - index_id = parts[-1] - return collection_name, index_id - - async def __get_indices(self) -> dict[str, str]: - list_vector_index_process = await asyncio.create_subprocess_exec( - "gcloud", - "alpha", - "firestore", - "indexes", - "composite", - "list", - "--database=(default)", - "--format=value(name)", # prints name field - stdout=asyncio.subprocess.PIPE, - ) - - # Capture output and ignore stderr - stdout, __ = await list_vector_index_process.communicate() - - # Decode and format output - index_lines = stdout.decode().strip().split("\n") - - indices = {} - - # Create a dict with collections and their corresponding vector index. - for line in index_lines: - if line: - collection, index_id = await self.parse_index_info(line) - indices[collection] = index_id - - return indices - - async def __delete_vector_index(self, indices: list[str]): - # Check if the collection exists and deletes all indexes - for index in indices: - if index: - delete_vector_index = await asyncio.create_subprocess_exec( - "gcloud", - "alpha", - "firestore", - "indexes", - "composite", - "delete", - index, - "--database=(default)", - "--quiet", # Added to suppress delete warning - ) - await delete_vector_index.wait() - - async def __create_vector_index(self, collection_name: str): - create_vector_index = await asyncio.create_subprocess_exec( - "gcloud", - "alpha", - "firestore", - "indexes", - "composite", - "create", - f"--collection-group={collection_name}", - "--query-scope=COLLECTION", - '--field-config=field-path=embedding,vector-config={"dimension":768,"flat":"{}"}', - "--database=(default)", - ) - await create_vector_index.wait() - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - # Check if the collections already exist; if so, delete collections - airports_ref = self.__client.collection("airports") - amenities_ref = self.__client.collection("amenities") - flights_ref = self.__client.collection("flights") - policies_ref = self.__client.collection("policies") - await self.__delete_collections( - [airports_ref, amenities_ref, flights_ref, policies_ref] - ) - - # Retrieve vector indexes and check if the collections already exist; if so, delete collections - indices = await self.__get_indices() - amenities_ref = indices.get("amenities", "") - policies_ref = indices.get("policies", "") - await self.__delete_vector_index([amenities_ref, policies_ref]) - - # Initialize collections - create_airports_tasks = [] - for airport in airports: - create_airports_tasks.append( - self.__client.collection("airports") - .document(str(airport.id)) - .set( - { - "iata": airport.iata, - "name": airport.name, - "city": airport.city, - "country": airport.country, - } - ) - ) - await asyncio.gather(*create_airports_tasks) - create_amenities_tasks = [] - for amenity in amenities: - create_amenities_tasks.append( - self.__client.collection("amenities") - .document(str(amenity.id)) - .set( - { - "name": amenity.name, - "description": amenity.description, - "location": amenity.location, - "terminal": amenity.terminal, - "category": amenity.category, - "hour": amenity.hour, - # Firebase does not support datetime.time type - "sunday_start_hour": ( - str(amenity.sunday_start_hour) - if amenity.sunday_start_hour - else None - ), - "sunday_end_hour": ( - str(amenity.sunday_end_hour) - if amenity.sunday_end_hour - else None - ), - "monday_start_hour": ( - str(amenity.monday_start_hour) - if amenity.monday_start_hour - else None - ), - "monday_end_hour": ( - str(amenity.monday_end_hour) - if amenity.monday_end_hour - else None - ), - "tuesday_start_hour": ( - str(amenity.tuesday_start_hour) - if amenity.tuesday_start_hour - else None - ), - "tuesday_end_hour": ( - str(amenity.tuesday_end_hour) - if amenity.tuesday_end_hour - else None - ), - "wednesday_start_hour": ( - str(amenity.wednesday_start_hour) - if amenity.wednesday_start_hour - else None - ), - "wednesday_end_hour": ( - str(amenity.wednesday_end_hour) - if amenity.wednesday_end_hour - else None - ), - "thursday_start_hour": ( - str(amenity.thursday_start_hour) - if amenity.thursday_start_hour - else None - ), - "thursday_end_hour": ( - str(amenity.thursday_end_hour) - if amenity.thursday_end_hour - else None - ), - "friday_start_hour": ( - str(amenity.friday_start_hour) - if amenity.friday_start_hour - else None - ), - "friday_end_hour": ( - str(amenity.friday_end_hour) - if amenity.friday_end_hour - else None - ), - "saturday_start_hour": ( - str(amenity.saturday_start_hour) - if amenity.saturday_start_hour - else None - ), - "saturday_end_hour": ( - str(amenity.saturday_end_hour) - if amenity.saturday_end_hour - else None - ), - "content": amenity.content, - # Vector type does not support None value - "embedding": Vector(amenity.embedding or []), - } - ) - ) - await asyncio.gather(*create_amenities_tasks) - create_flights_tasks = [] - for flight in flights: - create_flights_tasks.append( - self.__client.collection("flights") - .document(str(flight.id)) - .set( - { - "airline": flight.airline, - "flight_number": flight.flight_number, - "departure_airport": flight.departure_airport, - "arrival_airport": flight.arrival_airport, - "departure_time": flight.departure_time.strftime( - "%Y-%m-%d %H:%M:%S" - ), - "arrival_time": flight.arrival_time.strftime( - "%Y-%m-%d %H:%M:%S" - ), - "departure_gate": flight.departure_gate, - "arrival_gate": flight.arrival_gate, - } - ) - ) - if len(create_flights_tasks) % 10000 == 0: - # avoid gRPC batch write timeout error - await asyncio.gather(*create_flights_tasks) - create_flights_tasks.clear() - await asyncio.gather(*create_flights_tasks) - create_policies_tasks = [] - for policy in policies: - create_policies_tasks.append( - self.__client.collection("policies") - .document(str(policy.id)) - .set( - { - "content": policy.content, - # Vector type does not accept None value - "embedding": Vector(policy.embedding or []), - } - ) - ) - await asyncio.gather(*create_policies_tasks) - - # Initialize single-field vector indexes - await self.__create_vector_index("amenities") - await self.__create_vector_index("policies") - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - airport_docs = self.__client.collection("airports").stream() - amenities_docs = self.__client.collection("amenities").stream() - flights_docs = self.__client.collection("flights").stream() - policies_docs = self.__client.collection("policies").stream() - - airports = [] - async for doc in airport_docs: - airport_dict = doc.to_dict() - airport_dict["id"] = doc.id - airports.append(models.Airport.model_validate(airport_dict)) - - amenities = [] - async for doc in amenities_docs: - amenity_dict = doc.to_dict() - amenity_dict["id"] = doc.id - amenity_dict["embedding"] = list(amenity_dict["embedding"]) - amenities.append(models.Amenity.model_validate(amenity_dict)) - - flights = [] - async for doc in flights_docs: - flight_dict = doc.to_dict() - flight_dict["id"] = doc.id - flights.append(models.Flight.model_validate(flight_dict)) - - policies = [] - async for doc in policies_docs: - policy_dict = doc.to_dict() - policy_dict["id"] = doc.id - policy_dict["embedding"] = list(policy_dict["embedding"]) - policies.append(models.Policy.model_validate(policy_dict)) - - return airports, amenities, flights, policies - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - query = self.__client.collection("airports").where( - filter=FieldFilter("id", "==", id) - ) - airport_doc = await query.get() - airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} - return models.Airport.model_validate(airport_dict), None - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - query = self.__client.collection("airports").where( - filter=FieldFilter("iata", "==", iata) - ) - airport_doc = await query.get() - airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} - return models.Airport.model_validate(airport_dict), None - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - query = self.__client.collection("airports") - - if country is not None: - query = query.where("country", "==", country) - - if city is not None: - query = query.where("city", "==", city) - - if name is not None: - query = query.where("name", ">=", name).where("name", "<=", name + "\uf8ff") - - query = query.limit(10) - - docs = query.stream() - airports = [] - async for doc in docs: - airport_dict = doc.to_dict() | {"id": doc.id} - airports.append(models.Airport.model_validate(airport_dict)) - return airports, None - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - query = self.__client.collection("amenities").where( - filter=FieldFilter("id", "==", id) - ) - amenity_doc = await query.get() - amenity_dict = amenity_doc.to_dict() | {"id": amenity_doc.id} - amenity_dict["embedding"] = list(amenity_dict["embedding"]) - return models.Amenity.model_validate(amenity_dict), None - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - # Using the same similarity metric to the embedding model's training method - # produce the most accurate result - query = self.__amenities_collection.find_nearest( - vector_field="embedding", - query_vector=Vector(query_embedding), - distance_measure=DistanceMeasure.DOT_PRODUCT, - limit=top_k, - ) - - docs = query.stream() - amenities = [] - async for doc in docs: - amenity_dict = { - "id": doc.id, - "category": doc.get("category"), - "description": doc.get("description"), - "hour": doc.get("hour"), - "location": doc.get("location"), - "name": doc.get("name"), - "terminal": doc.get("terminal"), - } - amenities.append(amenity_dict) - return amenities, None - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - query = self.__client.collection("flights").where( - filter=FieldFilter("id", "==", flight_id) - ) - flight_doc = await query.get() - flight_dict = flight_doc.to_dict() | {"id": flight_doc.id} - return models.Flight.model_validate(flight_dict), None - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - query = ( - self.__client.collection("flights") - .where(filter=FieldFilter("airline", "==", airline)) - .where(filter=FieldFilter("flight_number", "==", number)) - .limit(10) - ) - - docs = query.stream() - flights = [] - async for doc in docs: - flight_dict = doc.to_dict() | {"id": doc.id} - flights.append(models.Flight.model_validate(flight_dict)) - return flights, None - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - date_obj = datetime.strptime(date, "%Y-%m-%d").date() - date_timestamp = datetime.combine(date_obj, datetime.min.time()) - query = ( - self.__client.collection("flights") - .where("departure_time", ">=", date_timestamp) - .where("departure_time", "<", date_timestamp + timedelta(days=1)) - .limit(10) - ) - - if departure_airport is None: - query = query.where("departure_airport", "==", departure_airport) - if arrival_airport is None: - query = query.where("arrival_airport", "==", arrival_airport) - - docs = query.stream() - flights = [] - async for doc in docs: - flight_dict = doc.to_dict() | {"id": doc.id} - flights.append(models.Flight.model_validate(flight_dict)) - return flights, None - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - raise NotImplementedError("Not Implemented") - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - raise NotImplementedError("Not Implemented") - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - raise NotImplementedError("Not Implemented") - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - query = self.__policies_collection.find_nearest( - vector_field="embedding", - query_vector=Vector(query_embedding), - distance_measure=DistanceMeasure.DOT_PRODUCT, - limit=top_k, - ) - - policies = [] - async for doc in query.stream(): - policies.append(doc.get("content")) - return policies, None - - async def close(self): - self.__client.close() diff --git a/retrieval_service/datastore/providers/firestore_test.py b/retrieval_service/datastore/providers/firestore_test.py deleted file mode 100644 index 8badd0401..000000000 --- a/retrieval_service/datastore/providers/firestore_test.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Dict - -from google.cloud.firestore import AsyncClient, Client # type: ignore -from google.cloud.firestore_v1.base_query import FieldFilter - -import models - -from . import firestore as firestore_provider - - -class MockDocument(Dict): - """ - Mock firestore document. - """ - - id: int - content: Dict - - def __init(self, id, content): - self.id = id - self.content = content - - def to_dict(self): - return self.content - - -class MockCollection(Dict): - """ - Mock firestore collection. - """ - - collection_name: str - documents = Dict[str, MockDocument] - - def __init__(self, collection_name: str): - self.collection_name = collection_name - - def where(self, filter: FieldFilter): - return self.documents - - def select(self, *args): - return self.documents - - -class MockFirestoreClient(AsyncClient): - """ - Mock firestore client. - """ - - collections: Dict[str, MockCollection] - - def __init__(self): - self.collections = {} - - def collection(self, collection_name: str): - return self.collections[collection_name] - - -async def mock_client(mock_firestore_client: MockFirestoreClient) -> Client: - return firestore_provider.Client(mock_firestore_client) - - -async def test_get_airport_by_id(): - fake_id = 1 - mock_document = MockDocument( - fake_id, - { - "iata": "Fake iata", - "name": "Fake name", - "city": "Fake city", - "country": "Fake country", - }, - ) - mock_collection = MockCollection("airports") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["airports"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.get_airport_by_id(fake_id) - expected_res = models.Airport( - id=fake_id, - iata="Fake iata", - name="Fake name", - city="Fake city", - country="Fake country", - ) - assert res == expected_res - - -async def test_get_airport_by_iata(): - fake_id = 1 - fake_iata = "Fake iata" - mock_document = MockDocument( - fake_id, - { - "iata": fake_iata, - "name": "Fake name", - "city": "Fake city", - "country": "Fake country", - }, - ) - mock_collection = MockCollection("airports") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["airports"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.get_airport_by_iata(fake_iata) - expected_res = models.Airport( - id=fake_id, - iata=fake_iata, - name="Fake name", - city="Fake city", - country="Fake country", - ) - assert res == expected_res - - -async def test_search_airports(): - fake_id = 3 - fake_name = "Fake name" - fake_country = "Fake country" - fake_city = "Fake city" - mock_document = MockDocument( - fake_id, - { - "iata": "Fake iata", - "name": fake_name, - "city": fake_city, - "country": fake_country, - }, - ) - mock_collection = MockCollection("airports") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["airports"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.search_airports(fake_country, fake_city, fake_name) - expected_res = [ - models.Airport( - id=fake_id, - iata="Fake iata", - name=fake_name, - city=fake_city, - country=fake_country, - ) - ] - - assert res == expected_res - - -async def test_get_amenity(): - fake_id = 2 - mock_document = MockDocument( - fake_id, - { - "name", - "Fake name", - "description", - "Fake description", - "location", - "Fake location", - "terminal", - "Fake terminal", - "category", - "Fake category", - "hour", - "Fake hour", - }, - ) - mock_collection = MockCollection("amenities") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["amenities"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.get_amenity(fake_id) - expected_res = models.Amenity( - id=fake_id, - name="Fake name", - description="Fake description", - location="Fake location", - terminal="Fake terminal", - category="Fake category", - hour="Fake hour", - ) - assert res == expected_res - - -async def test_amenities_search(): - fake_id = 3 - mock_document = MockDocument( - fake_id, - { - "name": "Fake name", - "description": "Fake description", - "location": "Fake location", - "terminal": "Fake terminal", - "category": "Fake category", - "hour": "Fake hour", - }, - ) - mock_collection = MockCollection("amenities") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["amenities"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.amenities_search(1, 0.7, 1) - expected_res = [ - models.Amenity( - id=fake_id, - name="Fake name", - description="Fake description", - location="Fake location", - terminal="Fake terminal", - category="Fake category", - hour="Fake hour", - ) - ] - assert res == expected_res - - -async def test_get_flight(): - fake_id = 4 - fake_datetime = datetime.datetime(2023, 11, 14, 12, 30, 45) - mock_document = MockDocument( - fake_id, - { - "airline": "Fake airline", - "flight_number": "Fake flight number", - "departure_airport": "Fake departure airport", - "arrival_airport": "Fake arrival airport", - "departure_time": fake_datetime, - "arrival_time": fake_datetime, - "departure_gate": "fake departure gate", - "arrival_gate": "fake arrival gate", - }, - ) - mock_collection = MockCollection("flights") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["flights"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.get_flight(fake_id) - expected_res = models.Flight( - id=fake_id, - airline="Fake airline", - flight_number="Fake flight number", - departure_airport="Fake departure airport", - arrival_airport="Fake arrival airport", - departure_time=fake_datetime, - arrival_time=fake_datetime, - departure_gate="fake departure gate", - arrival_gate="fake arrival gate", - ) - assert res == expected_res - - -async def test_search_flights_by_airports(): - fake_id = 5 - fake_date = "2023-11-14" - fake_datetime = datetime.datetime(2023, 11, 14, 12, 30, 45) - mock_document = MockDocument( - fake_id, - { - "airline": "Fake airline", - "flight_number": "Fake flight number", - "departure_airport": "Fake departure airport", - "arrival_airport": "Fake arrival airport", - "departure_time": fake_datetime, - "arrival_time": fake_datetime, - "departure_gate": "fake departure gate", - "arrival_gate": "fake arrival gate", - }, - ) - mock_collection = MockCollection("flights") - mock_collection.documents[fake_id, mock_document] - mock_firestore_client = MockFirestoreClient() - mock_firestore_client.collection["flights"] = mock_collection - - mock_client = await mock_client(mock_firestore_client) - res = await mock_client.search_flights_by_airport( - fake_date, "Fake departure airport", "Fake arrival airport" - ) - expected_res = [ - models.Flight( - id=fake_id, - airline="Fake airline", - flight_number="Fake flight number", - departure_airport="Fake departure airport", - arrival_airport="Fake arrival airport", - departure_time=fake_datetime, - arrival_time=fake_datetime, - departure_gate="fake departure gate", - arrival_gate="fake arrival gate", - ) - ] - assert res == expected_res diff --git a/retrieval_service/datastore/providers/postgres.py b/retrieval_service/datastore/providers/postgres.py deleted file mode 100644 index a806f01b5..000000000 --- a/retrieval_service/datastore/providers/postgres.py +++ /dev/null @@ -1,614 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from datetime import datetime -from ipaddress import IPv4Address, IPv6Address -from typing import Any, Literal, Optional - -import asyncpg -from pgvector.asyncpg import register_vector -from pydantic import BaseModel -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine - -import models - -from .. import datastore -from ..helpers import format_sql - -POSTGRES_IDENTIFIER = "postgres" - - -class Config(BaseModel, datastore.AbstractConfig): - kind: Literal["postgres"] - host: IPv4Address | IPv6Address = IPv4Address("127.0.0.1") - port: int = 5432 - user: str - password: str - database: str - - -class Client(datastore.Client[Config]): - __async_engine: AsyncEngine - - @datastore.classproperty - def kind(cls): - return POSTGRES_IDENTIFIER - - def __init__(self, async_engine: AsyncEngine): - self.__async_engine = async_engine - - @classmethod - async def create(cls, config: Config) -> "Client": - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await asyncpg.connection.connect( - host=str(config.host), - user=config.user, - password=config.password, - database=config.database, - port=config.port, - ) - await register_vector(conn) - return conn - - async_engine = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, - ) - if async_engine is None: - raise TypeError("async_engine not instantiated") - return cls(async_engine) - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - async with self.__async_engine.connect() as conn: - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS airports CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE airports( - id INT PRIMARY KEY, - iata TEXT, - name TEXT, - city TEXT, - country TEXT - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)""" - ), - [ - { - "id": a.id, - "iata": a.iata, - "name": a.name, - "city": a.city, - "country": a.country, - } - for a in airports - ], - ) - - await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE amenities( - id INT PRIMARY KEY, - name TEXT, - description TEXT, - location TEXT, - terminal TEXT, - category TEXT, - hour TEXT, - sunday_start_hour TIME, - sunday_end_hour TIME, - monday_start_hour TIME, - monday_end_hour TIME, - tuesday_start_hour TIME, - tuesday_end_hour TIME, - wednesday_start_hour TIME, - wednesday_end_hour TIME, - thursday_start_hour TIME, - thursday_end_hour TIME, - friday_start_hour TIME, - friday_end_hour TIME, - saturday_start_hour TIME, - saturday_end_hour TIME, - content TEXT NOT NULL, - embedding vector(768) NOT NULL - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """ - INSERT INTO amenities VALUES (:id, :name, :description, :location, - :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour, - :monday_start_hour, :monday_end_hour, :tuesday_start_hour, - :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour, - :thursday_start_hour, :thursday_end_hour, :friday_start_hour, - :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, :embedding) - """ - ), - [ - { - "id": a.id, - "name": a.name, - "description": a.description, - "location": a.location, - "terminal": a.terminal, - "category": a.category, - "hour": a.hour, - "sunday_start_hour": a.sunday_start_hour, - "sunday_end_hour": a.sunday_end_hour, - "monday_start_hour": a.monday_start_hour, - "monday_end_hour": a.monday_end_hour, - "tuesday_start_hour": a.tuesday_start_hour, - "tuesday_end_hour": a.tuesday_end_hour, - "wednesday_start_hour": a.wednesday_start_hour, - "wednesday_end_hour": a.wednesday_end_hour, - "thursday_start_hour": a.thursday_start_hour, - "thursday_end_hour": a.thursday_end_hour, - "friday_start_hour": a.friday_start_hour, - "friday_end_hour": a.friday_end_hour, - "saturday_start_hour": a.saturday_start_hour, - "saturday_end_hour": a.saturday_end_hour, - "content": a.content, - "embedding": a.embedding, - } - for a in amenities - ], - ) - - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS flights CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE flights( - id INTEGER PRIMARY KEY, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP, - departure_gate TEXT, - arrival_gate TEXT - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """ - INSERT INTO flights VALUES (:id, :airline, :flight_number, - :departure_airport, :arrival_airport, :departure_time, - :arrival_time, :departure_gate, :arrival_gate) - """ - ), - [ - { - "id": f.id, - "airline": f.airline, - "flight_number": f.flight_number, - "departure_airport": f.departure_airport, - "arrival_airport": f.arrival_airport, - "departure_time": f.departure_time, - "arrival_time": f.arrival_time, - "departure_gate": f.departure_gate, - "arrival_gate": f.arrival_gate, - } - for f in flights - ], - ) - - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS tickets CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE tickets( - user_id TEXT, - user_name TEXT, - user_email TEXT, - airline TEXT, - flight_number TEXT, - departure_airport TEXT, - arrival_airport TEXT, - departure_time TIMESTAMP, - arrival_time TIMESTAMP - ) - """ - ) - ) - - # If the table already exists, drop it to avoid conflicts - await conn.execute(text("DROP TABLE IF EXISTS policies CASCADE")) - # Create a new table - await conn.execute( - text( - """ - CREATE TABLE policies( - id INT PRIMARY KEY, - content TEXT NOT NULL, - embedding vector(768) NOT NULL - ) - """ - ) - ) - # Insert all the data - await conn.execute( - text( - """ - INSERT INTO policies VALUES (:id, :content, :embedding) - """ - ), - [ - { - "id": p.id, - "content": p.content, - "embedding": p.embedding, - } - for p in policies - ], - ) - await conn.commit() - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - async with self.__async_engine.connect() as conn: - airport_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM airports ORDER BY id ASC""")) - ) - amenity_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM amenities ORDER BY id ASC""")) - ) - flights_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM flights ORDER BY id ASC""")) - ) - policy_task = asyncio.create_task( - conn.execute(text("""SELECT * FROM policies ORDER BY id ASC""")) - ) - - airport_results = (await airport_task).mappings().fetchall() - amenity_results = (await amenity_task).mappings().fetchall() - flights_results = (await flights_task).mappings().fetchall() - policy_results = (await policy_task).mappings().fetchall() - - airports = [models.Airport.model_validate(a) for a in airport_results] - amenities = [models.Amenity.model_validate(a) for a in amenity_results] - flights = [models.Flight.model_validate(f) for f in flights_results] - policies = [models.Policy.model_validate(p) for p in policy_results] - - return airports, amenities, flights, policies - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """SELECT * FROM airports WHERE id=:id""" - s = text(sql) - params = {"id": id} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Airport.model_validate(result) - return res, format_sql(sql, params) - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """SELECT * FROM airports WHERE iata ILIKE :iata""" - s = text(sql) - params = {"iata": iata} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Airport.model_validate(result) - return res, format_sql(sql, params) - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT * FROM airports - WHERE (CAST(:country AS TEXT) IS NULL OR country ILIKE :country) - AND (CAST(:city AS TEXT) IS NULL OR city ILIKE :city) - AND (CAST(:name AS TEXT) IS NULL OR name ILIKE '%' || :name || '%') - LIMIT 10 - """ - s = text(sql) - params = { - "country": country, - "city": city, - "name": name, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [models.Airport.model_validate(r) for r in results] - return res, format_sql(sql, params) - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT id, name, description, location, terminal, category, hour - FROM amenities WHERE id=:id - """ - s = text(sql) - params = {"id": id} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Amenity.model_validate(result) - return res, format_sql(sql, params) - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT name, description, location, terminal, category, hour - FROM amenities - WHERE (embedding <=> :query_embedding) < :similarity_threshold - ORDER BY (embedding <=> :query_embedding) - LIMIT :top_k - """ - s = text(sql) - params = { - "query_embedding": query_embedding, - "similarity_threshold": similarity_threshold, - "top_k": top_k, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [r for r in results] - return res, format_sql(sql, params) - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE id = :flight_id - """ - s = text(sql) - params = {"flight_id": flight_id} - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - - res = models.Flight.model_validate(result) - return res, format_sql(sql, params) - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE airline = :airline - AND flight_number = :number - LIMIT 10 - """ - s = text(sql) - params = { - "airline": airline, - "number": number, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [models.Flight.model_validate(r) for r in results] - return res, format_sql(sql, params) - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE (CAST(:departure_airport AS TEXT) IS NULL OR departure_airport ILIKE :departure_airport) - AND (CAST(:arrival_airport AS TEXT) IS NULL OR arrival_airport ILIKE :arrival_airport) - AND departure_time >= CAST(:datetime AS timestamp) - AND departure_time < CAST(:datetime AS timestamp) + interval '1 day' - LIMIT 10 - """ - s = text(sql) - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "datetime": datetime.strptime(date, "%Y-%m-%d"), - } - - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [models.Flight.model_validate(r) for r in results] - return res, format_sql(sql, params) - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S") - async with self.__async_engine.connect() as conn: - sql = """ - SELECT * FROM flights - WHERE airline ILIKE :airline - AND flight_number ILIKE :flight_number - AND departure_airport ILIKE :departure_airport - AND departure_time = :departure_time - """ - s = text(sql) - params = { - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "departure_time": departure_time_datetime, - } - result = (await conn.execute(s, params)).mappings().fetchone() - - if result is None: - return None, None - res = models.Flight.model_validate(result) - return res, format_sql(sql, params) - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S") - arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S") - - async with self.__async_engine.connect() as conn: - s = text( - """ - INSERT INTO tickets ( - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time, - arrival_time - ) VALUES ( - :user_id, - :user_name, - :user_email, - :airline, - :flight_number, - :departure_airport, - :arrival_airport, - :departure_time, - :arrival_time - ); - """ - ) - params = { - "user_id": user_id, - "user_name": user_name, - "user_email": user_email, - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "departure_time": departure_time_datetime, - "arrival_time": arrival_time_datetime, - } - result = (await conn.execute(s, params)).mappings() - await conn.commit() - if not result: - raise Exception("Ticket Insertion failure") - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets - WHERE user_id = :user_id - """ - s = text(sql) - params = { - "user_id": user_id, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [r for r in results] - return res, format_sql(sql, params) - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - async with self.__async_engine.connect() as conn: - sql = """ - SELECT content - FROM policies - WHERE (embedding <=> :query_embedding) < :similarity_threshold - ORDER BY (embedding <=> :query_embedding) - LIMIT :top_k - """ - s = text(sql) - params = { - "query_embedding": query_embedding, - "similarity_threshold": similarity_threshold, - "top_k": top_k, - } - results = (await conn.execute(s, params)).mappings().fetchall() - - res = [r["content"] for r in results] - return res, format_sql(sql, params) - - async def close(self): - await self.__async_engine.dispose() diff --git a/retrieval_service/datastore/providers/postgres_test.py b/retrieval_service/datastore/providers/postgres_test.py deleted file mode 100644 index 0c4cec94c..000000000 --- a/retrieval_service/datastore/providers/postgres_test.py +++ /dev/null @@ -1,651 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from ipaddress import IPv4Address -from typing import Any, AsyncGenerator, List - -import pytest -import pytest_asyncio -from csv_diff import compare, load_csv # type: ignore - -import models - -from .. import datastore -from . import postgres -from .test_data import ( - amenities_query_embedding1, - amenities_query_embedding2, - foobar_query_embedding, - policies_query_embedding1, - policies_query_embedding2, -) -from .utils import get_env_var - -pytestmark = pytest.mark.asyncio(scope="module") - - -@pytest.fixture(scope="module") -def db_user() -> str: - return get_env_var("DB_USER", "name of a postgres user") - - -@pytest.fixture(scope="module") -def db_pass() -> str: - return get_env_var("DB_PASS", "password for the postgres user") - - -@pytest.fixture(scope="module") -def db_name() -> str: - return get_env_var("DB_NAME", "name of a postgres database") - - -@pytest.fixture(scope="module") -def db_host() -> str: - return get_env_var("DB_HOST", "ip address of a postgres database") - - -@pytest_asyncio.fixture(scope="module") -async def ds( - db_user: str, db_pass: str, db_name: str, db_host: str -) -> AsyncGenerator[datastore.Client, None]: - cfg = postgres.Config( - kind="postgres", - user=db_user, - password=db_pass, - database=db_name, - host=IPv4Address(db_host), - ) - ds = await datastore.create(cfg) - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, - amenities_ds_path, - flights_ds_path, - policies_ds_path, - ) - await ds.initialize_data(airports, amenities, flights, policies) - - if ds is None: - raise TypeError("datastore creation failure") - yield ds - await ds.close() - - -def check_file_diff(file_diff): - assert file_diff["added"] == [] - assert file_diff["removed"] == [] - assert file_diff["changed"] == [] - assert file_diff["columns_added"] == [] - assert file_diff["columns_removed"] == [] - - -async def test_export_dataset(ds: postgres.Client): - airports, amenities, flights, policies = await ds.export_data() - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - diff_airports = compare( - load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") - ) - check_file_diff(diff_airports) - - diff_amenities = compare( - load_csv(open(amenities_ds_path), "id"), - load_csv(open(amenities_new_path), "id"), - ) - check_file_diff(diff_amenities) - - diff_flights = compare( - load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") - ) - check_file_diff(diff_flights) - - diff_policies = compare( - load_csv(open(policies_ds_path), "id"), - load_csv(open(policies_new_path), "id"), - ) - check_file_diff(diff_policies) - - -async def test_get_airport_by_id(ds: postgres.Client): - res, sql = await ds.get_airport_by_id(1) - expected = models.Airport( - id=1, - iata="MAG", - name="Madang Airport", - city="Madang", - country="Papua New Guinea", - ) - assert res == expected - assert sql is not None - - -@pytest.mark.parametrize( - "iata", - [ - pytest.param("SFO", id="upper_case"), - pytest.param("sfo", id="lower_case"), - ], -) -async def test_get_airport_by_iata(ds: postgres.Client, iata: str): - res, sql = await ds.get_airport_by_iata(iata) - expected = models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - assert res == expected - assert sql is not None - - -search_airports_test_data = [ - pytest.param( - "Philippines", - "San jose", - None, - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=2313, - iata="EUQ", - name="Evelio Javier Airport", - city="San Jose", - country="Philippines", - ), - ], - id="country_and_city_only", - ), - pytest.param( - "united states", - "san francisco", - None, - [ - models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - ], - id="country_and_name_only", - ), - pytest.param( - None, - "San Jose", - "San Jose", - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=3548, - iata="SJC", - name="Norman Y. Mineta San Jose International Airport", - city="San Jose", - country="United States", - ), - ], - id="city_and_name_only", - ), - pytest.param( - "Foo", - "FOO BAR", - "Foo bar", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) -async def test_search_airports( - ds: postgres.Client, - country: str, - city: str, - name: str, - expected: List[models.Airport], -): - res, sql = await ds.search_airports(country, city, name) - assert res == expected - assert sql is not None - - -async def test_get_amenity(ds: postgres.Client): - res, sql = await ds.get_amenity(0) - expected = models.Amenity( - id=0, - name="Coffee Shop 732", - description="Serving American cuisine.", - location="Near Gate B12", - terminal="Terminal 3", - category="restaurant", - hour="Daily 7:00 am - 10:00 pm", - sunday_start_hour=None, - sunday_end_hour=None, - monday_start_hour=None, - monday_end_hour=None, - tuesday_start_hour=None, - tuesday_end_hour=None, - wednesday_start_hour=None, - wednesday_end_hour=None, - thursday_start_hour=None, - thursday_end_hour=None, - friday_start_hour=None, - friday_end_hour=None, - saturday_start_hour=None, - saturday_end_hour=None, - ) - assert res == expected - assert sql is not None - - -amenities_search_test_data = [ - pytest.param( - # "Where can I get coffee near gate A6?" - amenities_query_embedding1, - 0.35, - 1, - [ - { - "name": "Coffee Shop 732", - "description": "Serving American cuisine.", - "location": "Near Gate B12", - "terminal": "Terminal 3", - "category": "restaurant", - "hour": "Daily 7:00 am - 10:00 pm", - } - ], - id="search_coffee_shop", - ), - pytest.param( - # "Where can I look for luxury goods?" - amenities_query_embedding2, - 0.35, - 2, - [ - { - "name": "Gucci Duty Free", - "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", - "location": "Gate E9", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - { - "name": "Hermes Duty Free", - "description": "High-end French brand duty-free shop offering luxury goods and accessories.", - "location": "Gate E18", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - ], - id="search_luxury_goods", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.1, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data -) -async def test_amenities_search( - ds: postgres.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[Any], -): - res, sql = await ds.amenities_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is not None - - -async def test_get_flight(ds: postgres.Client): - res, sql = await ds.get_flight(1) - expected = models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime("2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ) - assert res == expected - assert sql is not None - - -search_flights_by_number_test_data = [ - pytest.param( - "UA", - "1158", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=55455, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2025-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="B50", - arrival_gate="E4", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "UU", - "0000", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "airline, number, expected", search_flights_by_number_test_data -) -async def test_search_flights_by_number( - ds: postgres.Client, airline: str, number: str, expected: List[models.Flight] -): - res, sql = await ds.search_flights_by_number(airline, number) - assert res == expected - assert sql is not None - - -search_flights_by_airports_test_data = [ - pytest.param( - "2025-01-01", - "SFO", - "ORD", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=13, - airline="UA", - flight_number="616", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="A11", - arrival_gate="D8", - ), - models.Flight( - id=25, - airline="AA", - flight_number="242", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E30", - arrival_gate="C1", - ), - models.Flight( - id=109, - airline="UA", - flight_number="1640", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E27", - arrival_gate="C24", - ), - models.Flight( - id=119, - airline="AA", - flight_number="197", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="D25", - arrival_gate="E49", - ), - models.Flight( - id=136, - airline="UA", - flight_number="1564", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E3", - arrival_gate="C48", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "2025-01-01", - "FOO", - "BAR", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "date, departure_airport, arrival_airport, expected", - search_flights_by_airports_test_data, -) -async def test_search_flights_by_airports( - ds: postgres.Client, - date: str, - departure_airport: str, - arrival_airport: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - assert res == expected - assert sql is not None - - -policies_search_test_data = [ - pytest.param( - # "What is the fee for extra baggage?" - policies_query_embedding1, - 0.35, - 1, - [ - "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", - ], - id="search_extra_baggage_fee", - ), - pytest.param( - # "Can I change my flight?" - policies_query_embedding2, - 0.35, - 2, - [ - "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", - "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", - ], - id="search_flight_delays", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.35, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data -) -async def test_policies_search( - ds: postgres.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[str], -): - res, sql = await ds.policies_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is not None - - -validate_ticket_data = [ - pytest.param( - { - "airline": "UA", - "flight_number": "1158", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:57:00", - }, - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ), - 'SELECT *
FROM flights
WHERE airline ILIKE UA
AND flight_number ILIKE 1158
AND departure_airport ILIKE SFO
AND departure_time = 2025-01-01 05:57:00', - ), - pytest.param( - { - "airline": "XX", - "flight_number": "9999", - "departure_airport": "ZZZ", - "departure_time": "2025-01-01 05:57:00", - }, - None, - None, - ), -] - - -@pytest.mark.parametrize("params, expected_data, expected_sql", validate_ticket_data) -async def test_validate_ticket( - ds: postgres.Client, params, expected_data, expected_sql -): - flight, sql = await ds.validate_ticket(**params) - assert flight == expected_data - assert sql == expected_sql diff --git a/retrieval_service/datastore/providers/spanner_gsql.py b/retrieval_service/datastore/providers/spanner_gsql.py deleted file mode 100644 index 459e393cd..000000000 --- a/retrieval_service/datastore/providers/spanner_gsql.py +++ /dev/null @@ -1,1004 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from typing import Any, Literal, Optional - -from google.cloud import spanner # type: ignore -from google.cloud.spanner_v1 import JsonObject, param_types -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance -from google.oauth2 import service_account # type: ignore -from pydantic import BaseModel - -import models - -from .. import datastore - -# Identifier for Spanner -SPANNER_IDENTIFIER = "spanner-gsql" - - -# Configuration model for Spanner -class Config(BaseModel, datastore.AbstractConfig): - """ - Configuration model for Spanner. - - Attributes: - kind (Literal["spanner"]): Type of datastore. - project (str): Google Cloud project ID. - instance (str): ID of the Spanner instance. - database (str): ID of the Spanner database. - service_account_key_file (str): Service Account Key File. - """ - - kind: Literal["spanner-gsql"] - project: str - instance: str - database: str - service_account_key_file: Optional[str] = None - - -# Client class for interacting with Spanner -class Client(datastore.Client[Config]): - OPERATION_TIMEOUT_SECONDS = 240 - BATCH_SIZE = 1000 - AIRPORT_COLUMNS = ["id", "iata", "name", "city", "country"] - AMENITIES_COLUMNS = [ - "id", - "name", - "description", - "location", - "terminal", - "category", - "hour", - "sunday_start_hour", - "sunday_end_hour", - "monday_start_hour", - "monday_end_hour", - "tuesday_start_hour", - "tuesday_end_hour", - "wednesday_start_hour", - "wednesday_end_hour", - "thursday_start_hour", - "thursday_end_hour", - "friday_start_hour", - "friday_end_hour", - "saturday_start_hour", - "saturday_end_hour", - "content", - "embedding", - ] - FLIGHTS_COLUMNS = [ - "id", - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - "departure_gate", - "arrival_gate", - ] - - POLICIES_COLUMNS = ["id", "content", "embedding"] - """ - Client class for interacting with Spanner. - - Attributes: - __client (spanner.Client): Spanner client instance. - __instance_id (str): ID of the Spanner instance. - __database_id (str): ID of the Spanner database. - __instance (Instance): Spanner instance. - __database (Database): Spanner database. - """ - - @datastore.classproperty - def kind(cls): - return SPANNER_IDENTIFIER - - def __init__(self, client: spanner.Client, instance_id: str, database_id: str): - """ - Initialize the Spanner client. - - Args: - client (spanner.Client): Spanner client instance. - instance_id (str): ID of the Spanner instance. - database_id (str): ID of the Spanner database. - """ - self.__client = client - self.__instance_id = instance_id - self.__database_id = database_id - - self.__instance = self.__client.instance(self.__instance_id) - self.__database = self.__instance.database(self.__database_id) - - @classmethod - async def create(cls, config: Config) -> "Client": - """ - Create a Spanner client. - - Args: - config (Config): Configuration for creating the client. - - Returns: - Client: Initialized Spanner client. - """ - client: spanner.Client - - if config.service_account_key_file is not None: - credentials = service_account.Credentials.from_service_account_file( - config.service_account_key_file - ) - client = spanner.Client(project=config.project, credentials=credentials) - else: - client = spanner.Client(project=config.project) - - instance_id = config.instance - instance = client.instance(instance_id) - - if not instance.exists(): - raise Exception(f"Instance with id: {instance_id} doesn't exist.") - - database_id = config.database - database = instance.database(database_id) - - if not database.exists(): - raise Exception(f"Database with id: {database_id} doesn't exist.") - - return cls(client, instance_id, database_id) - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - """ - Initialize data in the Spanner database by creating tables and inserting records. - - Args: - airports (list[models.Airport]): list of airports to be initialized. - amenities (list[models.Amenity]): list of amenities to be initialized. - flights (list[models.Flight]): list of flights to be initialized. - policies (list[models.Policy]): list of policies to be initialized. - Returns: - None - """ - # Initialize a list to store Data Definition Language (DDL) statements - ddl = [] - - # Create DDL statement to drop the 'airports' table if it exists - ddl.append("DROP TABLE IF EXISTS airports") - - # Create DDL statement to create the 'airports' table - ddl.append( - """ - CREATE TABLE airports( - id INT64, - iata STRING(MAX), - name STRING(MAX), - city STRING(MAX), - country STRING(MAX) - ) PRIMARY KEY(id) - """ - ) - - # Create DDL statement to drop the 'amenities' table if it exists - ddl.append("DROP TABLE IF EXISTS amenities") - - # Create DDL statement to create the 'amenities' table - ddl.append( - """ - CREATE TABLE amenities( - id INT64, - name STRING(MAX), - description STRING(MAX), - location STRING(MAX), - terminal STRING(MAX), - category STRING(MAX), - hour STRING(MAX), - sunday_start_hour STRING(100), - sunday_end_hour STRING(100), - monday_start_hour STRING(100), - monday_end_hour STRING(100), - tuesday_start_hour STRING(100), - tuesday_end_hour STRING(100), - wednesday_start_hour STRING(100), - wednesday_end_hour STRING(100), - thursday_start_hour STRING(100), - thursday_end_hour STRING(100), - friday_start_hour STRING(100), - friday_end_hour STRING(100), - saturday_start_hour STRING(100), - saturday_end_hour STRING(100), - content STRING(MAX) NOT NULL, - embedding ARRAY NOT NULL - ) PRIMARY KEY(id) - """ - ) - - # Create DDL statement to drop the 'flights' table if it exists - ddl.append("DROP TABLE IF EXISTS flights") - - # Create DDL statement to create the 'flights' table - ddl.append( - """ - CREATE TABLE flights( - id INT64, - airline STRING(MAX), - flight_number STRING(MAX), - departure_airport STRING(MAX), - arrival_airport STRING(MAX), - departure_time STRING(100), - arrival_time STRING(100), - departure_gate STRING(MAX), - arrival_gate STRING(MAX) - ) PRIMARY KEY(id) - """ - ) - - # Create DDL statement to drop the 'policies' table if it exists - ddl.append("DROP TABLE IF EXISTS policies") - - # Create DDL statement to create the 'policies' table - ddl.append( - """ - CREATE TABLE policies( - id INT64, - content STRING(MAX) NOT NULL, - embedding ARRAY NOT NULL - ) PRIMARY KEY(id) - """ - ) - - # Create DDL statement to drop the 'tickets' table if it exists - ddl.append("DROP TABLE IF EXISTS tickets") - - # Create DDL statement to create the 'tickets' table - ddl.append( - """ - CREATE TABLE tickets( - user_id STRING(MAX), - user_name STRING(MAX), - user_email STRING(MAX), - airline STRING(MAX), - flight_number STRING(MAX), - departure_airport STRING(MAX), - arrival_airport STRING(MAX), - departure_time STRING(100), - arrival_time STRING(100) - ) PRIMARY KEY(user_id, airline, flight_number, departure_time) - """ - ) - - # Update the schema using DDL statements - operation = self.__database.update_ddl(ddl) - - print("Waiting for schema update operation to complete...") - operation.result(self.OPERATION_TIMEOUT_SECONDS) - print("Schema update operation completed") - - # Insert data into 'airports' table using batch operation - - values = [ - tuple(getattr(airport, field) for field in self.AIRPORT_COLUMNS) - for airport in airports - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="airports", - columns=self.AIRPORT_COLUMNS, - values=records, - ) - - # Insert data into 'amenities' table using batch operation - values = [ - tuple( - ( - str(getattr(amenity, field)) - if isinstance(getattr(amenity, field), datetime.time) - else getattr(amenity, field) - ) - for field in self.AMENITIES_COLUMNS - ) - for amenity in amenities - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="amenities", - columns=self.AMENITIES_COLUMNS, - values=records, - ) - - # Insert data into 'flights' table using batch operation - values = [ - tuple( - ( - str(getattr(flight, field)) - if isinstance(getattr(flight, field), datetime.datetime) - else getattr(flight, field) - ) - for field in self.FLIGHTS_COLUMNS - ) - for flight in flights - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="flights", - columns=self.FLIGHTS_COLUMNS, - values=records, - ) - - # Insert data into 'policies' table using batch operation - values = [ - tuple(getattr(policy, field) for field in self.POLICIES_COLUMNS) - for policy in policies - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="policies", - columns=self.POLICIES_COLUMNS, - values=records, - ) - - # Return None to indicate successful initialization - return None - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - """ - Export data from the Spanner database. - - Returns: - tuple: A tuple containing lists of airports, amenities, flights, and policies. - """ - airports: list = [] - amenities: list = [] - flights: list = [] - policies: list = [] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - airport_results = snapshot.execute_sql( - "SELECT {} FROM airports ORDER BY id ASC".format( - ",".join(self.AIRPORT_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch airports: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in airport_results - ] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - amenity_results = snapshot.execute_sql( - "SELECT {} FROM amenities ORDER BY id ASC".format( - ",".join(self.AMENITIES_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch amenities: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - amenities = [ - models.Amenity.model_validate( - {key: value for key, value in zip(self.AMENITIES_COLUMNS, a)} - ) - for a in amenity_results - ] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - flights_results = snapshot.execute_sql( - "SELECT {} FROM flights ORDER BY id ASC".format( - ",".join(self.FLIGHTS_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch flights: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in flights_results - ] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - policy_results = snapshot.execute_sql( - "SELECT {} FROM policies ORDER BY id ASC".format( - ",".join(self.POLICIES_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch policies: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - policies = [ - models.Policy.model_validate( - {key: value for key, value in zip(self.POLICIES_COLUMNS, a)} - ) - for a in policy_results - ] - - return airports, amenities, flights, policies - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - """ - Retrieve an airport by its ID. - - Args: - id (int): The ID of the airport. - - Returns: - Optional[models.Airport]: An Airport model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Execute SQL query to fetch airport by ID - result = snapshot.execute_sql( - sql="SELECT * FROM airports WHERE id = @id", - params={"id": id}, - param_types={"id": param_types.INT64}, - ) - - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in result - ] - - return airports[0], None - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - """ - Retrieve an airport by its IATA code. - - Args: - iata (str): The IATA code of the airport. - - Returns: - Optional[models.Airport]: An Airport model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Execute SQL query to fetch airport by ID - result = snapshot.execute_sql( - sql="SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(@iata)", - params={"iata": iata}, - param_types={"iata": param_types.STRING}, - ) - - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in result - ] - - return airports[0], None - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - """ - Search for airports based on optional parameters. - - Args: - country (Optional[str]): The country of the airport. - city (Optional[str]): The city of the airport. - name (Optional[str]): The name of the airport. - - Returns: - list[models.Airport]: A list of Airport model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Construct SQL query based on provided parameters - query = """ - SELECT * FROM airports - WHERE (@country IS NULL OR LOWER(country) LIKE LOWER(@country)) - AND (@city IS NULL OR LOWER(city) LIKE LOWER(@city)) - AND (@name IS NULL OR LOWER(name) LIKE '%' || LOWER(@name) || '%') - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "country": country, - "city": city, - "name": name, - }, - param_types={ - "country": param_types.STRING, - "city": param_types.STRING, - "name": param_types.STRING, - }, - ) - - # Convert query result to model instance using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in results - ] - - return airports, None - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - """ - Retrieves an amenity by its ID. - - Args: - id (int): The ID of the amenity. - - Returns: - Optional[models.Amenity]: An Amenity model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - result = snapshot.execute_sql( - sql=""" - SELECT id, name, description, location, terminal, category, hour FROM amenities - WHERE id = @id - """, - params={"id": id}, - param_types={"id": param_types.INT64}, - ) - - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - amenities = [ - models.Amenity.model_validate( - {key: value for key, value in zip(self.AMENITIES_COLUMNS, a)} - ) - for a in result - ] - - return amenities[0], None - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - """ - Search for amenities based on similarity to a query embedding. - - Args: - query_embedding (list[float]): The embedding representing the query. - similarity_threshold (float): The minimum similarity threshold for results. - top_k (int): The maximum number of results to return. - - Returns: - list[models.Amenity]: A list of Amenity model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - query = """ - SELECT name, description, location, terminal, category, hour - FROM ( - SELECT name, description, location, terminal, category, hour, - COSINE_DISTANCE(embedding, @query_embedding) AS similarity - FROM amenities - ) AS sorted_amenities - WHERE (1 - similarity) > @similarity_threshold - ORDER BY similarity - LIMIT @top_k - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "query_embedding": query_embedding, - "similarity_threshold": similarity_threshold, - "top_k": top_k, - }, - param_types={ - "query_embedding": param_types.Array(param_types.FLOAT64), - "similarity_threshold": param_types.FLOAT64, - "top_k": param_types.INT64, - }, - ) - - amenities = [ - {key: value for key, value in zip(self.AMENITIES_COLUMNS[1:], a)} - for a in results - ] - - return amenities, None - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - """ - Retrieves a flight by its ID. - - Args: - flight_id (int): The ID of the flight. - - Returns: - Optional[models.Flight]: A Flight model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - result = snapshot.execute_sql( - sql=""" - SELECT * FROM flights - WHERE id = @flight_id - """, - params={"flight_id": flight_id}, - param_types={"flight_id": param_types.INT64}, - ) - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in result - ] - - return flights[0], None - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - """ - Search for flights by airline and flight number. - - Args: - airline (str): The airline of the flight. - number (str): The flight number. - - Returns: - list[models.Flight]: A list of Flight model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - results = snapshot.execute_sql( - sql=""" - SELECT * FROM flights - WHERE airline = @airline - AND flight_number = @number - LIMIT 10 - """, - params={"airline": airline, "number": number}, - param_types={ - "airline": param_types.STRING, - "number": param_types.STRING, - }, - ) - - # Convert query result to model instance using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in results - ] - - return flights, None - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - """ - Search for flights by departure and/or arrival airports. - - Args: - date (str): The date of the flights in 'YYYY-MM-DD' format. - departure_airport (str, optional): The departure airport code. Defaults to None. - arrival_airport (str, optional): The arrival airport code. Defaults to None. - - Returns: - list[models.Flight]: A list of Flight model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - query = """ - SELECT * FROM flights - WHERE (@departure_airport IS NULL OR LOWER(departure_airport) LIKE LOWER(@departure_airport)) - AND (@arrival_airport IS NULL OR LOWER(arrival_airport) LIKE LOWER(@arrival_airport)) - AND cast(departure_time as TIMESTAMP) >= CAST(@datetime AS TIMESTAMP) - AND cast(departure_time as TIMESTAMP) < TIMESTAMP_ADD(CAST(@datetime AS TIMESTAMP), INTERVAL 1 DAY) - LIMIT 10 - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "datetime": date, - }, - param_types={ - "departure_airport": param_types.STRING, - "arrival_airport": param_types.STRING, - "datetime": param_types.STRING, - }, - ) - - # Convert query results to model instances using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in results - ] - - return flights, None - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - results = snapshot.execute_sql( - sql=""" - SELECT * FROM flights - WHERE LOWER(airline) LIKE LOWER(@airline) - AND LOWER(flight_number) LIKE LOWER(@flight_number) - AND LOWER(departure_airport) LIKE LOWER(@departure_airport) - AND departure_time = @departure_time - """, - params={ - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "departure_time": departure_time, - }, - param_types={ - "airline": param_types.STRING, - "flight_number": param_types.STRING, - "departure_airport": param_types.STRING, - "departure_time": param_types.STRING, - }, - ) - - if results is None: - return None, None - - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in results - ] - - if not flights: - return None, None - return flights[0], None - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - """ - Inserts a ticket into the database. - - Args: - user_id (str): The ID of the user. - user_name (str): The name of the user. - user_email (str): The email of the user. - airline (str): The airline of the flight. - flight_number (str): The flight number. - departure_airport (str): The departure airport code. - arrival_airport (str): The arrival airport code. - departure_time (str): The departure time of the flight. - arrival_time (str): The arrival time of the flight. - """ - departure_time_datetime = datetime.datetime.strptime( - departure_time, "%Y-%m-%d %H:%M:%S" - ) - arrival_time_datetime = datetime.datetime.strptime( - arrival_time, "%Y-%m-%d %H:%M:%S" - ) - - with self.__database.batch() as batch: - batch.insert( - table="tickets", - columns=[ - "user_id", - "user_name", - "user_email", - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - ], - values=[ - [ - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time_datetime, - arrival_time_datetime, - ] - ], - ) - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - """ - Retrieves a list of tickets for a user. - - Args: - user_id (str): The ID of the user. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - results = snapshot.execute_sql( - sql=""" - SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets - WHERE user_id = @user_id - """, - params={"user_id": user_id}, - param_types={"user_id": param_types.STRING}, - ) - - # Convert query results to model instances using model_validate method - tickets = [r for r in results] - - return tickets, None - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - """ - Search for policies based on similarity to a query embedding. - - Args: - query_embedding (list[float]): The embedding representing the query. - similarity_threshold (float): The minimum similarity threshold for results. - top_k (int): The maximum number of results to return. - - Returns: - list[models.Policy]: A list of Policy model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - query = """ - SELECT content - FROM ( - SELECT content, COSINE_DISTANCE(embedding, @query_embedding) AS similarity - FROM policies - ) AS sorted_policies - WHERE (1 - similarity) > @similarity_threshold - ORDER BY similarity - LIMIT @top_k - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "query_embedding": query_embedding, - "similarity_threshold": similarity_threshold, - "top_k": top_k, - }, - param_types={ - "query_embedding": param_types.Array(param_types.FLOAT64), - "similarity_threshold": param_types.FLOAT64, - "top_k": param_types.INT64, - }, - ) - - # Convert query result to model instance using model_validate method - policies = [a[0] for a in results] - - return policies, None - - async def close(self): - """ - Closes the database client connection. - """ - self.__client.close() diff --git a/retrieval_service/datastore/providers/spanner_gsql_test.py b/retrieval_service/datastore/providers/spanner_gsql_test.py deleted file mode 100644 index 8210c4513..000000000 --- a/retrieval_service/datastore/providers/spanner_gsql_test.py +++ /dev/null @@ -1,675 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from ipaddress import IPv4Address -from typing import Any, AsyncGenerator, Generator, List, Optional - -import pytest -import pytest_asyncio -from csv_diff import compare, load_csv # type: ignore -from google.cloud import spanner # type: ignore -from google.cloud.spanner_v1 import JsonObject, param_types -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance - -import models - -from .. import datastore -from . import spanner_gsql -from .test_data import ( - amenities_query_embedding1, - amenities_query_embedding2, - foobar_query_embedding, - policies_query_embedding1, - policies_query_embedding2, -) -from .utils import get_env_var - -pytestmark = pytest.mark.asyncio(scope="module") - - -@pytest.fixture(scope="module") -def db_project() -> str: - return get_env_var("DB_PROJECT", "Google Cloud Project") - - -@pytest.fixture(scope="module") -def db_instance() -> str: - return get_env_var("DB_INSTANCE", "Spanner Instance") - - -@pytest.fixture(scope="module") -def db_name() -> str: - return get_env_var("DB_NAME", "Spanner Database") - - -@pytest.fixture(scope="module") -def create_db( - db_project: str, db_instance: str, db_name: str -) -> Generator[str, None, None]: - client = spanner.Client(project=db_project) - instance = client.instance(db_instance) - - database = instance.database(db_name) - - database.create() - - yield db_name - - database.drop() - client.close() - - -@pytest_asyncio.fixture(scope="module") -async def ds( - create_db: str, - db_project: str, - db_instance: str, -) -> AsyncGenerator[datastore.Client, None]: - cfg = spanner_gsql.Config( - kind="spanner-gsql", - project=db_project, - instance=db_instance, - database=create_db, - ) - - ds = await datastore.create(cfg) - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, - amenities_ds_path, - flights_ds_path, - policies_ds_path, - ) - await ds.initialize_data(airports, amenities, flights, policies) - - if ds is None: - raise TypeError("datastore creation failure") - - yield ds - - await ds.close() - - -def check_file_diff(file_diff): - assert file_diff["added"] == [] - assert file_diff["removed"] == [] - assert file_diff["changed"] == [] - assert file_diff["columns_added"] == [] - assert file_diff["columns_removed"] == [] - - -async def test_export_dataset(ds: spanner_gsql.Client): - airports, amenities, flights, policies = await ds.export_data() - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - diff_airports = compare( - load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") - ) - check_file_diff(diff_airports) - - diff_amenities = compare( - load_csv(open(amenities_ds_path), "id"), - load_csv(open(amenities_new_path), "id"), - ) - check_file_diff(diff_amenities) - - diff_flights = compare( - load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") - ) - check_file_diff(diff_flights) - - diff_policies = compare( - load_csv(open(policies_ds_path), "id"), - load_csv(open(policies_new_path), "id"), - ) - check_file_diff(diff_policies) - - -async def test_get_airport_by_id(ds: spanner_gsql.Client): - res, sql = await ds.get_airport_by_id(1) - expected = models.Airport( - id=1, - iata="MAG", - name="Madang Airport", - city="Madang", - country="Papua New Guinea", - ) - assert res == expected - assert sql is None - - -@pytest.mark.parametrize( - "iata", - [ - pytest.param("SFO", id="upper_case"), - pytest.param("sfo", id="lower_case"), - ], -) -async def test_get_airport_by_iata(ds: spanner_gsql.Client, iata: str): - res, sql = await ds.get_airport_by_iata(iata) - expected = models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - assert res == expected - assert sql is None - - -search_airports_test_data = [ - pytest.param( - "Philippines", - "San jose", - None, - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=2313, - iata="EUQ", - name="Evelio Javier Airport", - city="San Jose", - country="Philippines", - ), - ], - id="country_and_city_only", - ), - pytest.param( - "united states", - "san francisco", - None, - [ - models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - ], - id="country_and_name_only", - ), - pytest.param( - None, - "San Jose", - "San Jose", - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=3548, - iata="SJC", - name="Norman Y. Mineta San Jose International Airport", - city="San Jose", - country="United States", - ), - ], - id="city_and_name_only", - ), - pytest.param( - "Foo", - "FOO BAR", - "Foo bar", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) -async def test_search_airports( - ds: spanner_gsql.Client, - country: str, - city: str, - name: str, - expected: List[models.Airport], -): - res, sql = await ds.search_airports(country, city, name) - assert res == expected - assert sql is None - - -async def test_get_amenity(ds: spanner_gsql.Client): - res, sql = await ds.get_amenity(0) - expected = models.Amenity( - id=0, - name="Coffee Shop 732", - description="Serving American cuisine.", - location="Near Gate B12", - terminal="Terminal 3", - category="restaurant", - hour="Daily 7:00 am - 10:00 pm", - sunday_start_hour=None, - sunday_end_hour=None, - monday_start_hour=None, - monday_end_hour=None, - tuesday_start_hour=None, - tuesday_end_hour=None, - wednesday_start_hour=None, - wednesday_end_hour=None, - thursday_start_hour=None, - thursday_end_hour=None, - friday_start_hour=None, - friday_end_hour=None, - saturday_start_hour=None, - saturday_end_hour=None, - ) - - assert res == expected - assert sql is None - - -amenities_search_test_data = [ - pytest.param( - # "Where can I get coffee near gate A6?" - amenities_query_embedding1, - 0.65, - 1, - [ - { - "name": "Coffee Shop 732", - "description": "Serving American cuisine.", - "location": "Near Gate B12", - "terminal": "Terminal 3", - "category": "restaurant", - "hour": "Daily 7:00 am - 10:00 pm", - }, - ], - id="search_coffee_shop", - ), - pytest.param( - # "Where can I look for luxury goods?" - amenities_query_embedding2, - 0.65, - 2, - [ - { - "name": "Gucci Duty Free", - "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", - "location": "Gate E9", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - { - "name": "Hermes Duty Free", - "description": "High-end French brand duty-free shop offering luxury goods and accessories.", - "location": "Gate E18", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - ], - id="search_luxury_goods", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.9, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data -) -async def test_amenities_search( - ds: spanner_gsql.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[models.Amenity], -): - res, sql = await ds.amenities_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is None - - -async def test_get_flight(ds: spanner_gsql.Client): - res, sql = await ds.get_flight(1) - expected = models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime("2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ) - assert res == expected - assert sql is None - - -search_flights_by_number_test_data = [ - pytest.param( - "UA", - "1158", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=55455, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2025-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="B50", - arrival_gate="E4", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "UU", - "0000", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "airline, number, expected", search_flights_by_number_test_data -) -async def test_search_flights_by_number( - ds: spanner_gsql.Client, - airline: str, - number: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_number(airline, number) - assert res == expected - assert sql is None - - -search_flights_by_airports_test_data = [ - pytest.param( - "2025-01-01", - "SFO", - "ORD", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=13, - airline="UA", - flight_number="616", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="A11", - arrival_gate="D8", - ), - models.Flight( - id=25, - airline="AA", - flight_number="242", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E30", - arrival_gate="C1", - ), - models.Flight( - id=109, - airline="UA", - flight_number="1640", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E27", - arrival_gate="C24", - ), - models.Flight( - id=119, - airline="AA", - flight_number="197", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="D25", - arrival_gate="E49", - ), - models.Flight( - id=136, - airline="UA", - flight_number="1564", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E3", - arrival_gate="C48", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "2025-01-01", - "FOO", - "BAR", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "date, departure_airport, arrival_airport, expected", - search_flights_by_airports_test_data, -) -async def test_search_flights_by_airports( - ds: spanner_gsql.Client, - date: str, - departure_airport: str, - arrival_airport: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - assert res == expected - assert sql is None - - -policies_search_test_data = [ - pytest.param( - # "What is the fee for extra baggage?" - policies_query_embedding1, - 0.65, - 1, - [ - "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", - ], - id="search_extra_baggage_fee", - ), - pytest.param( - # "Can I change my flight?" - policies_query_embedding2, - 0.65, - 2, - [ - "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", - "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", - ], - id="search_flight_delays", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.65, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data -) -async def test_policies_search( - ds: spanner_gsql.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[models.Policy], -): - res, sql = await ds.policies_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is None - - -validate_ticket_data = [ - pytest.param( - { - "airline": "UA", - "flight_number": "1158", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:57:00", - }, - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ), - None, - ), - pytest.param( - { - "airline": "XX", - "flight_number": "9999", - "departure_airport": "ZZZ", - "departure_time": "2025-01-01 05:57:00", - }, - None, - None, - ), -] - - -@pytest.mark.parametrize("params, expected_data, expected_sql", validate_ticket_data) -async def test_validate_ticket( - ds: spanner_gsql.Client, params, expected_data, expected_sql -): - flight, sql = await ds.validate_ticket(**params) - assert flight == expected_data - assert sql == expected_sql diff --git a/retrieval_service/datastore/providers/spanner_postgres.py b/retrieval_service/datastore/providers/spanner_postgres.py deleted file mode 100644 index 8ebdd347f..000000000 --- a/retrieval_service/datastore/providers/spanner_postgres.py +++ /dev/null @@ -1,1028 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from typing import Any, Literal, Optional - -from google.cloud import spanner # type: ignore -from google.cloud.spanner_v1 import JsonObject, param_types -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance -from google.oauth2 import service_account # type: ignore -from pydantic import BaseModel - -import models - -from .. import datastore - -# Identifier for Spanner -SPANNER_IDENTIFIER = "spanner-postgres" - - -# Configuration model for Spanner -class Config(BaseModel, datastore.AbstractConfig): - """ - Configuration model for Spanner. - - Attributes: - kind (Literal["spanner"]): Type of datastore. - project (str): Google Cloud project ID. - instance (str): ID of the Spanner instance. - database (str): ID of the Spanner database. - service_account_key_file (str): Service Account Key File. - """ - - kind: Literal["spanner-postgres"] - project: str - instance: str - database: str - service_account_key_file: Optional[str] = None - - -# Client class for interacting with Spanner -class Client(datastore.Client[Config]): - OPERATION_TIMEOUT_SECONDS = 240 - BATCH_SIZE = 1000 - AIRPORT_COLUMNS = ["id", "iata", "name", "city", "country"] - AMENITIES_COLUMNS = [ - "id", - "name", - "description", - "location", - "terminal", - "category", - "hour", - "sunday_start_hour", - "sunday_end_hour", - "monday_start_hour", - "monday_end_hour", - "tuesday_start_hour", - "tuesday_end_hour", - "wednesday_start_hour", - "wednesday_end_hour", - "thursday_start_hour", - "thursday_end_hour", - "friday_start_hour", - "friday_end_hour", - "saturday_start_hour", - "saturday_end_hour", - "content", - "embedding", - ] - FLIGHTS_COLUMNS = [ - "id", - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - "departure_gate", - "arrival_gate", - ] - - POLICIES_COLUMNS = ["id", "content", "embedding"] - """ - Client class for interacting with Spanner. - - Attributes: - __client (spanner.Client): Spanner client instance. - __instance_id (str): ID of the Spanner instance. - __database_id (str): ID of the Spanner database. - __instance (Instance): Spanner instance. - __database (Database): Spanner database. - """ - - @datastore.classproperty - def kind(cls): - return SPANNER_IDENTIFIER - - def __init__(self, client: spanner.Client, instance_id: str, database_id: str): - """ - Initialize the Spanner client. - - Args: - client (spanner.Client): Spanner client instance. - instance_id (str): ID of the Spanner instance. - database_id (str): ID of the Spanner database. - """ - self.__client = client - self.__instance_id = instance_id - self.__database_id = database_id - - self.__instance = self.__client.instance(self.__instance_id) - self.__database = self.__instance.database(self.__database_id) - - @classmethod - async def create(cls, config: Config) -> "Client": - """ - Create a Spanner client. - - Args: - config (Config): Configuration for creating the client. - - Returns: - Client: Initialized Spanner client. - """ - client: spanner.Client - - if config.service_account_key_file is not None: - credentials = service_account.Credentials.from_service_account_file( - config.service_account_key_file - ) - client = spanner.Client(project=config.project, credentials=credentials) - else: - client = spanner.Client(project=config.project) - - instance_id = config.instance - instance = client.instance(instance_id) - - if not instance.exists(): - raise Exception(f"Instance with id: {instance_id} doesn't exist.") - - database_id = config.database - database = instance.database(database_id) - - if not database.exists(): - raise Exception(f"Database with id: {database_id} doesn't exist.") - - return cls(client, instance_id, database_id) - - async def initialize_data( - self, - airports: list[models.Airport], - amenities: list[models.Amenity], - flights: list[models.Flight], - policies: list[models.Policy], - ) -> None: - """ - Initialize data in the Spanner database by creating tables and inserting records. - - Args: - airports (list[models.Airport]): list of airports to be initialized. - amenities (list[models.Amenity]): list of amenities to be initialized. - flights (list[models.Flight]): list of flights to be initialized. - policies (list[models.Policy]): list of policies to be initialized. - Returns: - None - """ - # Initialize a list to store Data Definition Language (DDL) statements - ddl = [] - - # Create DDL statement to drop the 'airports' table if it exists - ddl.append("DROP TABLE IF EXISTS airports") - - # Create DDL statement to create the 'airports' table - ddl.append( - """ - CREATE TABLE airports( - id BIGINT PRIMARY KEY, - iata VARCHAR, - name VARCHAR, - city VARCHAR, - country VARCHAR - ) - """ - ) - - # Create DDL statement to drop the 'amenities' table if it exists - ddl.append("DROP TABLE IF EXISTS amenities") - - # Create DDL statement to create the 'amenities' table - ddl.append( - """ - CREATE TABLE amenities( - id BIGINT PRIMARY KEY, - name VARCHAR, - description VARCHAR, - location VARCHAR, - terminal VARCHAR, - category VARCHAR, - hour VARCHAR, - sunday_start_hour VARCHAR, - sunday_end_hour VARCHAR, - monday_start_hour VARCHAR, - monday_end_hour VARCHAR, - tuesday_start_hour VARCHAR, - tuesday_end_hour VARCHAR, - wednesday_start_hour VARCHAR, - wednesday_end_hour VARCHAR, - thursday_start_hour VARCHAR, - thursday_end_hour VARCHAR, - friday_start_hour VARCHAR, - friday_end_hour VARCHAR, - saturday_start_hour VARCHAR, - saturday_end_hour VARCHAR, - content VARCHAR NOT NULL, - embedding FLOAT8[] NOT NULL - ) - """ - ) - - # Create DDL statement to drop the 'flights' table if it exists - ddl.append("DROP TABLE IF EXISTS flights") - - # Create DDL statement to create the 'flights' table - ddl.append( - """ - CREATE TABLE flights( - id BIGINT PRIMARY KEY, - airline VARCHAR, - flight_number VARCHAR, - departure_airport VARCHAR, - arrival_airport VARCHAR, - departure_time VARCHAR(100), - arrival_time VARCHAR(100), - departure_gate VARCHAR, - arrival_gate VARCHAR - ) - """ - ) - - # Create DDL statement to drop the 'policies' table if it exists - ddl.append("DROP TABLE IF EXISTS policies") - - # Create DDL statement to create the 'policies' table - ddl.append( - """ - CREATE TABLE policies( - id BIGINT PRIMARY KEY, - content VARCHAR NOT NULL, - embedding FLOAT8[] NOT NULL - ) - """ - ) - - # Create DDL statement to drop the 'tickets' table if it exists - ddl.append("DROP TABLE IF EXISTS tickets") - - # Create DDL statement to create the 'tickets' table - ddl.append( - """ - CREATE TABLE tickets( - user_id VARCHAR, - user_name VARCHAR, - user_email VARCHAR, - airline VARCHAR, - flight_number VARCHAR, - departure_airport VARCHAR, - arrival_airport VARCHAR, - departure_time VARCHAR(100), - arrival_time VARCHAR(100), - PRIMARY KEY(user_id, airline, flight_number, departure_time) - ) - """ - ) - - # Update the schema using DDL statements - operation = self.__database.update_ddl(ddl) - - print("Waiting for schema update operation to complete...") - operation.result(self.OPERATION_TIMEOUT_SECONDS) - print("Schema update operation completed") - - # Insert data into 'airports' table using batch operation - - values = [ - tuple(getattr(airport, field) for field in self.AIRPORT_COLUMNS) - for airport in airports - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="airports", - columns=self.AIRPORT_COLUMNS, - values=records, - ) - - # Insert data into 'amenities' table using batch operation - values = [ - tuple( - ( - str(getattr(amenity, field)) - if isinstance(getattr(amenity, field), datetime.time) - else getattr(amenity, field) - ) - for field in self.AMENITIES_COLUMNS - ) - for amenity in amenities - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="amenities", - columns=self.AMENITIES_COLUMNS, - values=records, - ) - - # Insert data into 'flights' table using batch operation - values = [ - tuple( - ( - str(getattr(flight, field)) - if isinstance(getattr(flight, field), datetime.datetime) - else getattr(flight, field) - ) - for field in self.FLIGHTS_COLUMNS - ) - for flight in flights - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="flights", - columns=self.FLIGHTS_COLUMNS, - values=records, - ) - - # Insert data into 'policies' table using batch operation - values = [ - tuple(getattr(policy, field) for field in self.POLICIES_COLUMNS) - for policy in policies - ] - - for i in range(0, len(values), self.BATCH_SIZE): - records = values[i : i + self.BATCH_SIZE] - - with self.__database.batch() as batch: - batch.insert( - table="policies", - columns=self.POLICIES_COLUMNS, - values=records, - ) - - # Return None to indicate successful initialization - return None - - async def export_data( - self, - ) -> tuple[ - list[models.Airport], - list[models.Amenity], - list[models.Flight], - list[models.Policy], - ]: - """ - Export data from the Spanner database. - - Returns: - tuple: A tuple containing lists of airports, amenities, flights, and policies. - """ - airports: list = [] - amenities: list = [] - flights: list = [] - policies: list = [] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - airport_results = snapshot.execute_sql( - "SELECT {} FROM airports ORDER BY id ASC".format( - ",".join(self.AIRPORT_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch airports: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in airport_results - ] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - amenity_results = snapshot.execute_sql( - "SELECT {} FROM amenities ORDER BY id ASC".format( - ",".join(self.AMENITIES_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch amenities: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - amenities = [ - models.Amenity.model_validate( - {key: value for key, value in zip(self.AMENITIES_COLUMNS, a)} - ) - for a in amenity_results - ] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - flights_results = snapshot.execute_sql( - "SELECT {} FROM flights ORDER BY id ASC".format( - ",".join(self.FLIGHTS_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch flights: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in flights_results - ] - - try: - with self.__database.snapshot() as snapshot: - # Execute SQL queries to fetch data from respective tables - policy_results = snapshot.execute_sql( - "SELECT {} FROM policies ORDER BY id ASC".format( - ",".join(self.POLICIES_COLUMNS) - ) - ) - except Exception as e: - # Handle any exceptions, such as database connection errors - print(f"Error occurred while fetch policies: {e}") - # Return empty lists in case of error - return airports, amenities, flights, policies - - # Convert query results to model instances using model_validate method - policies = [ - models.Policy.model_validate( - {key: value for key, value in zip(self.POLICIES_COLUMNS, a)} - ) - for a in policy_results - ] - - return airports, amenities, flights, policies - - async def get_airport_by_id( - self, id: int - ) -> tuple[Optional[models.Airport], Optional[str]]: - """ - Retrieve an airport by its ID. - - Args: - id (int): The ID of the airport. - - Returns: - Optional[models.Airport]: An Airport model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Execute SQL query to fetch airport by ID - result = snapshot.execute_sql( - sql="SELECT * FROM airports WHERE id = $1", - params={"p1": id}, - param_types={"p1": param_types.INT64}, - ) - - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in result - ] - - return airports[0], None - - async def get_airport_by_iata( - self, iata: str - ) -> tuple[Optional[models.Airport], Optional[str]]: - """ - Retrieve an airport by its IATA code. - - Args: - iata (str): The IATA code of the airport. - - Returns: - Optional[models.Airport]: An Airport model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Execute SQL query to fetch airport by ID - result = snapshot.execute_sql( - sql="SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER($1)", - params={"p1": iata}, - param_types={"p1": param_types.STRING}, - ) - - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in result - ] - - return airports[0], None - - async def search_airports( - self, - country: Optional[str] = None, - city: Optional[str] = None, - name: Optional[str] = None, - ) -> tuple[list[models.Airport], Optional[str]]: - """ - Search for airports based on optional parameters. - - Args: - country (Optional[str]): The country of the airport. - city (Optional[str]): The city of the airport. - name (Optional[str]): The name of the airport. - - Returns: - list[models.Airport]: A list of Airport model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Construct SQL query based on provided parameters - query = """ - SELECT * FROM airports - WHERE ($1 IS NULL OR LOWER(country) LIKE LOWER($1)) - AND ($2 IS NULL OR LOWER(city) LIKE LOWER($2)) - AND ($3 IS NULL OR LOWER(name) LIKE '%' || LOWER($3) || '%') - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "p1": country, - "p2": city, - "p3": name, - }, - param_types={ - "p1": param_types.STRING, - "p2": param_types.STRING, - "p3": param_types.STRING, - }, - ) - - # Convert query result to model instance using model_validate method - airports = [ - models.Airport.model_validate( - {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} - ) - for a in results - ] - - return airports, None - - async def get_amenity( - self, id: int - ) -> tuple[Optional[models.Amenity], Optional[str]]: - """ - Retrieves an amenity by its ID. - - Args: - id (int): The ID of the amenity. - - Returns: - Optional[models.Amenity]: An Amenity model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - result = snapshot.execute_sql( - sql=""" - SELECT id, name, description, location, terminal, category, hour FROM amenities - WHERE id = $1 - """, - params={"p1": id}, - param_types={"p1": param_types.INT64}, - ) - - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - amenities = [ - models.Amenity.model_validate( - {key: value for key, value in zip(self.AMENITIES_COLUMNS, a)} - ) - for a in result - ] - - return amenities[0], None - - async def amenities_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[Any], Optional[str]]: - """ - Search for amenities based on similarity to a query embedding. - - Args: - query_embedding (list[float]): The embedding representing the query. - similarity_threshold (float): The minimum similarity threshold for results. - top_k (int): The maximum number of results to return. - - Returns: - list[models.Amenity]: A list of Amenity model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - query = """ - SELECT name, description, location, terminal, category, hour - FROM ( - SELECT name, description, location, terminal, category, hour, - spanner.cosine_distance(embedding, $1) AS similarity - FROM amenities - ) AS sorted_amenities - WHERE (1 - similarity) > $2 - ORDER BY similarity - LIMIT $3 - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "p1": query_embedding, - "p2": similarity_threshold, - "p3": top_k, - }, - param_types={ - "p1": param_types.Array(param_types.FLOAT64), - "p2": param_types.FLOAT64, - "p3": param_types.INT64, - }, - ) - - # Convert query result to model instance using model_validate method - amenities = [ - {key: value for key, value in zip(self.AMENITIES_COLUMNS[1:], a)} - for a in results - ] - - return amenities, None - - async def get_flight( - self, flight_id: int - ) -> tuple[Optional[models.Flight], Optional[str]]: - """ - Retrieves a flight by its ID. - - Args: - flight_id (int): The ID of the flight. - - Returns: - Optional[models.Flight]: A Flight model instance if found, else None. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - result = snapshot.execute_sql( - sql=""" - SELECT * FROM flights - WHERE id = $1 - """, - params={"p1": flight_id}, - param_types={"p1": param_types.INT64}, - ) - # Check if result is None - if result is None: - return None, None - - # Convert query result to model instance using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in result - ] - - return flights[0], None - - async def search_flights_by_number( - self, - airline: str, - number: str, - ) -> tuple[list[models.Flight], Optional[str]]: - """ - Search for flights by airline and flight number. - - Args: - airline (str): The airline of the flight. - number (str): The flight number. - - Returns: - list[models.Flight]: A list of Flight model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - results = snapshot.execute_sql( - sql=""" - SELECT * FROM flights - WHERE airline = $1 - AND flight_number = $2 - LIMIT 10 - """, - params={"p1": airline, "p2": number}, - param_types={ - "p1": param_types.STRING, - "p2": param_types.STRING, - }, - ) - - # Convert query result to model instance using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in results - ] - - return flights, None - - async def search_flights_by_airports( - self, - date: str, - departure_airport: Optional[str] = None, - arrival_airport: Optional[str] = None, - ) -> tuple[list[models.Flight], Optional[str]]: - """ - Search for flights by departure and/or arrival airports. - - Args: - date (str): The date of the flights in 'YYYY-MM-DD' format. - departure_airport (str, optional): The departure airport code. Defaults to None. - arrival_airport (str, optional): The arrival airport code. Defaults to None. - - Returns: - list[models.Flight]: A list of Flight model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - - query = """ - SELECT * FROM flights - WHERE (COALESCE($1) IS NULL OR LOWER(departure_airport) LIKE LOWER($1)) - AND (COALESCE($2) IS NULL OR LOWER(arrival_airport) LIKE LOWER($2)) - AND CAST(departure_time as timestamptz) >= CAST($3 AS timestamptz) - AND cast(departure_time as timestamptz) < spanner.timestamptz_add(CAST($3 AS timestamptz), '1 day') - LIMIT 10 - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "p1": departure_airport, - "p2": arrival_airport, - "p3": date, - }, - param_types={ - "p1": param_types.STRING, - "p2": param_types.STRING, - "p3": param_types.STRING, - }, - ) - - # Convert query results to model instances using model_validate method - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in results - ] - - return flights, None - - async def validate_ticket( - self, - airline: str, - flight_number: str, - departure_airport: str, - departure_time: str, - ) -> tuple[Optional[models.Flight], Optional[str]]: - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - results = snapshot.execute_sql( - sql=""" - SELECT * FROM flights - WHERE LOWER(airline) LIKE LOWER($1) - AND LOWER(flight_number) LIKE LOWER($2) - AND LOWER(departure_airport) LIKE LOWER($3) - AND departure_time = $4 - """, - params={ - "p1": airline, - "p2": flight_number, - "p3": departure_airport, - "p4": departure_time, - }, - param_types={ - "p1": param_types.STRING, - "p2": param_types.STRING, - "p3": param_types.STRING, - "p4": param_types.STRING, - }, - ) - - if results is None: - return None, None - - flights = [ - models.Flight.model_validate( - {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} - ) - for a in results - ] - - if not flights: - return None, None - return flights[0], None - - async def insert_ticket( - self, - user_id: str, - user_name: str, - user_email: str, - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: str, - arrival_time: str, - ): - """ - Inserts a ticket into the database. - - Args: - user_id (str): The ID of the user. - user_name (str): The name of the user. - user_email (str): The email of the user. - airline (str): The airline of the flight. - flight_number (str): The flight number. - departure_airport (str): The departure airport code. - arrival_airport (str): The arrival airport code. - departure_time (str): The departure time of the flight. - arrival_time (str): The arrival time of the flight. - """ - departure_time_datetime = datetime.datetime.strptime( - departure_time, "%Y-%m-%d %H:%M:%S" - ) - arrival_time_datetime = datetime.datetime.strptime( - arrival_time, "%Y-%m-%d %H:%M:%S" - ) - - with self.__database.batch() as batch: - batch.insert( - table="tickets", - columns=[ - "user_id", - "user_name", - "user_email", - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - ], - values=[ - [ - user_id, - user_name, - user_email, - airline, - flight_number, - departure_airport, - arrival_airport, - departure_time_datetime, - arrival_time_datetime, - ] - ], - ) - - async def list_tickets( - self, - user_id: str, - ) -> tuple[list[Any], Optional[str]]: - """ - Retrieves a list of tickets for a user. - - Args: - user_id (str): The ID of the user. - """ - with self.__database.snapshot() as snapshot: - # Spread SQL query for readability - results = snapshot.execute_sql( - sql=""" - SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets - WHERE user_id = $1 - """, - params={"p1": user_id}, - param_types={"p1": param_types.STRING}, - ) - - # Convert query results to model instances using model_validate method - tickets = [ - models.Ticket.model_validate( - { - key: value - for key, value in zip( - [ - "user_id", - "user_name", - "user_email", - "airline", - "flight_number", - "departure_airport", - "arrival_airport", - "departure_time", - "arrival_time", - ], - a, - ) - } - ) - for a in results - ] - - return tickets, None - - async def policies_search( - self, query_embedding: list[float], similarity_threshold: float, top_k: int - ) -> tuple[list[str], Optional[str]]: - """ - Search for policies based on similarity to a query embedding. - - Args: - query_embedding (list[float]): The embedding representing the query. - similarity_threshold (float): The minimum similarity threshold for results. - top_k (int): The maximum number of results to return. - - Returns: - list[models.Policy]: A list of Policy model instances matching the search criteria. - """ - with self.__database.snapshot() as snapshot: - query = """ - SELECT content - FROM ( - SELECT content, spanner.cosine_distance(embedding, $1) AS similarity - FROM policies - ) AS sorted_policies - WHERE (1 - similarity) > $2 - ORDER BY similarity - LIMIT $3 - """ - - # Execute SQL query with parameters - results = snapshot.execute_sql( - sql=query, - params={ - "p1": query_embedding, - "p2": similarity_threshold, - "p3": top_k, - }, - param_types={ - "p1": param_types.Array(param_types.FLOAT64), - "p2": param_types.FLOAT64, - "p3": param_types.INT64, - }, - ) - - # Convert query result to model instance using model_validate method - policies = [a[0] for a in results] - - return policies, None - - async def close(self): - """ - Closes the database client connection. - """ - self.__client.close() diff --git a/retrieval_service/datastore/providers/spanner_postgres_test.py b/retrieval_service/datastore/providers/spanner_postgres_test.py deleted file mode 100644 index 5cc498020..000000000 --- a/retrieval_service/datastore/providers/spanner_postgres_test.py +++ /dev/null @@ -1,685 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from ipaddress import IPv4Address -from typing import Any, AsyncGenerator, Generator, List, Optional - -import pytest -import pytest_asyncio -from csv_diff import compare, load_csv # type: ignore -from google.cloud import spanner # type: ignore -from google.cloud.spanner_admin_database_v1.types import DatabaseDialect -from google.cloud.spanner_v1 import JsonObject, param_types -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance - -import models - -from .. import datastore -from . import spanner_postgres -from .test_data import ( - amenities_query_embedding1, - amenities_query_embedding2, - foobar_query_embedding, - policies_query_embedding1, - policies_query_embedding2, -) -from .utils import get_env_var - -pytestmark = pytest.mark.asyncio(scope="module") - - -@pytest.fixture(scope="module") -def db_project() -> str: - return get_env_var("DB_PROJECT", "Google Cloud Project") - - -@pytest.fixture(scope="module") -def db_instance() -> str: - return get_env_var("DB_INSTANCE", "Spanner Instance") - - -@pytest.fixture(scope="module") -def db_name() -> str: - return get_env_var("DB_NAME", "Spanner PG Database") - - -@pytest.fixture(scope="module") -def create_db( - db_project: str, db_instance: str, db_name: str -) -> Generator[str, None, None]: - client = spanner.Client(project=db_project) - instance = client.instance(db_instance) - - database = instance.database( - db_name, - database_dialect=DatabaseDialect.POSTGRESQL, - ) - - database.create() - - yield db_name - - database.drop() - client.close() - - -@pytest_asyncio.fixture(scope="module") -async def ds( - create_db: str, - db_project: str, - db_instance: str, -) -> AsyncGenerator[datastore.Client, None]: - cfg = spanner_postgres.Config( - kind="spanner-postgres", - project=db_project, - instance=db_instance, - database=create_db, - ) - - ds = await datastore.create(cfg) - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, - amenities_ds_path, - flights_ds_path, - policies_ds_path, - ) - await ds.initialize_data(airports, amenities, flights, policies) - - if ds is None: - raise TypeError("datastore creation failure") - - yield ds - - await ds.close() - - -def check_file_diff(file_diff): - assert file_diff["added"] == [] - assert file_diff["removed"] == [] - assert file_diff["changed"] == [] - assert file_diff["columns_added"] == [] - assert file_diff["columns_removed"] == [] - - -async def test_export_dataset(ds: spanner_postgres.Client): - airports, amenities, flights, policies = await ds.export_data() - - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - diff_airports = compare( - load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") - ) - check_file_diff(diff_airports) - - diff_amenities = compare( - load_csv(open(amenities_ds_path), "id"), - load_csv(open(amenities_new_path), "id"), - ) - check_file_diff(diff_amenities) - - diff_flights = compare( - load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") - ) - check_file_diff(diff_flights) - - diff_policies = compare( - load_csv(open(policies_ds_path), "id"), - load_csv(open(policies_new_path), "id"), - ) - check_file_diff(diff_policies) - - -async def test_get_airport_by_id(ds: spanner_postgres.Client): - res, sql = await ds.get_airport_by_id(1) - expected = models.Airport( - id=1, - iata="MAG", - name="Madang Airport", - city="Madang", - country="Papua New Guinea", - ) - assert res == expected - assert sql is None - - -@pytest.mark.parametrize( - "iata", - [ - pytest.param("SFO", id="upper_case"), - pytest.param("sfo", id="lower_case"), - ], -) -async def test_get_airport_by_iata(ds: spanner_postgres.Client, iata: str): - res, sql = await ds.get_airport_by_iata(iata) - expected = models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - assert res == expected - assert sql is None - - -search_airports_test_data = [ - pytest.param( - "Philippines", - "San jose", - None, - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=2313, - iata="EUQ", - name="Evelio Javier Airport", - city="San Jose", - country="Philippines", - ), - ], - id="country_and_city_only", - ), - pytest.param( - "united states", - "san francisco", - None, - [ - models.Airport( - id=3270, - iata="SFO", - name="San Francisco International Airport", - city="San Francisco", - country="United States", - ) - ], - id="country_and_name_only", - ), - pytest.param( - None, - "San Jose", - "San Jose", - [ - models.Airport( - id=2299, - iata="SJI", - name="San Jose Airport", - city="San Jose", - country="Philippines", - ), - models.Airport( - id=3548, - iata="SJC", - name="Norman Y. Mineta San Jose International Airport", - city="San Jose", - country="United States", - ), - ], - id="city_and_name_only", - ), - pytest.param( - "Foo", - "FOO BAR", - "Foo bar", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) -async def test_search_airports( - ds: spanner_postgres.Client, - country: str, - city: str, - name: str, - expected: List[models.Airport], -): - res, sql = await ds.search_airports(country, city, name) - assert res == expected - assert sql is None - - -async def test_get_amenity(ds: spanner_postgres.Client): - res, sql = await ds.get_amenity(0) - expected = models.Amenity( - id=0, - name="Coffee Shop 732", - description="Serving American cuisine.", - location="Near Gate B12", - terminal="Terminal 3", - category="restaurant", - hour="Daily 7:00 am - 10:00 pm", - sunday_start_hour=None, - sunday_end_hour=None, - monday_start_hour=None, - monday_end_hour=None, - tuesday_start_hour=None, - tuesday_end_hour=None, - wednesday_start_hour=None, - wednesday_end_hour=None, - thursday_start_hour=None, - thursday_end_hour=None, - friday_start_hour=None, - friday_end_hour=None, - saturday_start_hour=None, - saturday_end_hour=None, - ) - - assert res is not None - assert res.name == expected.name - assert res.description == expected.description - assert res.location == expected.location - assert res.terminal == expected.terminal - assert res.category == expected.category - assert res.hour == expected.hour - assert sql is None - - -amenities_search_test_data = [ - pytest.param( - # "Where can I get coffee near gate A6?" - amenities_query_embedding1, - 0.65, - 1, - [ - { - "name": "Coffee Shop 732", - "description": "Serving American cuisine.", - "location": "Near Gate B12", - "terminal": "Terminal 3", - "category": "restaurant", - "hour": "Daily 7:00 am - 10:00 pm", - }, - ], - id="search_coffee_shop", - ), - pytest.param( - # "Where can I look for luxury goods?" - amenities_query_embedding2, - 0.65, - 2, - [ - { - "name": "Gucci Duty Free", - "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", - "location": "Gate E9", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - { - "name": "Hermes Duty Free", - "description": "High-end French brand duty-free shop offering luxury goods and accessories.", - "location": "Gate E18", - "terminal": "International Terminal A", - "category": "shop", - "hour": "Daily 7:00 am-10:00 pm", - }, - ], - id="search_luxury_goods", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.9, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data -) -async def test_amenities_search( - ds: spanner_postgres.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[models.Amenity], -): - res, sql = await ds.amenities_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is None - - -async def test_get_flight(ds: spanner_postgres.Client): - res, sql = await ds.get_flight(1) - expected = models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime("2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ) - assert res == expected - assert sql is None - - -search_flights_by_number_test_data = [ - pytest.param( - "UA", - "1158", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=55455, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="JFK", - departure_time=datetime.strptime( - "2025-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="B50", - arrival_gate="E4", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "UU", - "0000", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "airline, number, expected", search_flights_by_number_test_data -) -async def test_search_flights_by_number( - ds: spanner_postgres.Client, - airline: str, - number: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_number(airline, number) - assert res == expected - assert sql is None - - -search_flights_by_airports_test_data = [ - pytest.param( - "2025-01-01", - "SFO", - "ORD", - [ - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="C38", - arrival_gate="D30", - ), - models.Flight( - id=13, - airline="UA", - flight_number="616", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="A11", - arrival_gate="D8", - ), - models.Flight( - id=25, - airline="AA", - flight_number="242", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E30", - arrival_gate="C1", - ), - models.Flight( - id=109, - airline="UA", - flight_number="1640", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E27", - arrival_gate="C24", - ), - models.Flight( - id=119, - airline="AA", - flight_number="197", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="D25", - arrival_gate="E49", - ), - models.Flight( - id=136, - airline="UA", - flight_number="1564", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime( - "2025-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" - ), - departure_gate="E3", - arrival_gate="C48", - ), - ], - id="successful_airport_search", - ), - pytest.param( - "2025-01-01", - "FOO", - "BAR", - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "date, departure_airport, arrival_airport, expected", - search_flights_by_airports_test_data, -) -async def test_search_flights_by_airports( - ds: spanner_postgres.Client, - date: str, - departure_airport: str, - arrival_airport: str, - expected: List[models.Flight], -): - res, sql = await ds.search_flights_by_airports( - date, departure_airport, arrival_airport - ) - assert res == expected - assert sql is None - - -policies_search_test_data = [ - pytest.param( - # "What is the fee for extra baggage?" - policies_query_embedding1, - 0.65, - 1, - [ - "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", - ], - id="search_extra_baggage_fee", - ), - pytest.param( - # "Can I change my flight?" - policies_query_embedding2, - 0.65, - 2, - [ - "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", - "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", - ], - id="search_flight_delays", - ), - pytest.param( - # "FOO BAR" - foobar_query_embedding, - 0.65, - 1, - [], - id="no_results", - ), -] - - -@pytest.mark.parametrize( - "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data -) -async def test_policies_search( - ds: spanner_postgres.Client, - query_embedding: List[float], - similarity_threshold: float, - top_k: int, - expected: List[models.Policy], -): - res, sql = await ds.policies_search(query_embedding, similarity_threshold, top_k) - assert res == expected - assert sql is None - - -validate_ticket_data = [ - pytest.param( - { - "airline": "UA", - "flight_number": "1158", - "departure_airport": "SFO", - "departure_time": "2025-01-01 05:57:00", - }, - models.Flight( - id=1, - airline="UA", - flight_number="1158", - departure_airport="SFO", - arrival_airport="ORD", - departure_time=datetime.strptime( - "2025-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" - ), - arrival_time=datetime.strptime("2025-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), - departure_gate="C38", - arrival_gate="D30", - ), - None, - ), - pytest.param( - { - "airline": "XX", - "flight_number": "9999", - "departure_airport": "ZZZ", - "departure_time": "2025-01-01 05:57:00", - }, - None, - None, - ), -] - - -@pytest.mark.parametrize("params, expected_data, expected_sql", validate_ticket_data) -async def test_validate_ticket( - ds: spanner_postgres.Client, params, expected_data, expected_sql -): - flight, sql = await ds.validate_ticket(**params) - assert flight == expected_data - assert sql == expected_sql diff --git a/retrieval_service/datastore/providers/test_data.py b/retrieval_service/datastore/providers/test_data.py deleted file mode 100644 index 67b8420d4..000000000 --- a/retrieval_service/datastore/providers/test_data.py +++ /dev/null @@ -1,3854 +0,0 @@ -amenities_query_embedding1 = [ - -0.03426577150821686, - -0.02576032467186451, - -0.0257739145308733, - 0.02418350614607334, - -0.012800461612641811, - 0.026428084820508957, - -0.005199299193918705, - -0.014483599923551083, - 0.007685032673180103, - 0.005428903736174107, - -0.008022738620638847, - -0.00879852008074522, - 0.08367078006267548, - -0.040768448263406754, - 0.026927951723337173, - 0.03621876984834671, - -0.030760779976844788, - -0.028195805847644806, - 0.013008243404328823, - -0.010026376694440842, - 0.009289189241826534, - -0.007033844478428364, - -0.0063486588187515736, - -0.005574745126068592, - -0.013823761604726315, - -0.032435957342386246, - 0.0002799989306367934, - 0.01701989397406578, - -0.00906086154282093, - 0.08542285859584808, - 0.045840635895729065, - -0.01797819323837757, - 0.011624608188867569, - -0.018443230539560318, - -0.07987944036722183, - 0.005588903557509184, - 0.012791724875569344, - 0.05254274606704712, - -0.03330793231725693, - 0.019106678664684296, - 0.022927535697817802, - -0.04566822573542595, - -0.06261023879051208, - -0.019059836864471436, - -0.01783635839819908, - 0.03983624652028084, - -0.01726730726659298, - -0.019073056057095528, - -0.01582428067922592, - -0.0007198549574241042, - -0.044508110731840134, - 0.013001530431210995, - 0.028854133561253548, - 0.0006031260709278286, - -0.022025788202881813, - 0.03945688158273697, - 0.004475956317037344, - 0.08217790722846985, - -0.0634990930557251, - 0.017360646277666092, - 0.0023998464457690716, - -0.08620959520339966, - -0.03779741749167442, - -0.003792044473811984, - -0.04823293164372444, - -0.0327751599252224, - 0.055752117186784744, - -0.026187188923358917, - 0.007542594335973263, - -0.010054156184196472, - -0.03223911300301552, - 0.019602425396442413, - -0.0059861247427761555, - -0.046868082135915756, - 0.002277666935697198, - -0.02000548131763935, - 0.003470888128504157, - 0.08358046412467957, - 0.018639329820871353, - -0.05532898008823395, - 0.020794754847884178, - 0.015758205205202103, - 0.016124302521348, - 0.035103265196084976, - 0.0031034336425364017, - -0.034351930022239685, - -0.03509620204567909, - -0.00954558327794075, - -0.024984825402498245, - -0.013729190453886986, - 0.002476991154253483, - 0.017008233815431595, - -0.030832083895802498, - 0.030465200543403625, - 0.022249529138207436, - 0.05427328497171402, - 0.06642039120197296, - 0.019389618188142776, - -0.12169259786605835, - 0.009085948579013348, - 0.05069831758737564, - -0.04257209599018097, - 0.04462549090385437, - -0.0796033963561058, - -0.0878254845738411, - -0.039167433977127075, - -0.018684355542063713, - 0.004540180321782827, - 0.09645611047744751, - 0.03089885413646698, - -0.031766340136528015, - -0.030099518597126007, - 0.014820144511759281, - -0.005461242515593767, - 0.0339224636554718, - 0.00081025151303038, - -0.07433383166790009, - -0.023641742765903473, - -0.08635573834180832, - -0.008645744994282722, - -0.03097374550998211, - 0.041431088000535965, - 0.029763847589492798, - 0.01577255129814148, - -0.07926613092422485, - 0.017581826075911522, - 0.027170516550540924, - -0.005928780883550644, - 0.009574304334819317, - -0.05559387430548668, - -0.02515755593776703, - -0.016147971153259277, - 0.036077242344617844, - 0.017313888296484947, - -0.0365297831594944, - -0.04114357754588127, - 0.02604534849524498, - -0.04690905660390854, - -0.025948986411094666, - 0.03335786983370781, - -0.040374983102083206, - -0.030743172392249107, - 0.03193559870123863, - 0.011348542757332325, - -0.026815302670001984, - -0.02737625315785408, - -0.055536236613988876, - -0.001741938292980194, - 0.033599790185689926, - 0.04281945899128914, - 0.00144023890607059, - 0.022144541144371033, - -0.00996614433825016, - -0.0033164252527058125, - 0.02024814672768116, - 0.014512380585074425, - -0.04363207146525383, - -0.04704112187027931, - -0.028573350980877876, - 0.019007043913006783, - -0.04736559838056564, - -0.014222265221178532, - -0.04180072247982025, - 0.042079854756593704, - 0.012490890920162201, - -0.007344002835452557, - 0.061116527765989304, - 0.0295159500092268, - 0.00599073339253664, - 0.030128255486488342, - 0.009138560853898525, - -0.04133842512965202, - -0.06662344187498093, - 0.0023828756529837847, - 0.0502401664853096, - 0.026548542082309723, - -0.038357168436050415, - 0.03903193771839142, - -0.053168606013059616, - -0.0677325427532196, - -0.059935491532087326, - 0.020086275413632393, - -0.045598410069942474, - -0.04444769024848938, - -0.09378903359174728, - 0.047032903879880905, - -0.07375598698854446, - 0.056551236659288406, - -0.05137373507022858, - 0.04824788123369217, - -0.05028494819998741, - 0.016014117747545242, - 0.010138050653040409, - 0.02457788959145546, - 0.011855662800371647, - 0.027981966733932495, - -0.0020986273884773254, - -0.03498651087284088, - 0.051092494279146194, - -0.06301183253526688, - -0.0433199405670166, - -0.03251234069466591, - 0.00597750348970294, - 0.0815092995762825, - -0.024481914937496185, - -0.06562220305204391, - 0.04231167957186699, - 0.0020994385704398155, - 0.04726693406701088, - -0.015971772372722626, - 0.039061538875103, - 0.04135257750749588, - -0.061783257871866226, - -0.05318614840507507, - 0.010268249548971653, - -0.04598318412899971, - 0.07178954035043716, - 0.006825078744441271, - 0.06711448729038239, - 0.044576290994882584, - -0.0021480885334312916, - 0.012099944986402988, - -0.11658186465501785, - 0.0034082578495144844, - 0.05092315375804901, - 0.01847822405397892, - 0.03226379677653313, - 0.021992335096001625, - 0.053209640085697174, - 0.030645929276943207, - -0.02950444631278515, - 0.01884371228516102, - -0.006298222579061985, - -0.017239868640899658, - 0.007011412642896175, - -0.0006902551976963878, - -0.012976488098502159, - -0.05459235608577728, - 0.019252361729741096, - 0.05829198658466339, - -0.030194489285349846, - 0.011151409707963467, - -0.035857491195201874, - -0.010664726607501507, - -0.03429829329252243, - 0.0050385454669594765, - -0.02156270109117031, - 0.06746513396501541, - 0.03187132999300957, - -0.06623721122741699, - -0.03179844841361046, - -0.0006363365100696683, - -0.029520796611905098, - 0.007758421823382378, - 0.02402757853269577, - 0.009184599854052067, - -0.0030690738931298256, - -0.05611732229590416, - 0.028113234788179398, - -0.01973802223801613, - 0.01237731333822012, - 0.025775298476219177, - -0.02546478435397148, - -0.0026233880780637264, - -0.031242458149790764, - 0.0005816408665850759, - -0.004353170283138752, - -0.0068981824442744255, - 0.022962864488363266, - -8.44376118038781e-05, - 0.007178373634815216, - 0.014454629272222519, - -0.004015014506876469, - 0.04104899615049362, - -0.026276789605617523, - -0.04235663264989853, - -0.08950557559728622, - 0.01752709597349167, - -0.018757643178105354, - 0.011998031288385391, - -0.018200194463133812, - -0.0671669989824295, - 0.013485183008015156, - 0.11210521310567856, - 0.02546621486544609, - 0.001860854565165937, - 0.008789285086095333, - -0.024343576282262802, - -0.0003261480596847832, - 0.015369621105492115, - -0.02684134803712368, - 0.07865234464406967, - 0.03759289160370827, - -0.03329639136791229, - 0.003315947251394391, - -0.0013569670263677835, - 0.002337598241865635, - -0.003477030899375677, - -0.029452459886670113, - 0.028148850426077843, - -0.008210550993680954, - 0.002662568585947156, - 0.02164490520954132, - 0.032797470688819885, - 0.049701038748025894, - -0.0497996024787426, - -0.024841122329235077, - -0.00040088073001243174, - -0.0004295688704587519, - 0.03656087443232536, - -0.018808044493198395, - -0.027797885239124298, - 0.025932976976037025, - -0.011894632130861282, - -0.03202470391988754, - 0.0027828500606119633, - -0.003750568488612771, - 0.011646115221083164, - -0.021086297929286957, - -0.006843156181275845, - 0.07365952432155609, - -0.03632322698831558, - 0.07528533041477203, - 0.02066589705646038, - 0.09558834135532379, - -0.006310032680630684, - 0.026566460728645325, - 0.0064064874313771725, - 0.023108812049031258, - -0.026019291952252388, - 0.05798571556806564, - -0.009460080415010452, - 0.05695430934429169, - -0.004025998525321484, - -0.0653546154499054, - -0.02196013554930687, - -0.021971065551042557, - 0.03307781741023064, - -0.008588350377976894, - -0.0034699025563895702, - -0.011194705031812191, - 0.005137080326676369, - -0.04894813895225525, - 0.002586323069408536, - -0.010381253436207771, - 0.08378668129444122, - 0.027226565405726433, - 0.011295315809547901, - 0.015983659774065018, - 0.008025172166526318, - -0.033309705555438995, - 0.0740402340888977, - -0.0043640341609716415, - -0.03541756793856621, - 0.024241771548986435, - 0.016692671924829483, - 0.06383056938648224, - 0.014306199736893177, - 0.014089040458202362, - -0.00892566703259945, - 0.014328823424875736, - -0.005649662110954523, - -0.018398432061076164, - -0.020176194608211517, - -0.025548426434397697, - 0.017152652144432068, - -0.008283275179564953, - 0.019667256623506546, - 0.009431501850485802, - -0.01535679493099451, - -0.03565458953380585, - 0.024572866037487984, - -0.03403777256608009, - -0.0326625294983387, - -0.05244242399930954, - -0.005712522193789482, - 0.03922955319285393, - -0.06832774728536606, - 0.007414190098643303, - 0.05970955267548561, - 0.04348660260438919, - 0.001208509667776525, - -0.004788508638739586, - 0.028216339647769928, - -0.02837369777262211, - 0.04278450086712837, - -0.02247512713074684, - 0.05434102192521095, - -0.0817558690905571, - 0.0025151113513857126, - 0.0002742201613727957, - 0.03834645450115204, - -0.005740721244364977, - 0.013248599134385586, - 0.05529296025633812, - -0.01752236671745777, - -0.010776135139167309, - 0.005296112969517708, - -0.005302086938172579, - -0.022550543770194054, - 0.042537420988082886, - -0.03559837490320206, - -0.006310023367404938, - -0.025245975703001022, - -0.023362139239907265, - -0.025547483935952187, - -0.0064442772418260574, - -0.008855178020894527, - 0.005135504994541407, - -0.003320463700219989, - -0.021954478695988655, - -0.014463184401392937, - 0.058454446494579315, - 0.021995194256305695, - -0.01652427203953266, - -0.03252106532454491, - 0.06388015300035477, - -0.007802637293934822, - 0.022889269515872, - -0.04852554202079773, - -4.248267941875383e-05, - 0.04092083126306534, - -0.030287625268101692, - 0.007485358044505119, - 0.040713220834732056, - 0.03670143708586693, - 0.001497412915341556, - 0.0067846826277673244, - -0.00881416629999876, - 0.030616773292422295, - -0.011219942942261696, - -0.028832504525780678, - -0.007011820562183857, - -0.008518273942172527, - 0.004615526646375656, - -0.009454179555177689, - -0.047546274960041046, - 0.005481477826833725, - 0.007349043153226376, - -0.014099732972681522, - -0.0292937271296978, - -0.002278352389112115, - 0.020548688247799873, - -0.0172805767506361, - -0.04494452849030495, - 0.004361841361969709, - 0.009823571890592575, - 0.001107021002098918, - -0.006808244623243809, - 0.0011773877777159214, - -0.015986323356628418, - 0.02480212412774563, - -0.04127948731184006, - -0.03844968229532242, - -0.05323105677962303, - 0.028368666768074036, - 0.05332525447010994, - 0.00904531218111515, - 0.006427109241485596, - 0.011720380745828152, - -0.0014891664031893015, - 0.028862567618489265, - -0.022990139201283455, - 0.011468810960650444, - 0.04723145440220833, - -0.016716765239834785, - -0.051652684807777405, - -0.004717829171568155, - 0.02256341092288494, - -0.006018962245434523, - 0.02104402333498001, - -0.06498764455318451, - -0.004732148721814156, - -0.0032248878851532936, - -0.00973446387797594, - 0.05162052810192108, - 0.039369262754917145, - -0.01352009642869234, - -0.012358207255601883, - 0.030109398066997528, - -0.0398034006357193, - 0.030024757608771324, - 0.02019876055419445, - -0.05926764756441116, - 0.018312862142920494, - 0.07650671154260635, - -0.018029402941465378, - -0.04034435749053955, - 0.029159152880311012, - -0.03741726279258728, - -0.021363917738199234, - -0.0065930867567658424, - 0.0163438580930233, - -0.10120650380849838, - -0.005669379606842995, - 0.0097228167578578, - -0.03312637284398079, - 0.01937609538435936, - -0.04617556929588318, - -0.02141471579670906, - -0.0602756068110466, - 0.017043599858880043, - -0.06665118783712387, - 0.020621756091713905, - -0.0649784505367279, - -0.012113559991121292, - -0.004847749602049589, - 0.017426947131752968, - -0.06283388286828995, - -0.01869840919971466, - -0.008959515020251274, - -0.062405869364738464, - -0.022153351455926895, - 0.023545408621430397, - 0.03537554666399956, - -0.0014322495553642511, - 0.038696691393852234, - -0.08671356737613678, - 0.017217906191945076, - 0.09264660626649857, - -0.040027517825365067, - 0.004925523418933153, - -0.023294730111956596, - -0.04532265663146973, - 0.04783519357442856, - 0.021777192130684853, - 0.004456411115825176, - 0.0051250336691737175, - 0.02497895061969757, - -0.00478561082854867, - -0.00011801073560491204, - 0.061384182423353195, - -0.017768600955605507, - -0.07746068388223648, - 0.03317761793732643, - 0.05053428187966347, - -0.02800365351140499, - 0.04379768669605255, - 0.018105633556842804, - -0.0015479091089218855, - -0.002680727979168296, - -0.08259643614292145, - 0.008219117298722267, - -0.010825438424944878, - -0.03416779264807701, - -0.03689880296587944, - -0.03198298066854477, - -0.049294233322143555, - 0.020523404702544212, - 0.039070673286914825, - -0.001901070587337017, - -0.00029624669696204364, - 0.0741531029343605, - 0.012070284225046635, - 0.019619300961494446, - -0.002626742934808135, - 0.00486735487356782, - -0.018754566088318825, - 0.05023876205086708, - 0.014961488544940948, - -0.05428607016801834, - -0.03313930705189705, - -0.050635579973459244, - -0.04230504855513573, - 0.016655201092362404, - 0.01574505679309368, - -0.051654569804668427, - 0.019881606101989746, - 0.02128366008400917, - 0.02757614105939865, - -0.010336929000914097, - 0.056945785880088806, - -0.04454180225729942, - 0.06827438622713089, - -0.014460902661085129, - -0.02083568461239338, - -0.011086664162576199, - 0.005639239214360714, - 0.021935952827334404, - 0.02777838334441185, - 0.042444176971912384, - 0.004305605310946703, - 0.039522435516119, - -0.01000981405377388, - 0.012677590362727642, - 0.04667935147881508, - -0.005894504487514496, - -0.004658759571611881, - 0.06209319084882736, - -0.04650682583451271, - -0.07117754220962524, - 0.024717286229133606, - -0.03466859087347984, - -0.017481304705142975, - 0.04632146656513214, - 0.060550544410943985, - -0.053832877427339554, - -0.016148246824741364, - 0.06377317011356354, - -0.02218356914818287, - -0.023291371762752533, - 0.012509875930845737, - -0.058022987097501755, - -0.03695506975054741, - 0.0778346061706543, - -0.002613927936181426, - -0.05347229540348053, - -0.009894736111164093, - 0.0019844567868858576, - 0.046893637627363205, - 0.027064664289355278, - -0.010159524157643318, - -0.04866268113255501, - 0.01145843230187893, - -0.011822600848972797, - 0.04748937115073204, - -0.01740541309118271, - -0.021529359742999077, - -0.03179234266281128, - 0.042745135724544525, - -0.02422766387462616, - -0.021380214020609856, - -0.04915768280625343, - -0.015968358144164085, - -0.043358102440834045, - 0.015237350948154926, - -0.029542429372668266, - 0.024833552539348602, - -0.04824815317988396, - 0.04341820627450943, - 0.04361231252551079, - 0.007433712016791105, - 0.013310015201568604, - -0.011933174915611744, - 0.026550767943263054, - 0.036758262664079666, - 0.07831371575593948, - 0.023837171494960785, - 0.00584573345258832, - 0.020032525062561035, - 0.04975404590368271, - -0.051814157515764236, - 0.01568654365837574, - 0.00030132144456729293, - -0.009239316917955875, - -0.010572581551969051, - 0.01581266149878502, - 0.0210961252450943, - 0.01816808432340622, - 0.003100331174209714, - 0.03474388271570206, - -0.015573166310787201, - -0.0084596648812294, - -0.003974889870733023, - 0.0009469022043049335, - -0.025476787239313126, - 0.07956649363040924, - 0.005124109331518412, - -0.04892757162451744, - 0.02605452761054039, - -0.011202532798051834, - -0.015115722082555294, - -0.07456538081169128, - 0.007558709941804409, - 0.04330822452902794, - 0.08615235239267349, - -0.04410251975059509, - 0.015813710168004036, - 0.035662632435560226, - -0.01862969435751438, - 0.03425082564353943, - -0.02227182686328888, - 0.005141818430274725, - 0.05089874565601349, - 0.012701905332505703, - -0.013883288018405437, - 0.027305221185088158, - -0.020592572167515755, - -0.006503250915557146, - 0.025432147085666656, - -0.009313317015767097, - -0.0023737808223813772, - 0.021655172109603882, - 0.001100344816222787, - -0.03597162291407585, - -0.005055297166109085, - 0.04600590094923973, - 0.012476412579417229, - 0.016993260011076927, - 0.09908081591129303, - 0.024321356788277626, - 0.006415405310690403, - 0.02881949581205845, - -0.02533145621418953, - 0.04728936403989792, - -0.005486254580318928, - 0.038305480033159256, - -0.06929311901330948, - 0.00621030805632472, - 0.0162116140127182, - -0.0001760926388669759, - 0.023459700867533684, - 0.03529792279005051, - 0.02748793177306652, - 0.003611242165789008, - -0.05813417211174965, - 0.004826054442673922, - 0.01651090569794178, - -0.023076370358467102, - 0.08670508861541748, - -0.005762477871030569, - 0.0019142942037433386, - -0.023943807929754257, - -0.02387240342795849, - 0.007661466021090746, - -0.030013950541615486, - -0.008249140344560146, - 0.039298877120018005, - 0.005842762067914009, - 0.034623462706804276, - 0.0962236151099205, - -0.024780360981822014, - 0.013985848985612392, - 0.01719195395708084, - 0.05271543189883232, - -0.00013382977340370417, - -0.07190033048391342, - -0.02255168743431568, - -0.05316685512661934, - 0.03703410178422928, - 0.014754639007151127, - -0.06959973275661469, - -0.0345720611512661, - -0.058119576424360275, - 0.04692500829696655, - -0.026836412027478218, - -0.09481404721736908, - -0.03146347776055336, - -0.0042710076086223125, - -0.004116047639399767, - 0.008759005926549435, - -0.014326146803796291, - 0.04631959646940231, - -0.0201833788305521, - -0.009601149708032608, - -0.004468688275665045, - -0.01033309567719698, - 0.028015540912747383, - -0.03926633670926094, - -0.004663986619561911, - -0.015060227364301682, - 0.07136976718902588, - 0.028410200029611588, - -0.0233630184084177, - -0.0362052284181118, - 0.07289160043001175, - 0.004230049438774586, - 0.004988215863704681, - 0.03202070668339729, - 0.011132704094052315, - -0.01084903348237276, - 0.05584624409675598, - 0.0193212628364563, - 0.0016466749366372824, - -0.023416657000780106, - -0.02541269361972809, - -0.03386934474110603, - 0.00014119570550974458, -] - -amenities_query_embedding2 = [ - 0.055892463773489, - -0.05997491255402565, - -0.013584441505372524, - -0.0265935231000185, - 0.0020448751747608185, - 0.007934683933854103, - 0.005789965856820345, - 0.02990216389298439, - -0.05208924040198326, - 0.019307192414999008, - 0.008684077300131321, - -0.025533346459269524, - 0.05415329709649086, - 0.017287472262978554, - -0.04157889634370804, - -0.02797822095453739, - 0.06811019778251648, - -0.03457444906234741, - -0.024283014237880707, - 0.017314212396740913, - 0.08240590989589691, - 0.02869713492691517, - 0.03265078365802765, - -0.016291556879878044, - 0.016947949305176735, - -0.03329828381538391, - -0.053065430372953415, - 0.03355345502495766, - -0.031491201370954514, - -0.026516882702708244, - 0.10797826945781708, - -0.07012204080820084, - 0.060921456664800644, - -0.011012264527380466, - -0.043754927814006805, - 0.0029176005627959967, - 0.01608179323375225, - -0.024473711848258972, - 0.01546369306743145, - 0.02891385182738304, - 0.028533777222037315, - -0.09347712248563766, - -0.05527815967798233, - 0.026095839217305183, - -0.036471132189035416, - 0.000700332981068641, - 0.029043281450867653, - -0.05801704153418541, - -0.040640689432621, - -0.0499303862452507, - -0.013027386739850044, - -0.004142343066632748, - -0.025489527732133865, - -0.06124532222747803, - 0.03770961984992027, - -0.045334190130233765, - -0.021565377712249756, - 0.023867793381214142, - -0.021159211173653603, - 0.04146633297204971, - 0.051257506012916565, - -0.05072063207626343, - -0.011530877090990543, - -0.028427356854081154, - -0.048709429800510406, - 0.01796594448387623, - -0.023258958011865616, - 0.039060913026332855, - 0.011104852892458439, - -0.04234814643859863, - 0.03985419124364853, - 0.04790177568793297, - 0.044422321021556854, - -0.04049019142985344, - 0.007075674366205931, - 0.02396037057042122, - 0.02916235849261284, - 0.007539267186075449, - -0.06333117187023163, - 0.011838966980576515, - -0.07863882184028625, - 0.00363808567635715, - -0.015896042808890343, - 0.03301239013671875, - 0.029863040894269943, - -0.010576950386166573, - 0.015188626013696194, - -0.03342222794890404, - -0.0734391137957573, - 0.00660671154037118, - -0.02061229757964611, - 0.0205592829734087, - 0.00014670552627649158, - 0.016545046120882034, - 0.03544733673334122, - 0.00025832452229224145, - -0.050505395978689194, - 0.03861454874277115, - -0.10247048735618591, - 0.030009305104613304, - 0.01338998507708311, - -0.014261230826377869, - -0.02073832042515278, - 0.014054902829229832, - -0.0036453783977776766, - 0.015207679010927677, - -0.044717226177453995, - 0.04089798033237457, - 0.02604365535080433, - -0.05490874499082565, - 0.0790773555636406, - -0.09637530893087387, - -0.06302930414676666, - 0.04474630579352379, - 0.013346063904464245, - -0.017725283280014992, - -0.02488657645881176, - -0.03484667092561722, - -0.027166390791535378, - -0.033644240349531174, - -0.032290272414684296, - 0.012302601709961891, - 0.09726711362600327, - 0.06376137584447861, - -0.015719586983323097, - 0.002885644556954503, - 0.0291243065148592, - 0.027186835184693336, - -0.061414457857608795, - -0.06410511583089828, - -0.0746668204665184, - 0.037541959434747696, - 0.013539639301598072, - 0.07725995033979416, - 0.0017328887479379773, - 0.0031629905570298433, - 0.047179438173770905, - -0.015867115929722786, - 0.04769112169742584, - 0.004888491239398718, - 0.02375648356974125, - -0.04010530188679695, - 0.03123117797076702, - 0.021493323147296906, - -0.029389532282948494, - -0.03706381842494011, - -0.016610441729426384, - 0.0070285797119140625, - 5.9323054301785305e-05, - 0.014585467055439949, - -0.008711783215403557, - 0.029944946989417076, - -0.030022991821169853, - 0.05483468994498253, - 0.031048446893692017, - -0.016199683770537376, - -0.014016312547028065, - -0.0394848994910717, - -0.0248811487108469, - 0.014773635193705559, - -0.003382717492058873, - -0.011599132791161537, - 0.04920763894915581, - 0.005124911665916443, - -0.010152447037398815, - 0.01414799876511097, - -0.01695883832871914, - 0.009662442840635777, - 0.0023108262103050947, - -0.008084166795015335, - -0.01871715486049652, - -0.053956836462020874, - -0.06357388198375702, - -0.010554073378443718, - 0.006106268148869276, - -0.018398361280560493, - 0.004335328936576843, - -0.0042707594111561775, - 0.0007066239486448467, - -0.02249768190085888, - -0.05418376997113228, - 0.0659264326095581, - -0.012410280294716358, - -0.00286452891305089, - -0.1157439798116684, - 0.034704435616731644, - -0.037926554679870605, - -0.020422611385583878, - -0.07531341165304184, - 0.01949732005596161, - -0.05050870031118393, - -0.002013351069763303, - -0.041672345250844955, - 0.04263600707054138, - 0.031116992235183716, - 0.05678504332900047, - 0.003932067193090916, - -0.1185530573129654, - -0.04860329627990723, - 0.00991850346326828, - -0.0011754331644624472, - 0.042161986231803894, - -0.006052409298717976, - -0.02097479999065399, - -0.008509831503033638, - -0.023507757112383842, - -0.0012523338664323092, - -0.02650841698050499, - 0.003538678400218487, - 0.03498062863945961, - 0.03011840581893921, - 0.020340479910373688, - -0.07838288694620132, - 0.014482667669653893, - -0.036364391446113586, - -0.034286707639694214, - 0.06363901495933533, - -0.07648839056491852, - 0.05878320336341858, - 0.00982689019292593, - 0.056423820555210114, - -0.028613438829779625, - -0.03909299522638321, - 0.001664603129029274, - 0.0850345566868782, - 0.031580306589603424, - 0.0003749633324332535, - -0.019643045961856842, - 0.08948344737291336, - 0.0018674435559660196, - -0.01687905564904213, - 0.041621606796979904, - -0.06001771241426468, - 0.006512368563562632, - -0.028240371495485306, - -0.01006176508963108, - 0.01389401312917471, - 0.014381010085344315, - -0.08013754338026047, - 0.0023332468699663877, - 0.010345798917114735, - -0.018212823197245598, - -0.07603197544813156, - -0.03311338648200035, - -0.021623481065034866, - -0.04325062036514282, - -0.07162046432495117, - 0.061312999576330185, - 0.02050204388797283, - -0.010103263892233372, - -0.00646591791883111, - 0.011036211624741554, - 0.035259805619716644, - 0.057919591665267944, - 0.03259759023785591, - -0.015070066787302494, - -0.038099098950624466, - -0.04394245520234108, - 0.026942312717437744, - 0.01333738211542368, - 0.026306461542844772, - -0.02589692361652851, - 0.020545141771435738, - 0.0002383101382292807, - 0.02731776237487793, - -0.006610419601202011, - -0.03179822489619255, - -0.03693592548370361, - 0.002071080496534705, - 0.04704642295837402, - -0.07082756608724594, - -0.008559775538742542, - -0.008444043807685375, - 0.042161811143159866, - -0.03733347728848457, - -0.03319123014807701, - -0.018544470891356468, - 0.04031077399849892, - 0.048427268862724304, - 0.0074465638026595116, - -0.023493941873311996, - -0.06068098545074463, - 0.033717069774866104, - 0.06744818389415741, - -0.03482946380972862, - 0.01618429832160473, - 0.05981732904911041, - -0.0010997216450050473, - 0.04653256759047508, - -0.03578869625926018, - -0.03537687659263611, - 0.01264767162501812, - 0.045370280742645264, - -0.021453427150845528, - -0.050688132643699646, - 0.01132906973361969, - 0.03883067145943642, - 0.01499777939170599, - -0.04145190864801407, - -0.017443379387259483, - 0.011989263817667961, - 0.041495054960250854, - -0.008248438127338886, - -0.0004769432416651398, - 0.005417921114712954, - -0.07346836477518082, - 0.0005546698812395334, - -0.02018367126584053, - -0.02123497985303402, - 0.017262516543269157, - -0.008493483997881413, - -0.08130844682455063, - -0.01236826740205288, - 0.05354618653655052, - -0.016064364463090897, - -0.00548898708075285, - 0.007786482572555542, - -0.031966887414455414, - 0.08632483333349228, - 0.044751930981874466, - -0.03949300944805145, - -0.05634103715419769, - 0.053038112819194794, - -0.01716967113316059, - -0.010559618473052979, - -0.013522796332836151, - -0.051095109432935715, - 0.021458633244037628, - 0.005720146000385284, - -0.018432749435305595, - 0.025449497625231743, - -0.00600121496245265, - 0.06037852540612221, - -0.007450680714100599, - -0.05636236444115639, - 0.04172592610120773, - -0.054465003311634064, - 0.03650405630469322, - 0.043080657720565796, - 0.04329998418688774, - -0.01350818295031786, - 0.014209131710231304, - -0.01629612408578396, - 0.06934478878974915, - -0.001543070306070149, - 0.04252493381500244, - -0.02896171435713768, - -0.010839433409273624, - -0.06216109171509743, - 0.02311992458999157, - -0.008387998677790165, - -0.007414316758513451, - -0.03614780679345131, - 0.0005806884146295488, - 0.024239318445324898, - 0.001425982336513698, - -0.03563220798969269, - -0.00833919271826744, - 0.050679534673690796, - -0.041606657207012177, - 0.020802471786737442, - 0.02608087845146656, - -0.042102813720703125, - 0.029114719480276108, - 0.009616547264158726, - -0.01412101648747921, - -0.004666727967560291, - 0.00019587291171774268, - -0.02695455029606819, - 0.01072065532207489, - 0.009656059555709362, - 0.05432412773370743, - -0.06182076036930084, - -0.040298473089933395, - -0.026910820975899696, - -0.01174367219209671, - -0.013932385481894016, - -0.015945740044116974, - 0.012008018791675568, - 0.05074099823832512, - 0.00927682314068079, - -0.03053595870733261, - 0.028868475928902626, - 0.0072925034910440445, - 0.02494952455163002, - 0.037187591195106506, - 0.0159270241856575, - 0.0239708349108696, - -0.02178368531167507, - 0.02790244109928608, - -0.026035508140921593, - 0.02539544552564621, - 0.02678285911679268, - -0.0707574412226677, - 0.020685184746980667, - -0.030424926429986954, - -0.014828452840447426, - -0.030652668327093124, - -0.08170772343873978, - -0.01702142134308815, - 0.007331102155148983, - -0.04227852448821068, - -0.017049476504325867, - -0.05233880877494812, - -0.07402509450912476, - -0.001562946243211627, - -0.03295401111245155, - -0.01274824794381857, - 0.017627358436584473, - -0.0998864471912384, - 0.009559054858982563, - -0.05189035087823868, - 0.022569114342331886, - 0.004390105605125427, - 0.010059800930321217, - -0.04590233415365219, - 0.037164218723773956, - 0.007901381701231003, - -0.03338608145713806, - -0.03087235800921917, - -0.05107805132865906, - 0.0480528250336647, - -0.003765291767194867, - -0.012466994114220142, - -0.03478700667619705, - -0.01581522636115551, - -0.017138347029685974, - 0.036342278122901917, - 0.0014710051473230124, - 0.05627215653657913, - -0.040627673268318176, - -0.015227816067636013, - -0.03684627264738083, - -0.028323376551270485, - 0.014107657596468925, - 0.01137453317642212, - 0.016129594296216965, - -0.008044733665883541, - 0.005818014033138752, - -0.026991799473762512, - -0.03653818368911743, - 0.005541219841688871, - -0.06085873767733574, - -0.04035622254014015, - -0.012497778981924057, - -0.05360090732574463, - -0.025158362463116646, - 0.0021022942382842302, - 0.015903254970908165, - -0.0028944017831236124, - -0.030319079756736755, - -0.002622714266180992, - 0.0027678776532411575, - 0.014293392188847065, - 0.002763730473816395, - 0.035011645406484604, - -0.0006696831551380455, - 0.03184981644153595, - -0.017257824540138245, - -0.014679639600217342, - -0.013911917805671692, - 0.01795722357928753, - -0.028336336836218834, - -0.014990454539656639, - 0.02535722590982914, - -0.003903317963704467, - -0.06638816744089127, - -0.022757655009627342, - 0.010679174214601517, - -0.04039209708571434, - 0.040884967893362045, - -0.004443031270056963, - 0.006637635640799999, - -0.016058357432484627, - 0.02666856162250042, - 0.0212579183280468, - -0.003821374848484993, - 0.01934080757200718, - 0.01348326075822115, - 0.045918386429548264, - -0.056099940091371536, - -0.03134442865848541, - -0.0005965607706457376, - -0.033271901309490204, - -0.031199991703033447, - 0.08714185655117035, - -0.011285306885838509, - -0.040237292647361755, - 0.02053520642220974, - -0.02633083425462246, - -0.03897789120674133, - -0.01351873204112053, - -0.003458587918430567, - -0.06597544997930527, - 0.008898239582777023, - 0.029492449015378952, - 0.013226298615336418, - -0.003140013199299574, - -0.05394711717963219, - -0.08130994439125061, - 0.006262777838855982, - 0.017404261976480484, - -0.025333264842629433, - 0.059790898114442825, - -0.0786728709936142, - -0.013297149911522865, - -0.003027330618351698, - 0.030809232965111732, - 0.02023313194513321, - 0.013135876506567001, - 0.02963395230472088, - 0.015442299656569958, - -0.02514062449336052, - -0.010330609045922756, - 0.06010979041457176, - 0.0017621004953980446, - 0.014102074317634106, - -0.1017540842294693, - -0.0015514338156208396, - 0.001461600884795189, - -0.019353782758116722, - -0.006693688221275806, - -0.014341448433697224, - -0.03346484899520874, - 0.03208518028259277, - 0.020188136026263237, - -0.026048999279737473, - 0.04354298859834671, - -0.01871739886701107, - 0.06643807888031006, - 0.014598089270293713, - 0.022948380559682846, - -0.020140744745731354, - -0.0046681491658091545, - -0.00024890649365261197, - 0.00035457927151583135, - -0.03209498152136803, - -0.02002917230129242, - 0.016402412205934525, - 0.03346584737300873, - -0.0005908212624490261, - 0.023563599213957787, - 0.007390783634036779, - 0.009256760589778423, - -0.027180122211575508, - 0.011099312454462051, - 0.005155880935490131, - 0.002778381807729602, - 0.01159622985869646, - 0.02391306683421135, - 0.044234663248062134, - 0.02620016783475876, - 0.033129554241895676, - 0.001371318125165999, - -0.013309131376445293, - 0.07020150125026703, - -0.0002449248859193176, - -0.053115639835596085, - -0.055809386074543, - 0.0634140819311142, - -0.03936535865068436, - -0.009354985319077969, - -0.016624726355075836, - -0.0019687588792294264, - -0.006904046982526779, - -0.010856819339096546, - 0.04992185905575752, - 0.017130548134446144, - 0.009966826997697353, - -0.04676090180873871, - -0.03064579702913761, - 0.03008275106549263, - -0.012013914994895458, - 0.04427943751215935, - -0.015875188633799553, - -0.05643991008400917, - -0.06719090789556503, - 0.0004910638090223074, - -0.02791833132505417, - -0.02678244560956955, - -0.004079312086105347, - -0.04411236196756363, - 0.07620707154273987, - -0.016606511548161507, - -0.008742762729525566, - 0.006683305837213993, - 0.031639449298381805, - -0.02537885122001171, - 0.0019565275870263577, - -0.03588010370731354, - -0.04876827448606491, - 0.01681654527783394, - -0.017407188192009926, - -0.012411867268383503, - 0.032110169529914856, - -0.0065548536367714405, - -0.006270687095820904, - -0.023005347698926926, - 0.06375826895236969, - -0.006516706198453903, - -0.02471497841179371, - -0.01067350897938013, - -0.0006809182232245803, - -0.014999224804341793, - -0.02211812324821949, - 0.010759683325886726, - 0.04021115601062775, - 0.0076550329104065895, - 0.03414943441748619, - 0.011936040595173836, - 0.008099671453237534, - 0.004115550313144922, - 0.008644312620162964, - -0.008241702802479267, - 0.007714245934039354, - 0.0011074553476646543, - -0.0029667378403246403, - -0.07378111779689789, - -0.02807626500725746, - 0.024542508646845818, - -0.043187133967876434, - -0.004124469123780727, - 0.053720057010650635, - -0.008595858700573444, - 0.039867084473371506, - 0.03721282631158829, - -0.0055439346469938755, - 0.0625622496008873, - -0.015861188992857933, - -0.02748195081949234, - 0.0032021852675825357, - 0.007614562287926674, - -0.06985673308372498, - -0.02221621386706829, - 0.005123065784573555, - -0.006325411144644022, - 0.016060974448919296, - -0.031992316246032715, - -0.059219032526016235, - 0.002513501327484846, - 0.02417716570198536, - -0.02441258728504181, - 0.04392332211136818, - -0.04575078561902046, - -0.048653218895196915, - -0.02222994528710842, - 0.011598773300647736, - -0.02178288623690605, - 0.044693127274513245, - -0.04683522880077362, - 0.08879371732473373, - -0.040017858147621155, - -0.04959360137581825, - -0.008063608780503273, - -0.017882592976093292, - -0.07503172010183334, - 0.05371706560254097, - 0.015170070342719555, - -0.03123209998011589, - 0.0035140777472406626, - 0.02233346365392208, - 0.01640397123992443, - -0.014707718044519424, - -0.013613482937216759, - 0.0013233484933152795, - -0.0008791678701527417, - -0.02032262273132801, - 0.02400825545191765, - -0.06783255189657211, - -0.04555197060108185, - -0.028633760288357735, - 0.030921589583158493, - -0.007539829704910517, - 0.021863384172320366, - -0.053305160254240036, - -0.07471157610416412, - -0.010853106155991554, - -0.04685426503419876, - 0.02456243894994259, - 0.04955914616584778, - -0.06055504083633423, - -0.04984809458255768, - -0.014585047960281372, - -0.013018430210649967, - -0.015312179923057556, - -0.015418448485434055, - -0.0030825685244053602, - -0.022507019340991974, - -0.011475660838186741, - 0.026761630550026894, - 0.019262980669736862, - 0.0024740612134337425, - -0.002789789577946067, - -0.11503353714942932, - -0.02278132364153862, - 0.04104975610971451, - 0.054458316415548325, - -0.02116033062338829, - 0.024971626698970795, - 0.034776389598846436, - -0.022459762170910835, - 0.006885903887450695, - 0.013161196373403072, - -0.0005137320840731263, - -0.01898886077105999, - -0.04583144560456276, - 0.010187031701207161, - 0.02589089423418045, - -0.03014824353158474, - 0.002218270907178521, - -0.014986210502684116, - -0.0097990483045578, - -0.05645289272069931, - 0.00023823024821467698, - -0.053020376712083817, - -0.0043429904617369175, - -0.07122455537319183, - -0.034520212560892105, - -0.0016130845760926604, - -0.011034572497010231, - 0.0709027647972107, - -0.041150301694869995, - -0.043698981404304504, - -0.010003538802266121, - -0.017892440780997276, - -0.001292813685722649, - -0.017644090577960014, - -0.01578078791499138, - -0.03597499430179596, - 0.027997421100735664, - 0.02767106704413891, - -0.0360206738114357, - 0.03190212696790695, - 0.004582636523991823, - 0.03906809911131859, - -0.030055221170186996, - -0.09126430004835129, - 0.01394179929047823, - 0.010483482852578163, - -0.026801828294992447, - 0.020426394417881966, - -0.009973660111427307, - 0.008716212585568428, - 0.03917853534221649, - -0.026502864435315132, - -0.01726888120174408, - 0.021492447704076767, - 0.10126625001430511, - -0.02129351906478405, - -0.014346443116664886, - -0.038301777094602585, - 0.06633483618497849, - -0.029763424769043922, - -0.03446481004357338, - -0.01992025040090084, - 0.03701861947774887, - 0.006249621976166964, - 0.018199164420366287, - -0.013173830695450306, - 0.038745105266571045, - 0.02355598658323288, - 0.05820638686418533, - 0.017508558928966522, - 0.020933374762535095, - 0.03504528850317001, - -0.03622489422559738, - -0.038554348051548004, - -0.04149714857339859, -] - -foobar_query_embedding = [ - 0.010542947798967361, - -0.0020771699491888285, - 0.018793586641550064, - 0.015323284082114697, - 0.020855776965618134, - -0.027744701132178307, - 0.054254867136478424, - -0.022548092529177666, - 0.01114976592361927, - -0.02701328881084919, - 0.03188479319214821, - -0.017661605030298233, - 0.0019851233810186386, - 0.04048242047429085, - -0.022234367206692696, - -0.03877898305654526, - -0.06809619814157486, - -0.004483973141759634, - 0.07192538678646088, - 0.03258777782320976, - -0.06885764002799988, - -0.01065030600875616, - 0.02601620741188526, - -0.025261731818318367, - -0.026098977774381638, - -0.08830367028713226, - 0.0049090199172496796, - 0.006491732317954302, - -0.03235504776239395, - -0.014904284849762917, - 0.03910832852125168, - 0.008180372416973114, - -0.008882002905011177, - -0.028500935062766075, - -0.005907626356929541, - 0.08642354607582092, - -0.025083258748054504, - 0.028573770076036453, - 0.005551688373088837, - 0.028575535863637924, - 0.027835840359330177, - -0.01694948971271515, - 0.03894244506955147, - -0.00702495314180851, - 0.003198516322299838, - 0.005768909119069576, - -0.027666108682751656, - 0.045704301446676254, - 0.00855960976332426, - -0.03772832453250885, - 0.01113964058458805, - -0.006848837714642286, - 0.0031126488465815783, - 0.021760165691375732, - 0.00305558112449944, - 0.027835113927721977, - -0.02689148113131523, - -0.02474796772003174, - -0.001405426301062107, - -0.017837561666965485, - 0.005955921020358801, - 0.026807915419340134, - 0.02563514932990074, - -0.04626566916704178, - 0.004874375648796558, - 0.012400802224874496, - -0.060049138963222504, - -0.04084951803088188, - 0.018301483243703842, - -0.05275797098875046, - 0.0086642662063241, - 0.03990742936730385, - 0.019040320068597794, - 0.008621668443083763, - 0.0724521055817604, - 0.0412990003824234, - 0.0014013727195560932, - 0.06253764778375626, - 0.04116380214691162, - -0.03582260012626648, - -0.06224026530981064, - -0.02725234627723694, - -0.0928601399064064, - -0.07584880292415619, - -0.018580840900540352, - 0.02621578238904476, - 0.05248001962900162, - 0.017418449744582176, - -0.008641145192086697, - 0.018957287073135376, - -0.03877270966768265, - -0.06306338310241699, - 0.007788477465510368, - 0.04999104142189026, - -0.03251931816339493, - 0.008869366720318794, - -0.03268330171704292, - -0.005828124471008778, - -0.026510341092944145, - -0.009969801642000675, - 0.0044481391087174416, - -0.034640274941921234, - 0.03356659412384033, - -0.014506028033792973, - 0.04732068255543709, - -0.011050601489841938, - -0.004203138407319784, - 0.03990521654486656, - -0.041007958352565765, - -0.07893070578575134, - -0.05996153876185417, - 0.014063318260014057, - -0.0026169843040406704, - 0.02719777822494507, - -0.015427917242050171, - -0.02064945176243782, - 0.007363075856119394, - -0.03320041671395302, - 0.018084213137626648, - 0.023733800277113914, - -0.030606912449002266, - 0.02048870362341404, - 0.02402838133275509, - 0.005973761901259422, - 0.054204218089580536, - -0.0030790558084845543, - 0.004673156887292862, - -0.027416644617915154, - -0.04442985728383064, - 0.012847314588725567, - 0.07748882472515106, - 0.014493257738649845, - 0.004495652858167887, - -0.04001886025071144, - -0.036252908408641815, - 0.05335517227649689, - -0.01777976006269455, - 0.01082952693104744, - 0.01921195350587368, - 0.03392081335186958, - -0.042148079723119736, - -0.03388960659503937, - 0.007531383074820042, - 0.018161674961447716, - 0.03656009957194328, - -0.027102850377559662, - 0.02251628041267395, - -0.012046369723975658, - -0.05348425731062889, - 0.0022497086320072412, - 0.04016721993684769, - -0.01717129535973072, - 0.009544780477881432, - 0.02665960043668747, - 0.08562447875738144, - 0.057247646152973175, - 0.09216198325157166, - 0.02830491214990616, - -0.013137255795300007, - -0.07119007408618927, - -0.029031438753008842, - 0.02400599606335163, - 0.0208907388150692, - -0.00500252190977335, - 0.02207385003566742, - -0.03819321095943451, - -0.006595571991056204, - -0.03734014928340912, - -0.007158266380429268, - 0.016129007562994957, - 0.01862095668911934, - -0.11248142272233963, - -0.04325195029377937, - -0.04442116618156433, - 0.07084795832633972, - -0.05036599189043045, - 0.0002208968944614753, - -0.0015836356906220317, - 0.004539810121059418, - -0.02397581934928894, - -0.014102368615567684, - -0.023190686479210854, - 0.03535286337137222, - -0.04591631889343262, - -0.050178974866867065, - -0.01359744742512703, - 0.025772321969270706, - -0.06363644450902939, - -0.03543799743056297, - 0.07043260335922241, - -0.014484401792287827, - -0.012782412581145763, - 0.01116127334535122, - -0.05494656041264534, - -0.002786465920507908, - -0.002305548870936036, - 0.011700322851538658, - -0.11845295876264572, - -0.011913647875189781, - 0.04774351045489311, - -0.013436089269816875, - -0.01566474139690399, - 0.02157527767121792, - 0.006863899063318968, - -0.012185136787593365, - -0.0026727840304374695, - -0.08895088732242584, - 0.003419826738536358, - 0.07213828712701797, - -0.016972415149211884, - 0.00856101792305708, - 0.001426316681317985, - 0.09427885711193085, - 0.01850116066634655, - -0.0258034560829401, - -0.04975779727101326, - 0.014081566594541073, - -0.03875066339969635, - -0.06455642729997635, - -0.03468679264187813, - 0.015874363481998444, - -0.020243149250745773, - 0.035441748797893524, - 0.054731763899326324, - 0.01801624707877636, - -0.005626743193715811, - 0.028992490842938423, - 0.006307659205049276, - 0.0319247767329216, - 0.002187710953876376, - 0.0065626646392047405, - -0.012596727348864079, - 0.057705774903297424, - -0.026091232895851135, - 0.030690236017107964, - 0.02947114408016205, - 0.018958140164613724, - -0.007928173989057541, - -0.013222036883234978, - -0.03803595155477524, - 0.0013288946356624365, - 0.09856755286455154, - 0.014013550244271755, - 0.07894343137741089, - -0.02690303884446621, - -0.003988398239016533, - -0.0130416639149189, - 0.042919766157865524, - 0.05932505428791046, - -0.035565365105867386, - 0.035141583532094955, - -0.04124683886766434, - 0.025722447782754898, - -0.00018550716049503535, - -0.005109016317874193, - 0.033928681164979935, - 0.00565310986712575, - 0.04128587618470192, - 0.007276812102645636, - -0.025259966030716896, - 0.032601892948150635, - 0.016339223831892014, - 0.028560452163219452, - 0.13033530116081238, - -0.029234036803245544, - -0.006833497900515795, - -0.03143247589468956, - -0.012883774004876614, - 0.03338552638888359, - -0.057341158390045166, - -0.01174136996269226, - -0.0073961601592600346, - 0.03528213128447533, - 0.035989999771118164, - -0.037950076162815094, - 0.005264998879283667, - -0.04013827070593834, - 0.01746419444680214, - -0.0018648888217285275, - 0.04137907922267914, - -0.004513342399150133, - 0.012807865627110004, - 0.010365020483732224, - 0.037427470088005066, - -0.002865932183340192, - -0.0010766880586743355, - -0.04416626691818237, - 0.021367965266108513, - -0.05103495717048645, - -0.015714535489678383, - -0.030045492574572563, - 0.020174261182546616, - -0.036388009786605835, - 0.07772187143564224, - -0.021702803671360016, - -0.04465419054031372, - 0.04566051438450813, - -0.034563228487968445, - 0.01336362212896347, - 0.02437947504222393, - 0.0003132100682705641, - 0.0008901790715754032, - -0.02403002604842186, - -0.06205974146723747, - 0.0011670388048514724, - 0.011326261796057224, - -0.00750376982614398, - -0.00492094224318862, - 0.02063891850411892, - -0.01290293037891388, - 0.0344778448343277, - 0.02274216338992119, - 0.023453932255506516, - -0.0059125120751559734, - -0.011719446629285812, - 0.07904659956693649, - 0.008855306543409824, - -0.04290352389216423, - -0.08023841679096222, - 0.02476671151816845, - -0.01854325272142887, - -0.002190910978242755, - 0.06412454694509506, - -0.008029614575207233, - 0.01590806618332863, - -0.028129110112786293, - -0.021355995908379555, - 0.012765574268996716, - 0.09463607519865036, - -0.06958933174610138, - -0.02119147591292858, - 0.02453245036303997, - -0.005154046695679426, - -0.0077933818101882935, - 0.011605029925704002, - 0.009440272115170956, - 0.019215304404497147, - 0.06640348583459854, - 0.02133198082447052, - -0.08007435500621796, - -0.0028622981626540422, - -0.0028950911946594715, - -0.037916556000709534, - -0.018876947462558746, - -0.013363506644964218, - 0.0535924918949604, - -0.04713330790400505, - -0.040017012506723404, - 0.008122353814542294, - -0.023449432104825974, - 0.03991822153329849, - 6.734774069627747e-05, - -0.012273556552827358, - 0.003721225541085005, - 0.0016754773678258061, - -0.011123294942080975, - 0.0019946210086345673, - 0.006982120685279369, - 0.029747584834694862, - -0.013039481826126575, - -0.03484058752655983, - 0.022132227197289467, - -0.0546797476708889, - 0.015616540797054768, - 0.023645279929041862, - -0.004792783875018358, - -0.06951714307069778, - 0.02854088507592678, - 0.04693328216671944, - -0.0006264290423132479, - 0.007349343970417976, - 0.015109053812921047, - 0.003017974551767111, - 0.020248563960194588, - -0.0726543590426445, - 0.01134587824344635, - 0.0005060675903223455, - 0.01829889789223671, - 0.04753483459353447, - -0.0037315706722438335, - -0.013771136291325092, - 0.026760051026940346, - 0.00406580651178956, - -0.03558425232768059, - -0.005299488548189402, - 0.018526272848248482, - 0.045592159032821655, - 0.019966932013630867, - 0.018006660044193268, - -0.05187966674566269, - 0.01618654653429985, - 0.026056606322526932, - 0.022463358938694, - -0.01645507477223873, - -0.017413437366485596, - 0.05479249358177185, - -0.014588613994419575, - 0.012588215991854668, - -0.03574773669242859, - -0.016849415376782417, - 0.015653977170586586, - -0.0023553497157990932, - -0.06709419935941696, - -0.07179543375968933, - 0.015331734903156757, - 0.01766319014132023, - -0.007319381460547447, - 0.0189521461725235, - 0.019614940509200096, - -0.08486206829547882, - 0.0007858230965211987, - -0.0036041629500687122, - 0.014080874621868134, - 0.06779641658067703, - 0.03182194381952286, - -0.04642670601606369, - 0.004083056468516588, - -0.06267423182725906, - -0.0324474573135376, - 0.019781235605478287, - 0.032323069870471954, - 0.0070039681158959866, - -0.005001536570489407, - 0.009073197841644287, - 0.034770116209983826, - 4.4110685848863795e-05, - -0.09198522567749023, - 0.016849718987941742, - -0.03548656404018402, - -0.0692531019449234, - 0.018414774909615517, - -0.0921187475323677, - -0.015344605781137943, - -0.04134054109454155, - -0.028480011969804764, - -0.08146455883979797, - 0.09168414771556854, - 0.07998357713222504, - 0.026014341041445732, - 0.017937038093805313, - -0.028883034363389015, - 0.03057144582271576, - -0.010724442079663277, - 0.018125038594007492, - 0.029578905552625656, - -0.024579638615250587, - -0.07526443898677826, - -0.016231250017881393, - -0.06456200778484344, - -0.030072860419750214, - -0.022895049303770065, - 0.00415444141253829, - 0.01151253841817379, - 0.06331072002649307, - -0.025500191375613213, - -0.00053714046953246, - 0.025350533425807953, - 0.0055648707784712315, - 0.0021499800495803356, - -0.04386467859148979, - -0.02675274945795536, - -0.011274125427007675, - -0.01867513172328472, - 0.025913124904036522, - -0.023076586425304413, - -0.03464184328913689, - 0.011390935629606247, - 0.0028527271933853626, - -0.0400318019092083, - 0.025248177349567413, - -0.04997679218649864, - 0.020129835233092308, - -0.012148474343121052, - 0.014167075976729393, - 0.017614644020795822, - 0.01582793891429901, - 0.053372036665678024, - 0.07793024927377701, - -0.03784414380788803, - -0.028109664097428322, - -0.014291783794760704, - 0.027021605521440506, - 0.02588130347430706, - -0.04984234645962715, - -0.01770542934536934, - -0.008349047042429447, - 0.04212779924273491, - 0.0053593749180436134, - -0.03806058689951897, - 0.036747727543115616, - 0.05757524445652962, - -0.057155463844537735, - -0.04970157518982887, - -0.055126406252384186, - 0.029877936467528343, - -0.02737145684659481, - 0.011190582066774368, - -0.01531775202602148, - 0.052095942199230194, - 0.019683949649333954, - 0.04340781271457672, - -0.003975852858275175, - 0.014913157559931278, - 0.02827552892267704, - -0.05311066657304764, - -0.011286910623311996, - 0.0062117730267345905, - -0.021665755659341812, - -0.038417086005210876, - 0.037535544484853745, - 0.0058376360684633255, - 0.06614411622285843, - -0.008359941653907299, - -0.056450869888067245, - -0.02791580930352211, - -0.01933380216360092, - -0.01675199158489704, - 0.020247725769877434, - -0.0026980696711689234, - -0.013617953285574913, - -0.008764070458710194, - -0.03527483344078064, - 0.035375695675611496, - -0.0015997681766748428, - 0.035637691617012024, - 0.01963697001338005, - 0.017015134915709496, - -0.0003604411904234439, - -0.039214011281728745, - 0.014239010401070118, - -0.02629069983959198, - 0.0385490283370018, - 0.0002026914880843833, - 0.014476795680820942, - 0.04752976819872856, - -0.07340085506439209, - 0.02503885328769684, - 0.002498525194823742, - 0.054385773837566376, - 0.06879492104053497, - -0.007840770296752453, - -0.031533658504486084, - 0.004206547047942877, - 0.02952091209590435, - 0.04527729004621506, - 0.015435987152159214, - 0.015691671520471573, - 0.027629222720861435, - 0.027335476130247116, - 0.03717225790023804, - 0.03913302347064018, - -0.007201953325420618, - -0.038456082344055176, - -0.013573899865150452, - 0.017286749556660652, - -0.009555193595588207, - -0.039228253066539764, - -0.018771562725305557, - -0.019485214725136757, - -0.05869167298078537, - 0.042072784155607224, - -0.01506104413419962, - -0.0046134754084050655, - -0.02039765939116478, - 0.012185851112008095, - -0.04585464298725128, - 0.07079082727432251, - 0.04161670804023743, - -0.014112657867372036, - 0.04570164903998375, - -0.04709532856941223, - -0.03184964880347252, - 0.014102606102824211, - 0.010880542919039726, - -0.0010067857801914215, - 0.020858198404312134, - 0.001616304274648428, - -0.027871672064065933, - -0.009556083008646965, - 0.023160360753536224, - -0.03516773879528046, - 0.008302837610244751, - -0.01353617012500763, - 0.05320126563310623, - -0.04701259359717369, - -0.017535002902150154, - -0.018479419872164726, - -0.03794485703110695, - 0.029532013460993767, - -0.01962168514728546, - 0.013537120074033737, - -0.0445764884352684, - -0.0355890728533268, - -0.013421730138361454, - -0.09442584961652756, - -0.05684971809387207, - 0.001214321469888091, - -0.011213802732527256, - 0.021285628899931908, - 0.008936009369790554, - 0.006863136775791645, - 0.015872832387685776, - 0.0485493429005146, - -0.04763637110590935, - 0.005595272406935692, - 0.03826083615422249, - 0.06484928727149963, - 0.05174949765205383, - 0.0064583332277834415, - 0.005473082885146141, - 0.03309480473399162, - 0.061267245560884476, - 0.008687039837241173, - 0.003077984554693103, - -0.0014701125910505652, - 0.00422841589897871, - -0.03934015706181526, - 0.03204263374209404, - 0.029333733022212982, - -0.005301087629050016, - 0.007751498371362686, - 0.023613551631569862, - -0.04376191645860672, - -0.0104624442756176, - 0.04873567074537277, - -0.05050220713019371, - 0.006768053397536278, - 0.04576372355222702, - -0.020884448662400246, - -0.031267423182725906, - -0.019394734874367714, - -0.09519949555397034, - -0.06667064875364304, - -0.06275929510593414, - 0.0037940307520329952, - 0.006049887277185917, - -0.02023051679134369, - -0.0015441244468092918, - 0.0719972774386406, - -0.02210680954158306, - -0.0027332452591508627, - 0.04092469811439514, - -0.036194171756505966, - 0.0028341219294816256, - -0.01672888547182083, - 0.014553734101355076, - -0.06705988943576813, - -0.056773219257593155, - 0.018284067511558533, - -0.0012927469797432423, - -0.029053328558802605, - -0.0294451043009758, - -0.006686860229820013, - -0.07940531522035599, - 0.01285732164978981, - 0.030497534200549126, - -0.019285226240754128, - 0.02499590441584587, - -0.024998806416988373, - 0.046604398638010025, - 0.00456125196069479, - -0.0273160170763731, - -0.006036459002643824, - 0.05902529135346413, - -0.06650868058204651, - 0.02275759167969227, - 0.020499039441347122, - 0.0619879812002182, - 0.027179839089512825, - 0.02172078751027584, - -0.009773382917046547, - -0.06453000009059906, - 0.04223477467894554, - -0.027186917141079903, - 0.025517649948596954, - 0.009861636906862259, - -0.00415257690474391, - 0.00037013995461165905, - 0.011218744330108166, - -0.0076008932664990425, - -0.07056346535682678, - -0.01989125646650791, - 0.0023018799256533384, - -0.0039259023033082485, - 0.036238156259059906, - -0.027882106602191925, - -0.0005035923677496612, - 0.02293340489268303, - -0.023957427591085434, - -0.0011390234576538205, - 0.0310504250228405, - 0.031994838267564774, - -0.014038464985787868, - -0.032214924693107605, - -0.044814836233854294, - -0.05194795876741409, - -0.02142512798309326, - 0.04763603210449219, - -0.056629836559295654, - 0.005412807688117027, - -0.05006266385316849, - 0.0285593681037426, - 0.052619025111198425, - 0.014774417504668236, - 0.03433847799897194, - -0.04708310216665268, - -0.0752381980419159, - -0.0006509265513159335, - -0.0514691099524498, - -0.013008086942136288, - 0.0038416720926761627, - -0.0593993179500103, - 0.028144342824816704, - 0.044867511838674545, - 0.0027947607450187206, - 0.03446776419878006, - -0.028873153030872345, - 0.01712701842188835, - -0.04372264817357063, - 0.038617007434368134, - -0.025441637262701988, - 0.021534765139222145, - 0.07294990867376328, - -0.05809745192527771, - -0.017783673480153084, - -0.03648945316672325, - -0.01126911211758852, - -0.013493646867573261, - -0.021482286974787712, - 0.01550940703600645, - -0.08124933391809464, - -0.04867582768201828, - 0.08051152527332306, - 0.04546916484832764, - -0.0038702557794749737, - -0.03737856447696686, - 0.027614301070570946, - -0.015709565952420235, - 0.06368324160575867, - -0.02034878358244896, - 0.007513822987675667, - 0.021807905286550522, - -0.0018058198038488626, - -0.03278772532939911, - 0.04744148254394531, - -0.0041910866275429726, - 0.011152463965117931, - -0.06345872581005096, - 0.008331837132573128, - -0.0465482659637928, - -0.059423308819532394, - 0.004401106853038073, - 0.04347091168165207, - 0.04015486687421799, - -0.0498359352350235, - 0.06830709427595139, - 0.02497447282075882, - 0.043160878121852875, - 0.0036167881917208433, - 0.015616455115377903, - -0.010447202250361443, - -0.003197269979864359, - -0.036088019609451294, - 0.052625950425863266, - 0.019057126715779305, - 0.009622062556445599, - -0.057408180087804794, - 0.007988927885890007, -] - -policies_query_embedding1 = [ - 0.019435187801718712, - -0.0037556702736765146, - 0.02552376128733158, - 0.01096564345061779, - 0.015737976878881454, - 0.01130329817533493, - 0.00023166643222793937, - 0.06781462579965591, - -0.019040381535887718, - 0.018963145092129707, - -0.0014944596914574504, - -0.006447169929742813, - -0.010107671841979027, - -0.04205232113599777, - -0.05787590891122818, - 0.029478052631020546, - 0.07108640670776367, - -0.0961247906088829, - -0.001972149359062314, - 0.014465388841927052, - 0.03600253909826279, - -0.01512065902352333, - 0.017059195786714554, - -0.004590804222971201, - 0.006954622454941273, - -0.02905983477830887, - 0.0013063991209492087, - -0.0031190826557576656, - -0.021369535475969315, - 0.010408749803900719, - 0.011391591280698776, - -0.046294450759887695, - -0.023932354524731636, - -0.02043585665524006, - -0.02021188661456108, - 0.01624036394059658, - 0.011137217283248901, - 0.03264450281858444, - -0.0036236264277249575, - -0.004322638735175133, - -0.0004821106558665633, - -0.06352217495441437, - -0.04073949530720711, - 0.0346953347325325, - 0.010889098979532719, - 0.0397348552942276, - 0.03885148838162422, - -0.0017959770048037171, - -0.03925124555826187, - -0.08452218770980835, - 0.005138807464390993, - 0.0466533862054348, - 0.07980414479970932, - -0.006792259402573109, - -0.01979881152510643, - -0.012917381711304188, - -0.04023488983511925, - 0.04133393242955208, - 0.004745221231132746, - 0.0013040559133514762, - 0.0058005014434456825, - 0.019697006791830063, - 0.002732495777308941, - -0.004619953688234091, - -0.03908479958772659, - 0.013410205952823162, - -0.021526562049984932, - -0.0008117908146232367, - 0.031792670488357544, - -0.029192306101322174, - -0.04605792462825775, - 0.013882720842957497, - 0.005172772333025932, - 0.024236906319856644, - 0.02311934530735016, - -0.021947242319583893, - 0.054519977420568466, - 0.015746576711535454, - -0.06235334277153015, - -0.019076500087976456, - -0.03980528190732002, - -0.013803640380501747, - -0.0958453118801117, - -0.0001025507808662951, - -0.026332147419452667, - -0.0789344534277916, - -0.07367418706417084, - -0.005086809862405062, - -0.053099703043699265, - 0.020726190879940987, - -0.0018209840636700392, - 0.0871577039361, - -0.010064671747386456, - 0.017551545053720474, - -0.02806132659316063, - -0.0038932692259550095, - 0.031316883862018585, - 0.03234038129448891, - -0.06839751452207565, - 0.016536235809326172, - -0.009240190498530865, - 0.011623494327068329, - -0.006169336847960949, - -0.000926577253267169, - 0.017296604812145233, - 0.02705499157309532, - -0.02804708480834961, - 0.021506542339920998, - 0.035366397351026535, - 0.030248131603002548, - -0.024746578186750412, - 0.0035174682270735502, - 0.04358452185988426, - -0.005591242574155331, - 0.022160520777106285, - -0.019638773053884506, - -0.03563052788376808, - 0.006377449259161949, - -0.03663938492536545, - -0.02226337604224682, - -0.04131054878234863, - 0.009974549524486065, - 0.1250336468219757, - -0.004605799913406372, - -0.01878715120255947, - 0.05509314686059952, - 0.040508922189474106, - -0.022547876462340355, - 0.005452202633023262, - -0.005716193001717329, - -0.037371955811977386, - -0.036816444247961044, - 0.006496272515505552, - 0.10744160413742065, - -0.03494149073958397, - 0.025258086621761322, - 0.038687463849782944, - -0.03720446676015854, - 0.0592627115547657, - 0.059475693851709366, - 0.03244549408555031, - 0.03426070511341095, - -0.03519924357533455, - -0.040127288550138474, - -0.0029384256340563297, - -0.03274993598461151, - 0.021802294999361038, - 0.024221377447247505, - 0.02028047665953636, - 0.00155010842718184, - 0.019937604665756226, - 0.08657442778348923, - -0.04831191152334213, - 0.019992614164948463, - -0.007619012147188187, - -0.03595961630344391, - -0.006503098178654909, - -0.0585431270301342, - 0.007151509169489145, - -0.013700856827199459, - -0.016966743394732475, - 0.01462690718472004, - -0.08147680759429932, - 0.004718420561403036, - -0.052766103297472, - -0.04831836372613907, - 0.055080346763134, - -0.039204906672239304, - -0.0492672398686409, - -0.011144218035042286, - -0.0493973083794117, - -0.07486134767532349, - -0.04354774206876755, - 0.0056018526665866375, - 0.007830990478396416, - -0.03584715351462364, - 0.016087781637907028, - -0.027233494445681572, - -0.061622947454452515, - -0.032096587121486664, - 0.001351039856672287, - -0.004982681944966316, - -0.040377501398324966, - 0.0036965718027204275, - -0.12154152244329453, - 0.008651738055050373, - -0.04771796613931656, - -0.03748656436800957, - -0.0729370191693306, - -0.008526206016540527, - -0.10364904999732971, - -0.040909018367528915, - -0.10856490582227707, - 0.020391175523400307, - -0.011522337794303894, - 0.07111772894859314, - 0.07067426294088364, - -0.026042072102427483, - -0.028634272515773773, - 0.013032005168497562, - -0.017624586820602417, - -0.032281868159770966, - -0.014472651295363903, - -0.001706311828456819, - 0.005799585022032261, - -0.02358911745250225, - 0.03330744057893753, - -0.05084405094385147, - -0.02596958912909031, - -6.544052303070202e-05, - 0.01750905066728592, - 0.016637954860925674, - -0.1175457164645195, - -0.0004703553859144449, - -0.014331148006021976, - -0.006356349214911461, - 0.13624931871891022, - -0.07946453243494034, - 0.03789275884628296, - -0.0514734648168087, - -0.01155677530914545, - 0.009684423916041851, - -0.11389581859111786, - 0.007444752845913172, - 0.05756825953722, - -0.029828282073140144, - -0.005600295029580593, - 0.016466816887259483, - 0.07256552577018738, - -0.013813336379826069, - -0.0594477653503418, - 0.04681297391653061, - 0.01556143444031477, - 0.014007221907377243, - -0.010759816505014896, - -0.0015058288117870688, - -0.002126917475834489, - 0.02892388589680195, - -0.10400842130184174, - -0.012912698090076447, - -0.08100010454654694, - 0.05549990013241768, - -0.038000620901584625, - -0.03945706784725189, - -0.005736282095313072, - 0.05296698585152626, - -0.029937084764242172, - 0.01153349969536066, - -0.01965814083814621, - -0.009053226560354233, - -0.011235162615776062, - -0.006116161122918129, - -0.02314285933971405, - 0.06143643707036972, - 0.06925277411937714, - -0.05089593306183815, - -0.019367286935448647, - -0.04801476001739502, - 0.03602011501789093, - -0.014578604139387608, - -0.03604239970445633, - 0.003863043151795864, - -0.024723632261157036, - 0.020573832094669342, - -0.009570201858878136, - 0.03774641081690788, - -0.05292463302612305, - 0.006452533416450024, - -0.04205406829714775, - 0.010149180889129639, - -0.018143240362405777, - 0.004127827472984791, - -0.045035213232040405, - 0.02675674669444561, - 0.006915321573615074, - -0.05640145763754845, - 0.004128557629883289, - -0.008709126152098179, - 0.01597878523170948, - -0.0030157528817653656, - -0.023828797042369843, - -0.05701858922839165, - -0.06052121892571449, - 0.05381669104099274, - 0.025057021528482437, - -0.005017107352614403, - 0.016007566824555397, - 0.04422290623188019, - 0.01778363808989525, - -0.044307511299848557, - -0.012402270920574665, - 0.07214607298374176, - 0.0792701318860054, - 0.027898022904992104, - 0.03736334666609764, - -0.01765776053071022, - 0.031058166176080704, - 0.0004254234372638166, - -0.01205131784081459, - -0.004341015126556158, - -0.011568134650588036, - 0.04156339541077614, - 0.013221288099884987, - -0.021749014034867287, - 0.06962836533784866, - -0.017315620556473732, - 0.08870577812194824, - -0.00873498059809208, - -0.011377223767340183, - -0.011595699936151505, - 0.006607932038605213, - -0.017374547198414803, - -0.022839810699224472, - 0.034594908356666565, - -0.019290849566459656, - -0.006121560465544462, - 0.017753586173057556, - -0.07135862112045288, - 0.04130232334136963, - 0.0009869203204289079, - 0.009459913708269596, - -0.03928222879767418, - 0.029846664518117905, - -0.010229891166090965, - 0.018833884969353676, - -0.021240070462226868, - -0.014747458510100842, - 0.030872568488121033, - 0.03410491719841957, - -0.01231492217630148, - 0.005178756080567837, - -0.01042709406465292, - 0.04743482545018196, - -0.039838749915361404, - -0.031018337234854698, - 0.015309013426303864, - -0.004037679638713598, - 0.004998211748898029, - 0.05449990928173065, - 0.01789567433297634, - -0.045838549733161926, - 0.014405546709895134, - -0.07571280747652054, - 0.03758632019162178, - -0.062378477305173874, - 0.022019602358341217, - -0.05938965454697609, - -0.011664186604321003, - -0.018437281250953674, - -0.012576173059642315, - 0.01091400533914566, - 0.00945031363517046, - 0.005958541762083769, - -0.03366720303893089, - -0.015271728858351707, - 0.023048875853419304, - -0.006725663784891367, - 0.0016373167745769024, - 0.00046704502892680466, - -0.03365272656083107, - -0.007359737996011972, - 0.03848809376358986, - -0.0045979744754731655, - 0.02045128308236599, - -0.018190674483776093, - -0.04796736314892769, - -0.037075746804475784, - -0.018071072176098824, - -0.013390590436756611, - 0.01774747297167778, - 0.008547973819077015, - 0.02622542344033718, - -0.03219887614250183, - -0.0200456902384758, - -0.0033232688438147306, - -0.05828697979450226, - -0.015739664435386658, - 0.013859874568879604, - -0.043690990656614304, - -0.007176811341196299, - -0.008432558737695217, - -0.021054115146398544, - 0.002758907387033105, - 0.03890146315097809, - -0.004201301373541355, - -0.006750235799700022, - -0.02212129719555378, - -0.005983707960695028, - 0.0016165575943887234, - 0.011560464277863503, - -0.01283461507409811, - 0.05277907848358154, - -0.03431731462478638, - -0.03150126710534096, - 0.012272159568965435, - -0.043483976274728775, - -0.02715604193508625, - -0.024733325466513634, - -0.018217643722891808, - -0.021643197163939476, - 0.02713724598288536, - 0.04305938631296158, - 0.04261206462979317, - 0.002413937821984291, - -0.018239721655845642, - -0.013090712018311024, - -0.05980465188622475, - 0.011109466664493084, - 0.013083151541650295, - -0.061852823942899704, - 0.01054653525352478, - -0.06278519332408905, - 0.029932409524917603, - 0.039942361414432526, - -0.03127257153391838, - -0.06201130151748657, - -0.009618529118597507, - -7.425264630001038e-05, - 0.031138921156525612, - -0.019810115918517113, - -0.05585085228085518, - 0.06520254164934158, - 0.024034099653363228, - 0.015708057209849358, - 0.050117459148168564, - -0.00744776101782918, - -0.03167617321014404, - 0.0025314788799732924, - 0.0409964881837368, - 0.042706750333309174, - -0.03914619982242584, - 0.022494593635201454, - -0.0411686934530735, - 0.007950341328978539, - -0.0009295025956816971, - -0.04114913567900658, - -0.005723332986235619, - 0.014078161679208279, - 0.06999073922634125, - 0.012080896645784378, - -0.008286076597869396, - 0.05319024249911308, - -0.0015491624362766743, - -0.0018167974194511771, - -0.0024034867528826, - -0.0426836758852005, - -0.02353746071457863, - 0.019222090020775795, - 0.004391263704746962, - 0.016924070194363594, - -0.027421772480010986, - 0.03135634586215019, - 0.030027667060494423, - -0.020796509459614754, - -0.02457408793270588, - -0.001325097749941051, - -0.005604694597423077, - -0.03956734761595726, - 0.021962232887744904, - -0.014150132425129414, - -0.030417246744036674, - -0.008683395572006702, - -0.005982583854347467, - 0.06052049621939659, - 0.06659169495105743, - 0.017101123929023743, - -0.03270954266190529, - -0.047633182257413864, - 0.04904041439294815, - -0.016177913174033165, - 0.009877012111246586, - 0.004116414114832878, - -0.030805347487330437, - 0.0485880970954895, - -0.024089228361845016, - 0.05591115355491638, - -0.0046888794749975204, - -0.033939626067876816, - -0.0206813532859087, - 0.04792835935950279, - -0.047564007341861725, - 0.03640164062380791, - -0.024477144703269005, - 0.013948267325758934, - 0.06275200843811035, - 0.07728910446166992, - 0.013532593846321106, - 0.004243429284542799, - 0.008401074446737766, - -0.02796311117708683, - -0.010176070965826511, - -0.03250659629702568, - 0.029259270057082176, - -0.04701896011829376, - 0.0002866119612008333, - -0.0035398928448557854, - -0.016467221081256866, - -0.039023157209157944, - -0.019849354401230812, - -0.018275268375873566, - -0.021512357518076897, - 0.09221479296684265, - -0.06803987175226212, - 0.08957689255475998, - -0.061851803213357925, - 0.002801343100145459, - 0.03331753984093666, - -0.013032764196395874, - 0.005731076933443546, - 0.07219750434160233, - 0.021996982395648956, - 0.01677960343658924, - -0.021797234192490578, - -0.020772438496351242, - 0.009000342339277267, - 0.015125478617846966, - 0.05084208771586418, - -0.06308974325656891, - -0.03930108994245529, - 0.004811308812350035, - -0.03128044679760933, - 0.007836565375328064, - -0.010927228257060051, - -0.055871956050395966, - 0.05007820576429367, - -0.031110195443034172, - 0.004241985268890858, - 0.020734362304210663, - -0.05976051092147827, - 0.04778436943888664, - -0.04699753597378731, - 0.0722452774643898, - -0.00911264680325985, - 0.03627417981624603, - 0.0171580258756876, - 0.03797684237360954, - 0.004993902053683996, - -0.03650178387761116, - 0.0171328317373991, - 0.01566913165152073, - 0.05370726436376572, - -0.014021750539541245, - -0.0021908965427428484, - 0.02357625775039196, - 0.01654050499200821, - 0.010890313424170017, - 0.0060208262875676155, - -0.07913804799318314, - 0.035021279007196426, - 0.05705339089035988, - 0.005312921479344368, - 0.07994336634874344, - 0.01256528776139021, - 0.05941077321767807, - 0.026312824338674545, - 0.040590472519397736, - 0.00016249150212388486, - -0.018834862858057022, - -0.003587394719943404, - 0.017034705728292465, - -0.020565006881952286, - 0.006719683762639761, - -0.02947266772389412, - -0.05601661279797554, - -0.0600954107940197, - -0.004952855873852968, - 0.005586576648056507, - 0.003577976254746318, - -0.0021878965198993683, - -0.023892691358923912, - -0.006508287973701954, - -0.015506021678447723, - 0.006047308444976807, - 0.03371831402182579, - 0.027723748236894608, - -0.04905441403388977, - -0.013026459142565727, - 0.06604082137346268, - -0.004796930588781834, - -0.00500132841989398, - 0.009556032717227936, - -0.0308690145611763, - 0.060339294373989105, - 0.03652385622262955, - -0.008656700141727924, - -0.0034692573826760054, - 0.026960982009768486, - -0.02149580977857113, - 0.04670587554574013, - -0.08773908764123917, - -0.02298053354024887, - -0.009338187985122204, - -0.07271511852741241, - 0.048206083476543427, - -0.033434897661209106, - 0.025270143523812294, - -0.020232995972037315, - 0.032617174088954926, - 0.06025872379541397, - -0.0409870408475399, - -0.029911495745182037, - 0.0040312763303518295, - 0.06096614524722099, - 0.007751356344670057, - 0.04987076669931412, - 0.0011653571855276823, - -0.008355646394193172, - -0.027534136548638344, - -0.010308800265192986, - -0.05349545180797577, - 0.06088753417134285, - -0.015470409765839577, - -0.007751379162073135, - -0.02721143327653408, - 0.015468763187527657, - -0.03001999855041504, - 0.008844302035868168, - -0.015932371839880943, - -0.023506775498390198, - 0.0425015464425087, - -0.019866015762090683, - -0.015134626999497414, - -0.05903257429599762, - -0.0007920601638033986, - 0.0223627220839262, - 0.02698272466659546, - -0.05972685664892197, - 0.012724273838102818, - -0.06237699091434479, - -0.018834033980965614, - 0.03595362976193428, - 0.006248063407838345, - -0.034634724259376526, - -0.06366340816020966, - 0.04566069319844246, - 0.02451184391975403, - 0.021615460515022278, - -0.016037533059716225, - 0.009565546177327633, - 0.04028503596782684, - -0.031230447813868523, - -0.011746572330594063, - 0.03748490661382675, - -0.015223887749016285, - -0.009446519427001476, - -0.05522184446454048, - 0.02250606380403042, - 0.03194887563586235, - 0.015682894736528397, - -0.024093562737107277, - 0.011318043805658817, - -0.0003446311457082629, - -0.03661378100514412, - 0.004301781300455332, - -0.00024413350911345333, - 0.014156840741634369, - 0.02502865344285965, - -0.02162214182317257, - 0.04329215735197067, - -0.0080551253631711, - 0.02013627253472805, - 0.0008956206729635596, - -0.0438414141535759, - -0.0029399932827800512, - 0.0421270877122879, - 0.025939369574189186, - -0.02011949196457863, - 0.041561905294656754, - -0.03155352547764778, - -0.0316459946334362, - -0.014933913014829159, - 0.03079991415143013, - 0.017882181331515312, - 0.025511685758829117, - -0.019864212721586227, - -0.008703756146132946, - 0.022553488612174988, - -0.09703342616558075, - -0.014624638482928276, - 0.03448106721043587, - 0.008192705921828747, - -0.06994254142045975, - 0.024999642744660378, - -0.017784584313631058, - -0.018767815083265305, - 0.013196418061852455, - 0.0325394831597805, - -0.01587137207388878, - -0.008665663190186024, - 0.013694265857338905, - 0.03144343942403793, - -0.049686919897794724, - -0.0031140276696532965, - -0.08738628029823303, - 0.04882385954260826, - -0.00831123348325491, - 0.03187264874577522, - -0.02911270596086979, - 0.005735242273658514, - 0.02958846092224121, - -0.026828370988368988, - 0.04184429720044136, - 0.03659920021891594, - 0.013125266879796982, - 0.020556125789880753, - -0.021640243008732796, - -0.044204916805028915, - 0.02662266418337822, - -0.040824323892593384, - 0.02497202716767788, - 0.010659066028892994, - -0.01079088356345892, - -0.04107194021344185, - 0.004049273673444986, - -0.06602949649095535, - 0.010830281302332878, - 0.0020915307104587555, - -0.0013070548884570599, - 0.017711516469717026, - -0.008852856233716011, - 0.05698423832654953, - -0.022878462448716164, - 0.018874768167734146, - 0.02949371002614498, - 0.014823920093476772, - -0.004531942307949066, - 0.0065182582475245, - 0.0320458859205246, - -0.0037235443014651537, - 0.06805872172117233, - 0.032741446048021317, - -0.0643078088760376, - 0.022031843662261963, - -0.0305438581854105, - 0.035897932946681976, - -0.02241932787001133, - -0.061598360538482666, - -0.015969673171639442, - 0.0012508367653936148, - -0.050355080515146255, - 0.016142455860972404, - 0.005607514176517725, - 0.03330579027533531, - 0.010560208931565285, - -0.043246183544397354, - 0.0062516843900084496, - -0.03666979819536209, - 0.026078086346387863, - -0.007035407703369856, - 0.03943013399839401, - 0.03862067312002182, - 0.012649385258555412, - -0.005919503979384899, - -0.01943592168390751, - 0.02367492951452732, - 0.006087158340960741, - -0.0512159988284111, - 0.022539883852005005, - 0.0066794115118682384, - 0.007273179013282061, - -0.017144527286291122, - 0.016839638352394104, - 0.027191102504730225, - -0.03326268494129181, - 0.06590717285871506, - -0.002962711500003934, - -0.07598647475242615, - -0.015523041598498821, -] - -policies_query_embedding2 = [ - 1.1489464668557048e-05, - -0.04417964816093445, - -0.013543735258281231, - 0.002646843669936061, - -0.01887943223118782, - -0.010727466084063053, - 0.010153484530746937, - 0.03505353257060051, - 0.002640416845679283, - 0.016271473839879036, - 0.008542485535144806, - -0.041314903646707535, - 0.01286915224045515, - -0.027897773310542107, - -0.023653613403439522, - -0.004843604750931263, - 0.06502345949411392, - -0.11364666372537613, - -0.0010558441281318665, - -0.002368893241509795, - 0.03692156821489334, - 0.002206284087151289, - 0.037736959755420685, - -0.01009274646639824, - 0.0010775597766041756, - -0.03889409080147743, - 0.0010773736285045743, - 0.01980048231780529, - -0.029793180525302887, - -0.010355527512729168, - 0.025571979582309723, - -0.047521915286779404, - 0.0061652157455682755, - 0.021566171199083328, - -0.03666559234261513, - -0.010032066144049168, - 0.03103993460536003, - -0.01567891053855419, - 0.007794647011905909, - -0.0322665311396122, - 0.005515735596418381, - -0.040037140250205994, - -0.0406297966837883, - -0.0023383263032883406, - -0.004302581772208214, - 0.012860193848609924, - -0.022696463391184807, - -0.005832694936543703, - -0.09517011046409607, - -0.06661532074213028, - -0.002523045288398862, - 0.04561278969049454, - 0.07103309780359268, - -0.031060507521033287, - -0.023533904924988747, - -0.016901874914765358, - 0.009511495940387249, - 0.03325766697525978, - -0.04985547810792923, - -0.0013011045521125197, - 0.045025814324617386, - 0.026514215394854546, - -0.026021206751465797, - 0.009772923775017262, - -0.013860255479812622, - 0.06526196748018265, - -0.004199746996164322, - -0.04982037469744682, - -0.028295688331127167, - -0.01903771050274372, - -0.028212614357471466, - -0.01570351980626583, - -0.02774018421769142, - -0.007101997267454863, - -0.006642892956733704, - -0.010274997912347317, - 0.07225783169269562, - -0.01797972060739994, - -0.051479656249284744, - 0.0021423434372991323, - -0.05144917592406273, - -0.013434980995953083, - -0.06836753338575363, - 0.008877214044332504, - -0.05122477188706398, - -0.053496479988098145, - -0.031720519065856934, - -0.026774529367685318, - -0.07948949933052063, - -0.01962014101445675, - -0.028407633304595947, - 0.07512436807155609, - -0.0065933652222156525, - 0.011147618293762207, - -0.028212834149599075, - 0.025125108659267426, - 0.04263653978705406, - 0.03909575194120407, - -0.05295490846037865, - -0.022351328283548355, - 0.004120570607483387, - 0.02031572163105011, - 0.007082875352352858, - -0.040019795298576355, - -0.012000920251011848, - 0.05776089429855347, - 0.0177608672529459, - -0.002005579648539424, - 0.04061266779899597, - 0.043334104120731354, - -0.006321517284959555, - -0.0014862158568575978, - 0.0065847099758684635, - -0.027191162109375, - 0.006672321818768978, - -0.04666347801685333, - -0.08029055595397949, - 0.003386661410331726, - -0.07500205934047699, - -0.04766172170639038, - -0.009060515090823174, - 0.02555493265390396, - 0.09016162902116776, - 0.014216864481568336, - 0.030401045456528664, - 0.01376708410680294, - 0.0394713394343853, - -0.06820525228977203, - 0.028198808431625366, - 0.022309981286525726, - -0.015717053785920143, - -0.012528209015727043, - 0.028997089713811874, - 0.0666099414229393, - -0.024245120584964752, - 0.049451131373643875, - 0.03746980428695679, - -0.07206009328365326, - 0.02122347056865692, - -0.01132186409085989, - 0.005819517187774181, - -0.0005745641537941992, - -0.017243554815649986, - -0.01783209666609764, - -0.01533574890345335, - -0.026931939646601677, - 0.03217088803648949, - 0.04312875494360924, - 0.03636249899864197, - 0.04066362604498863, - -0.022167161107063293, - 0.0792364552617073, - -0.018510060384869576, - 0.027009744197130203, - 0.02188047580420971, - -0.020126109942793846, - -0.016124332323670387, - -0.07692807912826538, - -0.031905438750982285, - 0.008915846236050129, - 0.01963149569928646, - 0.02104085497558117, - -0.012466001324355602, - -0.018818268552422523, - -0.03133704513311386, - -0.017744597047567368, - 0.06747609376907349, - -0.011810868047177792, - 0.03494247421622276, - -0.056223027408123016, - -0.01520370040088892, - -0.08506564795970917, - -0.057824209332466125, - -0.0029911373276263475, - -0.03404572233557701, - 0.0026574810035526752, - -0.047937557101249695, - -0.04159596189856529, - -0.03985041007399559, - -0.05610537528991699, - -0.06080633029341698, - -0.015240334905683994, - 0.011921930126845837, - 0.0016870687250047922, - -0.08828447014093399, - 0.0382922999560833, - -0.06361750513315201, - -0.04874303191900253, - -0.09006042033433914, - -0.00468732975423336, - -0.04136015474796295, - -0.031024668365716934, - -0.11642608046531677, - -0.03804741054773331, - 0.009081859141588211, - 0.0488508939743042, - 0.05959388241171837, - -0.043625324964523315, - -0.0021041384898126125, - 0.0468524806201458, - -0.01195160485804081, - -0.015788234770298004, - -0.005863327067345381, - 0.03430390730500221, - -0.07131640613079071, - -0.047090690582990646, - 0.05509907007217407, - -0.011792562901973724, - -0.05020248517394066, - -0.040194734930992126, - 0.03153885528445244, - 0.04532082751393318, - -0.1476796418428421, - 0.004047122318297625, - -0.019465195015072823, - 0.01324824895709753, - 0.10535216331481934, - -0.0744105651974678, - 0.029518023133277893, - -0.028836848214268684, - 0.006898754741996527, - 0.027300164103507996, - -0.04690819978713989, - 0.022650111466646194, - 0.0020053857006132603, - -0.0181961078196764, - 0.037353090941905975, - 0.009082581847906113, - 0.03067804127931595, - -0.000293438759399578, - -0.029023192822933197, - 0.015544848516583443, - 0.0005384765681810677, - 0.04318104684352875, - 0.012431884184479713, - -0.006393855437636375, - -0.00731752160936594, - -0.013319234363734722, - -0.09229796379804611, - -0.02471965178847313, - -0.02731281891465187, - 0.0822460949420929, - -0.019307367503643036, - -0.06337793916463852, - -0.005808855872601271, - 0.05138656869530678, - -0.06086648628115654, - -0.008418021723628044, - -0.0074536786414682865, - -0.06839632242918015, - -0.014659270644187927, - -0.03666446730494499, - -0.047926850616931915, - 0.040804412215948105, - 0.03776337578892708, - -0.09373703598976135, - -0.0008150787907652557, - -0.025711430236697197, - 0.04452098160982132, - -0.0005632721004076302, - -0.006582108326256275, - -0.019862161949276924, - -0.017114149406552315, - -0.012631848454475403, - 0.025622371584177017, - 0.03864235058426857, - -0.05276263877749443, - 0.005399600602686405, - -0.022486036643385887, - 0.056733060628175735, - -0.04681537672877312, - 0.016780782490968704, - -0.0005745685193687677, - 0.0076257409527897835, - -0.00762772373855114, - -0.05675653740763664, - -0.012343266047537327, - -0.03405306860804558, - -0.0028795755933970213, - -0.026169532909989357, - 0.0031819886062294245, - -0.055128131061792374, - -0.040804363787174225, - 0.0959225669503212, - 0.04777396097779274, - 0.01248854584991932, - -0.0346091128885746, - 0.0012384552974253893, - 0.016274916008114815, - -0.018916521221399307, - 0.011970599181950092, - 0.07506520301103592, - 0.05780462920665741, - 0.0019387968350201845, - 0.06402605026960373, - -0.005488018970936537, - 0.025572506710886955, - -0.047932952642440796, - 0.01740681752562523, - -0.02845127508044243, - -0.011388413608074188, - 0.04607626795768738, - 0.03145420923829079, - -0.043867554515600204, - 0.08072052150964737, - -0.008866746909916401, - 0.06415791809558868, - -0.03814288228750229, - -0.04280978813767433, - 0.03736559674143791, - 0.05040620639920235, - -0.039302051067352295, - 0.010268078185617924, - 0.047697436064481735, - -0.02107059210538864, - 0.009627968072891235, - -0.01775490678846836, - -0.06285437196493149, - 0.06362050026655197, - -0.01878926157951355, - 0.015059037134051323, - -0.023113472387194633, - 0.04212016612291336, - -0.012490317225456238, - 0.014377717860043049, - 0.005142960697412491, - -0.00973766203969717, - 0.06246551126241684, - 0.03843576833605766, - -0.020927011966705322, - -0.033441685140132904, - -0.0022990789730101824, - 0.04013356938958168, - -0.06098434329032898, - -0.03651748597621918, - 0.007979054935276508, - -0.01455061323940754, - -0.00875249132514, - 0.06168532744050026, - 0.016504911705851555, - -0.04544872045516968, - -0.007270110305398703, - -0.05084022134542465, - -0.010963707230985165, - -0.05904154106974602, - 0.01743290014564991, - -0.020297642797231674, - 0.033035121858119965, - -0.01814180426299572, - -0.04515055939555168, - 0.03297523036599159, - -0.020830363035202026, - -0.020806843414902687, - -0.02344667725265026, - 0.0006725008715875447, - -0.03686101734638214, - -0.012858753092586994, - -0.009827155619859695, - 0.033864546567201614, - -0.033192869275808334, - -0.017786452546715736, - 0.00702989986166358, - 0.05421966314315796, - 0.014496312476694584, - -0.005337302573025227, - -0.079369455575943, - -0.039053235203027725, - 0.015828222036361694, - 0.0021103022154420614, - 0.01074950024485588, - -0.012326790019869804, - 0.03158434480428696, - -0.04075644910335541, - -0.009935163892805576, - -0.016160717234015465, - -0.061049576848745346, - -0.04524746909737587, - 0.020469466224312782, - -0.001066140248440206, - -0.03197918459773064, - 0.008812588639557362, - -0.028872298076748848, - -0.0003000098222400993, - -0.0055451588705182076, - 0.04486154764890671, - 0.0006179793854244053, - -0.01180270779877901, - -0.0016177998622879386, - -0.005145156290382147, - -0.04939825460314751, - 0.008156572468578815, - 0.026950005441904068, - -0.04496395215392113, - -0.018535785377025604, - 0.03978446125984192, - -0.04837900027632713, - -0.01902257837355137, - -0.0037644151598215103, - -0.03618226572871208, - 0.01980670914053917, - 0.03061339259147644, - 0.03189018368721008, - 0.047526173293590546, - 0.026180041953921318, - -0.0119468430057168, - -0.0339416079223156, - -0.06880015134811401, - 0.020078599452972412, - 0.02884572744369507, - -0.009989629499614239, - 0.00861865933984518, - -0.04079229757189751, - 0.04214446246623993, - 0.0007846234366297722, - 0.0002016093349084258, - -0.08154798299074173, - 0.021139323711395264, - 0.003933001775294542, - 0.009086069650948048, - -0.027323143556714058, - -0.0027124756015837193, - 0.05645693838596344, - 0.00033277124748565257, - 0.02223312109708786, - 0.030703403055667877, - 0.030566880479454994, - -0.05346106365323067, - -0.008949755690991879, - 0.04021448269486427, - 0.06668312102556229, - -0.014551125466823578, - 0.015177509747445583, - -0.024546785280108452, - 0.02109510265290737, - -0.022253314033150673, - -0.009550293907523155, - -0.03137696906924248, - 0.012763927690684795, - 0.01797163113951683, - 0.002724010031670332, - -0.03852318227291107, - 0.029054317623376846, - -0.02711714245378971, - 0.013056362047791481, - -0.02054792456328869, - 0.014122270047664642, - 0.024160074070096016, - 0.022181052714586258, - -0.006567671895027161, - 0.01557751651853323, - -0.010897778905928135, - 0.01694948785007, - -0.016120465472340584, - -0.012987310066819191, - 0.03852672129869461, - 0.02026708796620369, - -0.020607654005289078, - -0.01659061573445797, - 0.011586613953113556, - -0.025652078911662102, - -0.0032943193800747395, - 0.018773213028907776, - 0.034710392355918884, - -0.00636656116694212, - 0.04586773365736008, - 0.01566656306385994, - 0.004218233283609152, - 0.00754138408228755, - 0.009710167534649372, - 0.019478436559438705, - 0.04841319099068642, - 0.0025055939331650734, - -0.038806233555078506, - 0.010114329867064953, - 0.003944931086152792, - 0.06764235347509384, - 0.020739911124110222, - -0.04059816524386406, - -0.013803824782371521, - 0.02558905817568302, - -0.06397488713264465, - 0.024432538077235222, - -0.009103739634156227, - -0.0011051943292841315, - 0.05122315138578415, - 0.0757622942328453, - 0.03497244045138359, - 0.016767552122473717, - -0.0028931149281561375, - -0.019212031736969948, - -0.0558941513299942, - -0.025089923292398453, - 0.012393583543598652, - -0.05856006219983101, - -0.024423683062195778, - -0.01248327735811472, - -0.0159304179251194, - -0.042598411440849304, - 0.0050742365419864655, - 0.019651951268315315, - 0.03179909661412239, - 0.09451761096715927, - -0.035356760025024414, - 0.055684156715869904, - -0.04297143220901489, - 0.023263053968548775, - -0.005327092949301004, - -0.005430041812360287, - 0.013364868238568306, - 0.03922300413250923, - -0.011905710212886333, - 0.020538918673992157, - -0.021839020773768425, - -0.059818755835294724, - 0.00883548241108656, - 0.01919606514275074, - 0.058570701628923416, - -0.04443622753024101, - -0.02693021669983864, - -0.019972149282693863, - -0.04246845841407776, - 0.04894375428557396, - -0.015407809987664223, - -0.03765471652150154, - 0.08200346678495407, - -0.015487522818148136, - -0.009315523318946362, - 0.04076525196433067, - -0.025237547233700752, - 0.046353187412023544, - -0.019918128848075867, - 0.017333505675196648, - 0.014947325922548771, - -0.01633741892874241, - -0.0037335266824811697, - 0.006114261224865913, - 0.0009970386745408177, - -0.03454785794019699, - 0.025357484817504883, - 0.009000148624181747, - 0.061743926256895065, - -0.024646760895848274, - -0.01514875702559948, - -0.020061038434505463, - 0.018796468153595924, - 0.0025695767253637314, - 0.03416294977068901, - -0.0498264878988266, - 0.03171757608652115, - 0.06739833950996399, - -0.021739469841122627, - 0.05334605649113655, - 0.03286312147974968, - 0.01757230795919895, - -0.04312858358025551, - 0.042531151324510574, - -0.03004857338964939, - 0.017469266429543495, - -0.012114943005144596, - 0.03614463284611702, - 0.005759501829743385, - 0.03334067389369011, - -0.036137230694293976, - -0.060897357761859894, - -0.04707951098680496, - -0.030901234596967697, - 0.012699859216809273, - -0.01225703302770853, - -0.03831500932574272, - -0.011210618540644646, - -0.006945494096726179, - -0.06437350064516068, - 0.03976176306605339, - 0.022079946473240852, - 0.051489293575286865, - -0.04161742702126503, - -0.015967076644301414, - 0.051872823387384415, - -0.014327307231724262, - -0.016479510813951492, - 0.01891450025141239, - -0.04847748950123787, - 0.0764191672205925, - 0.0017369415145367384, - 0.008184406906366348, - 0.010970892384648323, - 0.015331893227994442, - 0.008141906000673771, - 0.05485530197620392, - -0.0754239559173584, - -0.009183076210319996, - 0.0025307766627520323, - -0.05106281861662865, - 0.02369060181081295, - -0.02043209597468376, - -0.0021998852025717497, - 0.0037044764030724764, - 0.03230445832014084, - 0.08176115900278091, - -0.07895228266716003, - -0.035508979111909866, - -0.03414342924952507, - 0.03688323497772217, - 0.0005612073000520468, - 0.018258703872561455, - -0.006317385006695986, - 0.022497259080410004, - -0.021584322676062584, - -0.0021365575958043337, - -0.04004015773534775, - 0.028079496696591377, - -0.03313641622662544, - 0.014522411860525608, - -0.004684946034103632, - -0.0015415854286402464, - -0.07738209515810013, - -0.0076045384630560875, - -0.00017606545588932931, - -0.03870181366801262, - 0.04653894528746605, - -0.009382392279803753, - -0.010714557953178883, - -0.07811527699232101, - -0.003023742698132992, - 0.04492877423763275, - 0.04807121679186821, - -0.04453440010547638, - 0.018598180264234543, - -0.053693950176239014, - 0.0005651914980262518, - 0.024131203070282936, - 0.011158311739563942, - -0.06330019980669022, - -0.013617516495287418, - 0.013436135835945606, - 0.019802534952759743, - 0.006077812053263187, - -0.007399336434900761, - -0.0070304409600794315, - 0.08649568259716034, - -0.06305396556854248, - 0.008890517055988312, - -0.0044074896723032, - 0.024839160963892937, - 0.006484976503998041, - -0.03687465563416481, - 0.03325795382261276, - 0.025409407913684845, - -0.015886075794696808, - -0.039690643548965454, - -0.013540212996304035, - -0.017590465024113655, - -0.0544017069041729, - 0.004020100925117731, - 0.021805167198181152, - 0.027279887348413467, - 0.04075487330555916, - -0.035505544394254684, - 0.02225283347070217, - -0.01775486208498478, - 0.024338139221072197, - 0.017186734825372696, - -0.004054193384945393, - 0.021896496415138245, - 0.03698623552918434, - 0.002409539418295026, - -0.040563687682151794, - 0.025103451684117317, - -0.04608858749270439, - -0.00833041500300169, - 0.02624930813908577, - 0.03879132866859436, - 0.017509160563349724, - -0.007651817984879017, - 0.01164444163441658, - 0.011853894218802452, - 0.03615027666091919, - -0.05702957138419151, - -0.03462577238678932, - 0.050132736563682556, - -0.003220156067982316, - -0.02635158970952034, - 0.007139620371162891, - -0.0008898164960555732, - -0.018587952479720116, - 0.008688241243362427, - 0.020444881170988083, - -0.02641548216342926, - 0.0834483951330185, - 0.01434650830924511, - 0.04376484453678131, - -0.0011908509768545628, - 0.009777247905731201, - -0.13625597953796387, - 0.05653632432222366, - 0.014287661761045456, - 0.027409560978412628, - 0.04181351885199547, - 0.0024189704563468695, - 0.024867961183190346, - -0.0313175730407238, - 0.025688517838716507, - 0.01518984604626894, - 0.0007884484366513789, - 0.012660779990255833, - -0.027535313740372658, - -0.02469700202345848, - 0.03163842484354973, - -0.043047014623880386, - 0.018811671063303947, - 0.03310970589518547, - -0.0001860297634266317, - -0.028571220114827156, - -0.023441867902874947, - -0.03713734820485115, - -0.03950771316885948, - -0.01946307346224785, - -0.0385328084230423, - 0.0654711201786995, - -0.032378461211919785, - 0.010252811945974827, - -0.02723832242190838, - 0.004081403836607933, - -0.005934979300945997, - 0.03747616708278656, - -0.020681727677583694, - -0.01970534957945347, - 0.04276663810014725, - 0.019808653742074966, - 0.0669986680150032, - 0.037935029715299606, - -0.0684472993016243, - 0.04169120267033577, - -0.04295681416988373, - 0.011889652349054813, - -0.017844047397375107, - -0.056302931159734726, - 0.011420484632253647, - 0.019892165437340736, - -0.02913745865225792, - -0.010606188327074051, - 0.017285184934735298, - 0.039103053510189056, - -0.006328175310045481, - -0.033770911395549774, - 0.03475082293152809, - -0.04927230253815651, - 0.038552314043045044, - -0.03748100996017456, - 0.014868093654513359, - 0.04358559474349022, - -0.012105683796107769, - -0.03355974704027176, - -0.03204505518078804, - -0.013308456167578697, - 0.005010353866964579, - -0.02391629107296467, - 0.0060473838821053505, - -0.012235157191753387, - 0.000672084977850318, - 0.013973928056657314, - 0.04044320061802864, - 0.043585196137428284, - 0.011952136643230915, - 0.04131130129098892, - -0.011513923294842243, - -0.036330271512269974, - -0.0019295994425192475, -] diff --git a/retrieval_service/datastore/providers/utils.py b/retrieval_service/datastore/providers/utils.py deleted file mode 100644 index bbed2d055..000000000 --- a/retrieval_service/datastore/providers/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - - -def get_env_var(key: str, desc: str) -> str: - v = os.environ.get(key) - if v is None: - raise ValueError(f"Must set env var {key} to: {desc}") - return v diff --git a/retrieval_service/example-config-alloydb.yml b/retrieval_service/example-config-alloydb.yml deleted file mode 100644 index 8c278ef86..000000000 --- a/retrieval_service/example-config-alloydb.yml +++ /dev/null @@ -1,11 +0,0 @@ -host: 0.0.0.0 -datastore: - # Example for AlloyDB - kind: "alloydb-postgres" - project: "my-project" - region: "my-region" - cluster: "my-cluster" - instance: "my-instance" - database: "my_database" - user: "my-user" - password: "my-password" diff --git a/retrieval_service/example-config-cloudsql.yml b/retrieval_service/example-config-cloudsql.yml deleted file mode 100644 index c489707af..000000000 --- a/retrieval_service/example-config-cloudsql.yml +++ /dev/null @@ -1,10 +0,0 @@ -host: 0.0.0.0 -datastore: - # Example for Cloud SQL - kind: "cloudsql-engine" - project: "my-project" - region: "my-region" - instance: "my-instance" - database: "my_database" - user: "my-user" - password: "my-password" diff --git a/retrieval_service/example-config-spanner.yml b/retrieval_service/example-config-spanner.yml deleted file mode 100644 index 8409147c7..000000000 --- a/retrieval_service/example-config-spanner.yml +++ /dev/null @@ -1,8 +0,0 @@ -host: 0.0.0.0 -datastore: - # Example for Spanner - kind: "spanner-engine" - project: "my-project" - instance: "my-instance" - database: "my_database" - # service_account_key_file: "my-service-account-key/service_accounts_credentials.json" diff --git a/retrieval_service/example-config.yml b/retrieval_service/example-config.yml deleted file mode 100644 index 9d1e89c7f..000000000 --- a/retrieval_service/example-config.yml +++ /dev/null @@ -1,11 +0,0 @@ -host: 0.0.0.0 -# port: 8080 -datastore: - # Example for AlloyDB - kind: "postgres" - host: 127.0.0.1 - # port: 5432 - database: "my_database" - user: "my-user" - password: "my-password" - # clientId: "my-clientId" diff --git a/retrieval_service/models/__init__.py b/retrieval_service/models/__init__.py deleted file mode 100644 index e83b588d0..000000000 --- a/retrieval_service/models/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .models import Airport, Amenity, Flight, Policy, Ticket diff --git a/retrieval_service/models/models.py b/retrieval_service/models/models.py deleted file mode 100644 index 53349a0de..000000000 --- a/retrieval_service/models/models.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ast -import datetime -from typing import Optional - -from pydantic import BaseModel, ConfigDict, field_validator - - -class Airport(BaseModel): - id: int - iata: str - name: str - city: str - country: str - - -class Amenity(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - id: int - name: str - description: str - location: str - terminal: str - category: str - hour: str - sunday_start_hour: Optional[datetime.time] = None - sunday_end_hour: Optional[datetime.time] = None - monday_start_hour: Optional[datetime.time] = None - monday_end_hour: Optional[datetime.time] = None - tuesday_start_hour: Optional[datetime.time] = None - tuesday_end_hour: Optional[datetime.time] = None - wednesday_start_hour: Optional[datetime.time] = None - wednesday_end_hour: Optional[datetime.time] = None - thursday_start_hour: Optional[datetime.time] = None - thursday_end_hour: Optional[datetime.time] = None - friday_start_hour: Optional[datetime.time] = None - friday_end_hour: Optional[datetime.time] = None - saturday_start_hour: Optional[datetime.time] = None - saturday_end_hour: Optional[datetime.time] = None - content: Optional[str] = None - embedding: Optional[list[float]] = None - - @field_validator( - "sunday_start_hour", - "sunday_end_hour", - "monday_start_hour", - "monday_end_hour", - "tuesday_start_hour", - "tuesday_end_hour", - "wednesday_start_hour", - "wednesday_end_hour", - "thursday_start_hour", - "thursday_end_hour", - "friday_start_hour", - "friday_end_hour", - "saturday_start_hour", - "saturday_end_hour", - mode="before", - ) - def replace_none(cls, v): - return v or None - - @field_validator("embedding", mode="before") - def validate(cls, v): - if isinstance(v, str): - v = ast.literal_eval(v) - v = [float(f) for f in v] - return v - - -class Flight(BaseModel): - id: int - airline: str - flight_number: str - departure_airport: str - arrival_airport: str - departure_time: datetime.datetime - arrival_time: datetime.datetime - departure_gate: str - arrival_gate: str - - -class Ticket(BaseModel): - user_id: int - user_name: str - user_email: str - airline: str - flight_number: str - departure_airport: str - arrival_airport: str - departure_time: datetime.datetime - arrival_time: datetime.datetime - - -class Policy(BaseModel): - id: int - content: str - embedding: Optional[list[float]] = None - - @field_validator("embedding", mode="before") - def validate(cls, v): - if isinstance(v, str): - v = ast.literal_eval(v) - v = [float(f) for f in v] - return v diff --git a/retrieval_service/postgres.tests.cloudbuild.yaml b/retrieval_service/postgres.tests.cloudbuild.yaml deleted file mode 100644 index 6d2d2e647..000000000 --- a/retrieval_service/postgres.tests.cloudbuild.yaml +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - - - id: Create database - name: postgres - secretEnv: - - DB_USER # Use built-in env vars for database connection - - DB_PASS - script: | - #!/usr/bin/env bash - export PGUSER=$DB_USER - export PGPASSWORD=$DB_PASS - echo "SELECT 'CREATE DATABASE ${_DATABASE_NAME}' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '${_DATABASE_NAME}')\gexec" | psql -h ${_DATABASE_HOST} - psql -h ${_DATABASE_HOST} -d ${_DATABASE_NAME} -c 'CREATE EXTENSION IF NOT EXISTS vector;' - - - id: Update config - name: python:3.11 - dir: retrieval_service - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - # Create config - cp example-config.yml config.yml - sed -i "s/127.0.0.1/${_DATABASE_HOST}/g" config.yml - sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml - sed -i "s/my-user/$$DB_USER/g" config.yml - sed -i "s/my-password/$$DB_PASS/g" config.yml - - - id: Run Alloy DB integration tests - name: python:3.11 - dir: retrieval_service - env: # Set env var expected by tests - - "DB_NAME=${_DATABASE_NAME}" - - "DB_HOST=${_DATABASE_HOST}" - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - python -m pytest --cov=datastore.providers.postgres --cov-config=coverage/.postgres-coveragerc datastore/providers/postgres_test.py - - - id: Clean database - name: postgres - secretEnv: - - DB_USER - - DB_PASS - script: | - #!/usr/bin/env bash - export PGUSER=$DB_USER - export PGPASSWORD=$DB_PASS - psql -h ${_DATABASE_HOST} -c "DROP DATABASE IF EXISTS ${_DATABASE_NAME};" - -substitutions: - _DATABASE_NAME: test_${SHORT_SHA} - _DATABASE_USER: postgres - _DATABASE_HOST: 127.0.0.1 - -availableSecrets: - secretManager: - - versionName: projects/$PROJECT_ID/secrets/alloy_db_user/versions/latest - env: DB_USER - - versionName: projects/$PROJECT_ID/secrets/alloy_db_pass/versions/latest - env: DB_PASS -options: - automapSubstitutions: true - substitutionOption: 'ALLOW_LOOSE' - dynamic_substitutions: true - pool: - name: projects/$PROJECT_ID/locations/us-central1/workerPools/alloy-private-pool # Necessary for VPC network connection diff --git a/retrieval_service/pyproject.toml b/retrieval_service/pyproject.toml deleted file mode 100644 index d4898d789..000000000 --- a/retrieval_service/pyproject.toml +++ /dev/null @@ -1,14 +0,0 @@ -[tool.isort] -profile = "black" - -[tool.mypy] -python_version = "3.11" -warn_unused_configs = true - -[[tool.mypy.overrides]] -module = ["pgvector.asyncpg"] -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = ["sqlparse"] -ignore_missing_imports = true diff --git a/retrieval_service/requirements-test.txt b/retrieval_service/requirements-test.txt deleted file mode 100644 index 0a8177cac..000000000 --- a/retrieval_service/requirements-test.txt +++ /dev/null @@ -1,10 +0,0 @@ -asyncpg-stubs==0.30.0 -black==25.1.0 -httpx==0.27.2 -isort==6.0.0 -mypy==1.11.2 -pytest-asyncio==0.24.0 -pytest==8.3.3 -types-PyYAML==6.0.12.20240917 -csv-diff==1.2 -pytest-cov==6.0.0 diff --git a/retrieval_service/requirements.txt b/retrieval_service/requirements.txt deleted file mode 100644 index 07b163d0a..000000000 --- a/retrieval_service/requirements.txt +++ /dev/null @@ -1,23 +0,0 @@ -asyncpg==0.30.0 -fastapi==0.115.0 -google-auth==2.35.0 -google-cloud-firestore==2.19.0 -google-cloud-aiplatform==1.72.0 -google-cloud-spanner==3.49.1 -langchain-core==0.3.18 -pgvector==0.3.5 -pydantic==2.9.0 -uvicorn[standard]==0.31.0 -cloud-sql-python-connector==1.12.1 -google-cloud-alloydb-connector[asyncpg]==1.4.0 -sqlalchemy[asyncio]==2.0.36 -pandas==2.2.3 -pandas-stubs==2.2.2.240807 -langchain-text-splitters==0.3.0 -langchain-google-vertexai==2.0.7 -asyncio==3.4.3 -datetime==5.5 -pymysql==1.1.1 -types-PyMySQL==1.1.0.20240524 -neo4j==5.26.0 -sqlparse==0.5.2 diff --git a/retrieval_service/run_app.py b/retrieval_service/run_app.py deleted file mode 100644 index 832f858b6..000000000 --- a/retrieval_service/run_app.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import asyncio - -import uvicorn - -from app import init_app, parse_config - - -async def main(): - # Set up argument parsing - parser = argparse.ArgumentParser(description="Run the FastAPI application") - parser.add_argument("--reload", action="store_true", help="Enable auto-reload") - args = parser.parse_args() - - cfg = parse_config("./config.yml") - app = init_app(cfg) - if app is None: - raise TypeError("app not instantiated") - server = uvicorn.Server( - uvicorn.Config( - app, host=str(cfg.host), port=cfg.port, log_level="info", reload=args.reload - ) - ) - await server.serve() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/retrieval_service/run_database_export.py b/retrieval_service/run_database_export.py deleted file mode 100644 index aa7d05a65..000000000 --- a/retrieval_service/run_database_export.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import datastore -from app import parse_config - - -async def main(): - cfg = parse_config("config.yml") - ds = await datastore.create(cfg.datastore) - - airports, amenities, flights, policies = await ds.export_data() - - await ds.close() - - airports_new_path = "../data/airport_dataset.csv.new" - amenities_new_path = "../data/amenity_dataset.csv.new" - flights_new_path = "../data/flights_dataset.csv.new" - policies_new_path = "../data/cymbalair_policy.csv.new" - - await ds.export_dataset( - airports, - amenities, - flights, - policies, - airports_new_path, - amenities_new_path, - flights_new_path, - policies_new_path, - ) - - print("database export done.") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/retrieval_service/run_database_init.py b/retrieval_service/run_database_init.py deleted file mode 100644 index 79ac8e592..000000000 --- a/retrieval_service/run_database_init.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import datastore -from app import parse_config - - -async def main() -> None: - airports_ds_path = "../data/airport_dataset.csv" - amenities_ds_path = "../data/amenity_dataset.csv" - flights_ds_path = "../data/flights_dataset.csv" - policies_ds_path = "../data/cymbalair_policy.csv" - - cfg = parse_config("config.yml") - ds = await datastore.create(cfg.datastore) - airports, amenities, flights, policies = await ds.load_dataset( - airports_ds_path, amenities_ds_path, flights_ds_path, policies_ds_path - ) - await ds.initialize_data(airports, amenities, flights, policies) - await ds.close() - - print("database init done.") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/retrieval_service/run_generate_embeddings.py b/retrieval_service/run_generate_embeddings.py deleted file mode 100644 index c1bf477eb..000000000 --- a/retrieval_service/run_generate_embeddings.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import csv - -from langchain_google_vertexai import VertexAIEmbeddings - -import models -from app import EMBEDDING_MODEL_NAME - - -async def main() -> None: - embed_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME) - - amenities: list[models.Amenity] = [] - with open("../data/amenity_dataset.csv", "r") as f: - reader = csv.DictReader(f, delimiter=",") - for line in reader: - amenity = models.Amenity.model_validate(line) - if amenity.content: - amenity.embedding = embed_service.embed_query(amenity.content) - amenities.append(amenity) - - policies: list[models.Policy] = [] - with open("../data/cymbalair_policy.csv", "r") as f: - reader = csv.DictReader(f, delimiter=",") - for line in reader: - policy = models.Policy.model_validate(line) - if policy.content: - policy.embedding = embed_service.embed_query(policy.content) - policies.append(policy) - - print("Completed embedding generation.") - - with open("../data/amenity_dataset.csv.new", "w") as f: - col_names = [ - "id", - "name", - "description", - "location", - "terminal", - "category", - "hour", - "sunday_start_hour", - "sunday_end_hour", - "monday_start_hour", - "monday_end_hour", - "tuesday_start_hour", - "tuesday_end_hour", - "wednesday_start_hour", - "wednesday_end_hour", - "thursday_start_hour", - "thursday_end_hour", - "friday_start_hour", - "friday_end_hour", - "saturday_start_hour", - "saturday_end_hour", - "content", - "embedding", - ] - writer = csv.DictWriter(f, col_names, delimiter=",") - writer.writeheader() - for amenity in amenities: - writer.writerow(amenity.model_dump()) - - with open("../data/cymbalair_policy.csv.new", "w") as f: - col_names = [ - "id", - "content", - "embedding", - ] - writer = csv.DictWriter(f, col_names, delimiter=",") - writer.writeheader() - for policy in policies: - writer.writerow(policy.model_dump()) - - print("Wrote data to CSV.") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/retrieval_service/run_generate_policy_dataset.py b/retrieval_service/run_generate_policy_dataset.py deleted file mode 100644 index 42e8d0661..000000000 --- a/retrieval_service/run_generate_policy_dataset.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import pandas as pd -from langchain_google_vertexai import VertexAIEmbeddings -from langchain_text_splitters import ( - MarkdownHeaderTextSplitter, - RecursiveCharacterTextSplitter, -) - -from app import EMBEDDING_MODEL_NAME - - -def main() -> None: - policies_ds_path = "../data/cymbalair_policy.csv" - - chunked = text_split(_POLICY) - data_embeddings = vectorize(chunked) - data_embeddings.to_csv(policies_ds_path, index=True, index_label="id") - - print("Done generating policy dataset.") - - -def text_split(data): - headers_to_split_on = [("#", "Header 1"), ("##", "Header 2")] - markdown_splitter = MarkdownHeaderTextSplitter( - headers_to_split_on=headers_to_split_on, strip_headers=False - ) - md_header_splits = markdown_splitter.split_text(data) - - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=30, - length_function=len, - ) - splits = text_splitter.split_documents(md_header_splits) - - chunked = [{"content": s.page_content} for s in splits] - return chunked - - -def vectorize(chunked): - embed_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME) - - def retry_with_backoff(func, *args, retry_delay=5, backoff_factor=2, **kwargs): - max_attempts = 3 - retries = 0 - for i in range(max_attempts): - try: - return func(*args, **kwargs) - except Exception as e: - print(f"error: {e}") - retries += 1 - wait = retry_delay * (backoff_factor**retries) - print(f"Retry after waiting for {wait} seconds...") - time.sleep(wait) - - batch_size = 5 - for i in range(0, len(chunked), batch_size): - request = [x["content"] for x in chunked[i : i + batch_size]] - response = retry_with_backoff(embed_service.embed_documents, request) - # Store the retrieved vector embeddings for each chunk back. - for x, e in zip(chunked[i : i + batch_size], response): - x["embedding"] = e - - data_embeddings = pd.DataFrame(chunked) - data_embeddings.head() - return data_embeddings - - -_POLICY = """# Cymbal Air: Passenger Policy -## Ticket Purchase and Changes -Types of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased. -Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee. -Cancellations: Fully refundable tickets can be canceled at any time until 60 minutes prior to scheduled departure for a full refund. For non-refundable tickets, there will be no cost associated with cancellation within 24 hours of ticket purchase. After 24 hours of ticket purchase, there will be a $200 fee for ticket cancellation. Cancellations less than 24 hours before scheduled departure receive no refund. -Payment of refunds: refunds for fully refundable tickets will be returned to the original method of payment within 3-5 business days. Refunds for non-refundable tickets, less fees, will be refunded in the form of a trip credit which can be applied to future Cymbal Air flights. -## Baggage -Checked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted. -Carry-on Baggage: Passengers are allowed one carry-on bag and one personal item. These items must meet size and weight restrictions. -Liability: Cymbal Air assumes limited liability for lost, damaged, or delayed baggage. Passengers are encouraged to purchase travel insurance for additional protection. -## Check-in and Boarding -Check-in: Passengers are advised to check in online or at the airport kiosk within the specified timeframes before departure. Check-in deadlines are 1 hour prior to departure time. -Boarding: Boarding will begin approximately 30 minutes prior to departure. Passengers must present a valid boarding pass and government-issued ID. -Gate Closure: Boarding gates close 10 minutes prior to departure. Late passengers may not be permitted to board. -## Special Assistance -Passengers with Disabilities: Cymbal Air is committed to providing accommodations for passengers with disabilities. Please contact us at least 48 hours before departure to arrange assistance. -Unaccompanied Minors: We offer an unaccompanied minor program for children traveling alone. Fees apply. Contact us for details. -Traveling with Pets: Pets may be allowed in the cabin or as checked baggage depending on size and breed. Fees and restrictions apply. Please consult our website. -## Overbooking -In the rare event of an overbooked flight, Cymbal Air will first seek volunteers to give up their seats in exchange for compensation. If insufficient volunteers are found, passengers may be denied boarding involuntarily in accordance with our overbooking policy. -## Flight Delays and Cancellations -Cymbal Air strives to maintain on-time performance, but disruptions due to weather, mechanical issues, or other events may occur. In the event of delays or cancellations: -Rebooking: We will make reasonable efforts to rebook affected passengers on the next available flight. -Compensation: Compensation may be provided in certain situations as outlined by our policies and regulations. -""" - -if __name__ == "__main__": - main() diff --git a/retrieval_service/spanner-gsql.tests.cloudbuild.yaml b/retrieval_service/spanner-gsql.tests.cloudbuild.yaml deleted file mode 100644 index 2dbc3f36b..000000000 --- a/retrieval_service/spanner-gsql.tests.cloudbuild.yaml +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: - - id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - - - id: Update config - name: python:3.11 - dir: retrieval_service - script: | - #!/usr/bin/env bash - # Create config - cp example-config-spanner.yml config.yml - sed -i "s/spanner-engine/spanner-gsql/g" config.yml - sed -i "s/my-project/$PROJECT_ID/g" config.yml - sed -i "s/my-instance/${_SPANNER_INSTANCE}/g" config.yml - sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml - - - id: Run Spanner with gsql integration tests - name: python:3.11 - dir: retrieval_service - env: # Set env var expected by tests - - "DB_PROJECT=$PROJECT_ID" - - "DB_INSTANCE=${_SPANNER_INSTANCE}" - - "DB_NAME=${_DATABASE_NAME}" - script: | - #!/usr/bin/env bash - python -m pytest datastore/providers/spanner_gsql_test.py - -substitutions: - _DATABASE_NAME: test_${SHORT_SHA} - _SPANNER_INSTANCE: "my-spanner-gsql-instance" - -options: - automapSubstitutions: true - substitutionOption: 'ALLOW_LOOSE' - dynamic_substitutions: true diff --git a/retrieval_service/spanner-pg.tests.cloudbuild.yaml b/retrieval_service/spanner-pg.tests.cloudbuild.yaml deleted file mode 100644 index 651806893..000000000 --- a/retrieval_service/spanner-pg.tests.cloudbuild.yaml +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -steps: -- id: Install dependencies - name: python:3.11 - dir: retrieval_service - script: pip install -r requirements.txt -r requirements-test.txt --user - -- id: Update config - name: python:3.11 - dir: retrieval_service - script: | - #!/usr/bin/env bash - # Create config - cp example-config-spanner.yml config.yml - sed -i "s/spanner-engine/spanner-postgres/g" config.yml - sed -i "s/my-project/$PROJECT_ID/g" config.yml - sed -i "s/my-instance/${_SPANNER_INSTANCE}/g" config.yml - sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml - -- id: Run Spanner with postgres dialect integration tests - name: python:3.11 - dir: retrieval_service - env: # Set env var expected by tests - - "DB_PROJECT=$PROJECT_ID" - - "DB_INSTANCE=${_SPANNER_INSTANCE}" - - "DB_NAME=${_DATABASE_NAME}-pg" - script: | - #!/usr/bin/env bash - python -m pytest datastore/providers/spanner_postgres_test.py - -substitutions: - _DATABASE_NAME: test_${SHORT_SHA} - _SPANNER_INSTANCE: "my-spanner-pg-instance" - -options: - automapSubstitutions: true - substitutionOption: 'ALLOW_LOOSE' - dynamic_substitutions: true diff --git a/llm_demo/run_app.py b/run_app.py similarity index 85% rename from llm_demo/run_app.py rename to run_app.py index cdca499a1..66591b58b 100644 --- a/llm_demo/run_app.py +++ b/run_app.py @@ -24,12 +24,9 @@ async def main(): PORT = int(os.getenv("PORT", default=8081)) HOST = os.getenv("HOST", default="0.0.0.0") - ORCHESTRATION_TYPE = os.getenv("ORCHESTRATION_TYPE", default="langchain-tools") CLIENT_ID = os.getenv("CLIENT_ID") MIDDLEWARE_SECRET = os.getenv("MIDDLEWARE_SECRET", default="this is a secret") - app = init_app( - ORCHESTRATION_TYPE, client_id=CLIENT_ID, middleware_secret=MIDDLEWARE_SECRET - ) + app = init_app(client_id=CLIENT_ID, middleware_secret=MIDDLEWARE_SECRET) if app is None: raise TypeError("app not instantiated") server = uvicorn.Server(uvicorn.Config(app, host=HOST, port=PORT, log_level="info")) diff --git a/llm_demo/run_evaluation.py b/run_evaluation.py similarity index 94% rename from llm_demo/run_evaluation.py rename to run_evaluation.py index ddde259d7..f24087092 100644 --- a/llm_demo/run_evaluation.py +++ b/run_evaluation.py @@ -26,7 +26,7 @@ goldens, run_llm_for_eval, ) -from orchestrator import createOrchestrator +from orchestrator import Orchestrator def export_metrics_table_csv(retrieval: pd.DataFrame, response: pd.DataFrame): @@ -48,7 +48,6 @@ async def main(): USER_ID_TOKEN = os.getenv("USER_ID_TOKEN", default=None) CLIENT_ID = os.getenv("CLIENT_ID", default="") - ORCHESTRATION_TYPE = os.getenv("ORCHESTRATION_TYPE", default="langchain-tools") EXPORT_CSV = bool(os.getenv("EXPORT_CSV", default=False)) RETRIEVAL_EXPERIMENT_NAME = os.getenv( "RETRIEVAL_EXPERIMENT_NAME", default="retrieval-phase-eval" @@ -58,7 +57,7 @@ async def main(): ) # Prepare orchestrator and session - orc = createOrchestrator(ORCHESTRATION_TYPE) + orc = Orchestrator() session_id = str(uuid.uuid4()) session = {"uuid": session_id} await orc.user_session_create(session) diff --git a/llm_demo/static/favicon.png b/static/favicon.png similarity index 100% rename from llm_demo/static/favicon.png rename to static/favicon.png diff --git a/llm_demo/static/index.css b/static/index.css similarity index 99% rename from llm_demo/static/index.css rename to static/index.css index fac6d1c52..0fb21960c 100644 --- a/llm_demo/static/index.css +++ b/static/index.css @@ -363,7 +363,6 @@ div.chat-wrapper div.chat-content div#loader-container span { #trace { position: absolute; width: 400px; - min-height: 200px; max-height: 400px; background: #f8f8f8; padding: 10px; diff --git a/llm_demo/static/index.js b/static/index.js similarity index 97% rename from llm_demo/static/index.js rename to static/index.js index 1cd0bd82b..368181406 100644 --- a/llm_demo/static/index.js +++ b/static/index.js @@ -262,10 +262,14 @@ async function cancelTicket(id) { 'Content-Type': 'application/json' } }); + logMessage("human", "I changed my mind.") + removeTicketChoices(id); + if (response.ok) { - logMessage("human", "I changed my mind.") - removeTicketChoices(id); - logMessage("ai", 'Booking declined. What else can I help you with?'); + logMessage("ai", await response.text()); + } else { + console.error(await response.text()) + logMessage("ai", "Sorry, something went wrong. 😢") } } diff --git a/llm_demo/static/logo-header.png b/static/logo-header.png similarity index 100% rename from llm_demo/static/logo-header.png rename to static/logo-header.png diff --git a/llm_demo/static/logo.png b/static/logo.png similarity index 100% rename from llm_demo/static/logo.png rename to static/logo.png diff --git a/llm_demo/static/trace.js b/static/trace.js similarity index 95% rename from llm_demo/static/trace.js rename to static/trace.js index ee75f9fb6..9551b1773 100644 --- a/llm_demo/static/trace.js +++ b/static/trace.js @@ -22,8 +22,10 @@ export function create_trace(toolcalls) { let toolcall = toolcalls[i]; trace += trace_section_title(toolcall.tool_call_id); - trace += trace_header("SQL Executed:"); - trace += trace_sql(toolcall.sql); + if (toolcall.sql) { + trace += trace_header("SQL Executed:"); + trace += trace_sql(toolcall.sql); + } trace += trace_header("Results:"); trace += trace_results(toolcall.results); diff --git a/llm_demo/templates/index.html b/templates/index.html similarity index 100% rename from llm_demo/templates/index.html rename to templates/index.html diff --git a/tools.yml b/tools.yml new file mode 100644 index 000000000..449fa54ad --- /dev/null +++ b/tools.yml @@ -0,0 +1,233 @@ +sources: + my-pg-instance: + kind: cloud-sql-postgres + project: retrieval-app-testing + region: us-central1 + instance: my-cloudsql-pg-instance + database: assistantdemo + user: postgres + password: postgres +authServices: + my_google_service: + kind: google + clientId: 706535509072-qa5v22ur8ik8o513b0538ufo0ne9jfn5.apps.googleusercontent.com +tools: + search_airports: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to list all airports matching search criteria. + Takes at least one of country, city, name, or all and returns all matching airports. + The agent can decide to return the results directly to the user. + parameters: + - name: country + type: string + description: Country + default: "" + - name: city + type: string + description: City + default: "" + - name: name + type: string + description: Airport name + default: "" + statement: | + SELECT * FROM airports + WHERE (CAST($1 AS TEXT) = '' OR country ILIKE $1) + AND (CAST($2 AS TEXT) = '' OR city ILIKE $2) + AND (CAST($3 AS TEXT) = '' OR name ILIKE '%' || $3 || '%') + LIMIT 10 + list_flights: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to list flights information matching search criteria. + Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. + If 3-letter iata code is not provided for departure_airport or arrival_airport, use `search_airports` tool to get iata code information. + Do NOT guess a date, ask user for date input if it is not given. Date must be in the following format: YYYY-MM-DD. + The agent can decide to return the results directly to the user. + parameters: + - name: departure_airport + type: string + description: Departure airport 3-letter code + default: "" + - name: arrival_airport + type: string + description: Arrival airport 3-letter code + default: "" + - name: date + type: string + description: Date of flight departure + statement: | + SELECT * FROM flights + WHERE (CAST($1 AS TEXT) = '' OR departure_airport ILIKE $1) + AND (CAST($2 AS TEXT) = '' OR arrival_airport ILIKE $2) + AND departure_time >= CAST($3 AS timestamp) + AND departure_time < CAST($3 AS timestamp) + interval '1 day' + LIMIT 10 + search_flights_by_number: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to get information for a specific flight. + Takes an airline code and flight number and returns info on the flight. + Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number. + A airline code is a code for an airline service consisting of two-character + airline designator and followed by flight number, which is 1 to 4 digit number. + For example, if given CY 0123, the airline is "CY", and flight_number is "123". + Another example for this is DL 1234, the airline is "DL", and flight_number is "1234". + If the tool returns more than one option choose the date closest to the current date. + parameters: + - name: airline + type: string + description: Airline unique 2 letter identifier + - name: flight_number + type: string + description: 1 to 4 digit number + statement: | + SELECT * FROM flights + WHERE airline = $1 + AND flight_number = $2 + LIMIT 10 + search_amenities: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to search amenities by name or to recommended airport amenities at SFO. + If user provides flight info, use 'search_flights_by_number' tool + first to get gate info and location. + Only recommend amenities that are returned by this query. + Find amenities close to the user by matching the terminal and then comparing + the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 + B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. + parameters: + - name: query + type: string + description: Search query + statement: | + SELECT name, description, location, terminal, category, hour + FROM amenities + WHERE (embedding <=> $1) < 0.5 + ORDER BY (embedding <=> $1) + LIMIT 5 + search_policies: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to search for cymbal air passenger policy. + Policy that are listed is unchangeable. + You will not answer any questions outside of the policy given. + Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. + parameters: + - name: query + type: string + description: Search query + statement: | + SELECT content + FROM policies + WHERE (embedding <=> $1) < 0.5 + ORDER BY (embedding <=> $1) + LIMIT 5 + validate_ticket: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to validate ticket before booking. + parameters: + - name: airline + type: string + description: Airline unique 2 letter identifier + - name: flight_number + type: string + description: 1 to 4 digit number + - name: departure_airport + type: string + description: Departure airport 3-letter code + - name: departure_time + type: string + description: Flight departure datetime + statement: | + SELECT * FROM flights + WHERE airline ILIKE $1 + AND flight_number ILIKE $2 + AND departure_airport ILIKE $3 + AND departure_time = $4 + insert_ticket: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to book a flight ticket for the user. + parameters: + - name: user_id + type: string + description: User ID of the logged in user. + authServices: + - name: my_google_service + field: sub + - name: user_name + type: string + description: Name of the logged in user. + authServices: + - name: my_google_service + field: name + - name: user_email + type: string + description: Email ID of the logged in user. + authServices: + - name: my_google_service + field: email + - name: airline + type: string + description: Airline unique 2 letter identifier + - name: flight_number + type: string + description: 1 to 4 digit number + - name: departure_airport + type: string + description: Departure airport 3-letter code + - name: departure_time + type: string + description: Flight departure datetime + - name: arrival_airport + type: string + description: Arrival airport 3-letter code + - name: arrival_time + type: string + description: Flight arrival datetime + statement: | + INSERT INTO tickets ( + user_id, + user_name, + user_email, + airline, + flight_number, + departure_airport, + departure_time, + arrival_airport, + arrival_time + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9); + list_tickets: + kind: postgres-sql + source: my-pg-instance + description: | + Use this tool to list a user's flight tickets. + parameters: + - name: user_id + type: string + description: User ID of the logged in user. + authServices: + - name: my_google_service + field: sub + statement: | + SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets + WHERE user_id = $1 +toolsets: + cymbal_air: + - search_airports + - list_flights + - search_flights_by_number + - search_amenities + - search_policies + - insert_ticket + - list_tickets