diff --git a/orchestrator/__init__.py b/agent/__init__.py similarity index 89% rename from orchestrator/__init__.py rename to agent/__init__.py index f357920e9..05ba37457 100644 --- a/orchestrator/__init__.py +++ b/agent/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .orchestrator import Orchestrator +from .agent import Agent -__ALL__ = ["Orchestrator"] +__ALL__ = ["Agent"] diff --git a/orchestrator/orchestrator.py b/agent/agent.py similarity index 72% rename from orchestrator/orchestrator.py rename to agent/agent.py index 8a926a592..45d8b0977 100644 --- a/orchestrator/orchestrator.py +++ b/agent/agent.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,20 +16,15 @@ import os import uuid from datetime import datetime -from typing import Annotated, Any, Dict, List, Literal, Optional, Sequence, TypedDict +from typing import Any, Dict, List, Optional, Sequence -from aiohttp import ClientSession, TCPConnector -from fastapi import HTTPException from langchain.globals import set_verbose # type: ignore from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableConfig, RunnableLambda -from langchain_core.tools import StructuredTool +from langchain_core.prompts import ChatPromptTemplate from langgraph.checkpoint.base import empty_checkpoint from langgraph.checkpoint.memory import MemorySaver from pytz import timezone -from ..orchestrator import BaseOrchestrator, classproperty from .react_graph import create_graph from .tools import initialize_tools @@ -41,13 +36,12 @@ } -class LangGraphOrchestrator(BaseOrchestrator): +class Agent: MODEL = "gemini-2.0-flash-001" _user_sessions: Dict[str, str] # aiohttp context connector = None - client: Optional[ClientSession] = None def __init__(self): self._user_sessions = {} @@ -72,14 +66,19 @@ async def user_session_decline_ticket(self, uuid: str) -> dict[str, Any]: async def user_session_create(self, session: dict[str, Any]): """Create and load an agent executor with tools and LLM.""" - client = await self.create_client_session() if self._langgraph_app is None: print("Initializing graph..") - tools = await initialize_tools(client) - prompt = self.create_prompt_template(tools) + tools, insert_ticket, validate_ticket = await initialize_tools() + prompt = self.create_prompt_template() checkpointer = MemorySaver() langgraph_app = await create_graph( - tools, checkpointer, prompt, self.MODEL, client, DEBUG + tools, + insert_ticket, + validate_ticket, + checkpointer, + prompt, + self.MODEL, + DEBUG, ) self._checkpointer = checkpointer self._langgraph_app = langgraph_app @@ -95,7 +94,6 @@ async def user_session_create(self, session: dict[str, Any]): config = self.get_config(session_id) self._langgraph_app.update_state(config, {"messages": history}) self._user_sessions[session_id] = "" - self.client = client async def user_session_invoke( self, uuid: str, user_prompt: Optional[str] @@ -106,10 +104,7 @@ async def user_session_invoke( ) if user_prompt: user_query = [HumanMessage(content=user_prompt)] - app_input = { - "messages": user_query, - "user_id_token": self.get_user_id_token(uuid), - } + app_input = {"messages": user_query} else: app_input = None final_state = await self._langgraph_app.ainvoke( @@ -165,45 +160,18 @@ def user_session_reset(self, session: dict[str, Any], uuid: str): # Update state with message history self._langgraph_app.update_state(config, {"messages": history}) - def get_user_session(self, uuid: str): - raise NotImplementedError("Irrelevant to LangGraph.") - def set_user_session_header(self, uuid: str, user_id_token: str): self._user_sessions[uuid] = user_id_token def get_user_id_token(self, uuid: str) -> Optional[str]: return self._user_sessions.get(uuid) - async def get_connector(self) -> TCPConnector: - if self.connector is None: - self.connector = TCPConnector(limit=100) - return self.connector - - async def create_client_session(self) -> ClientSession: - return ClientSession( - connector=await self.get_connector(), - connector_owner=False, - headers={}, - raise_for_status=True, - ) - - def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate: - # Create new prompt template - tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] - ) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = FORMAT_INSTRUCTIONS.format( - tool_names=tool_names, - ) + def create_prompt_template(self) -> ChatPromptTemplate: current_datetime = "Today's date and current time is {cur_datetime}." template = "\n\n".join( [ PREFIX, current_datetime, - TOOLS_PREFIX, - tool_strings, - format_instructions, SUFFIX, ] ) @@ -242,7 +210,15 @@ def get_base_history(self, session: dict[str, Any]): return BASE_HISTORY def get_config(self, uuid: str): - return {"configurable": {"thread_id": uuid, "checkpoint_ns": ""}} + return { + "configurable": { + "thread_id": uuid, + "auth_token_getters": { + "my_google_service": lambda: self.get_user_id_token(uuid) + }, + "checkpoint_ns": "", + }, + } async def user_session_signout(self, uuid: str): checkpoint = empty_checkpoint() @@ -252,10 +228,6 @@ async def user_session_signout(self, uuid: str): ) del self._user_sessions[uuid] - async def close_clients(self): - if self.client: - await self.client.close() - PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. @@ -268,43 +240,9 @@ async def close_clients(self): require passing results from one query to another. Using the latest AI models, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. The assistant should -not answer questions about other peoples information for privacy reasons. +not answer questions about other people's information for privacy reasons. Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air -as well as ammenities of San Francisco Airport.""" - -TOOLS_PREFIX = """ -TOOLS: ------- -Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are: - -""" - -FORMAT_INSTRUCTIONS = """ -When responding, please output a response in one of two formats: - -** Option 1:** -Use this is you want to use a tool. -Markdown code snippet formatted in the following schema: -```json -{{{{ - "action": string, \ The action to take. Must be one of {tool_names} - "action_input": string \ The input to the action -}}}} -``` - -**Option 2:** -Use this if you want to respond directly to the human. -Markdown code snippet formatted following schema: -```json -{{{{ - "action": "Final Answer", - "action_input": string \ You should put what you want to return to user here -}}}} -``` -""" - -SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate. - -Remember to respond with a markdown code snippet of a json a single action, and NOTHING else. -""" +as well as amenities of San Francisco Airport.""" + +SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate.""" diff --git a/orchestrator/react_graph.py b/agent/react_graph.py similarity index 69% rename from orchestrator/react_graph.py rename to agent/react_graph.py index b05134f1e..e86eda0be 100644 --- a/orchestrator/react_graph.py +++ b/agent/react_graph.py @@ -16,13 +16,10 @@ import uuid from typing import Annotated, Literal, Sequence, TypedDict -from aiohttp import ClientSession from langchain_core.messages import ( AIMessage, BaseMessage, - HumanMessage, ToolCall, - ToolMessage, ) from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import RunnableConfig, RunnableLambda @@ -30,40 +27,36 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages -from langgraph.managed import IsLastStep +from langgraph.prebuilt import ToolNode +from toolbox_langchain import ToolboxTool -from .tool_node import ToolNode -from .tools import ( - TicketInfo, - get_confirmation_needing_tools, - insert_ticket, - validate_ticket, -) +from .tools import get_confirmation_needing_tools class UserState(TypedDict): """ - State with messages and ClientSession for each session/user. + State with messages for each session/user. """ messages: Annotated[Sequence[BaseMessage], add_messages] - user_id_token: str - is_last_step: IsLastStep async def create_graph( - tools, + tools: list[ToolboxTool], + insert_ticket: ToolboxTool, + validate_ticket: ToolboxTool, checkpointer: MemorySaver, prompt: ChatPromptTemplate, model_name: str, - client: ClientSession, debug: bool, ): """ Creates a graph that works with a chat model that utilizes tool calling. Args: - tools: A list of StructuredTools that will bind with the chat model. + tools: A list of ToolboxTools that will bind with the chat model. + insert_ticket: A ToolboxTool that inserts ticket for the logged in user. + validate_ticket: A ToolboxTool that validates the given flight data. checkpointer: The checkpoint saver object. This is useful for persisting the state of the graph (e.g., as chat memory). prompt: Initial prompt for the model. This applies to messages before they @@ -85,11 +78,13 @@ async def create_graph( tool_node = ToolNode(tools) # model node - # TODO: Use .bind_tools(tools) to bind the tools with the LLM. model = ChatVertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0) + # Bind the tools with the LLM. + model_with_tools = model.bind_tools(tools) + # Add the prompt to the model to create a model runnable - model_runnable = prompt | model + model_runnable = prompt | model_with_tools async def acall_model(state: UserState, config: RunnableConfig): """ @@ -98,43 +93,6 @@ async def acall_model(state: UserState, config: RunnableConfig): """ messages = state["messages"] res = await model_runnable.ainvoke({"messages": messages}, config) - - # TODO: Remove the temporary fix of parsing LLM response and invoking - # tools until we use bind_tools API and have automatic response parsing - # and tool calling. (see - # https://langchain-ai.github.io/langgraph/#example) - if "```json" in res.content: - try: - response = str(res.content).replace("```json", "").replace("```", "") - json_response = json.loads(response) - action = json_response.get("action") - action_input = json_response.get("action_input") - if action == "Final Answer": - res = AIMessage(content=action_input) - else: - res = AIMessage( - content="suggesting a tool call", - tool_calls=[ - ToolCall( - id=str(uuid.uuid4()), name=action, args=action_input - ) - ], - ) - except Exception as e: - json_response = response - res = AIMessage( - content="Sorry, failed to generate the right format for response" - ) - - # if model exceed the number of steps and has not yet return a final answer - if state["is_last_step"] and hasattr(res, "tool_calls"): - return { - "messages": [ - AIMessage( - content="Sorry, need more steps to process this request.", - ) - ] - } return {"messages": [res]} def agent_should_continue( @@ -151,7 +109,7 @@ def agent_should_continue( for tool_call in last_message.tool_calls: tool_name = tool_call["name"] if tool_name in confirmation_needing_tools: - if tool_name == "Insert Ticket": + if tool_name == "insert_ticket": return "booking_validation" return "continue" # Otherwise, we stop (reply to the user) @@ -164,13 +122,12 @@ async def booking_validation_node(state: UserState, config: RunnableConfig): """ messages = state["messages"] last_message = messages[-1] - user_id_token = state["user_id_token"] if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: tool_call = last_message.tool_calls[0] # Run ticket validation and return the correct ticket information - flight_info = await validate_ticket( - client, tool_call.get("args"), user_id_token - ) + flight_info = await validate_ticket.ainvoke(tool_call.get("args")) + flight_info = json.loads(flight_info) + flight_info = flight_info[0] new_message = AIMessage( content="Please confirm if you would like to book the ticket.", @@ -203,20 +160,11 @@ async def insert_ticket_node(state: UserState, config: RunnableConfig): """ messages = state["messages"] last_message = messages[-1] - user_id_token = state["user_id_token"] # Run insert ticket if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: tool_call = last_message.tool_calls[0] tool_args = tool_call.get("args") - ticket_info = TicketInfo(**tool_args) - output = await insert_ticket(client, ticket_info, user_id_token) - tool_call_id = tool_call.get("id") - tool_message = ToolMessage( - content=output, name="Insert Ticket", tool_call_id=tool_call_id - ) - human_message = HumanMessage(content="Looks good to me.") - ai_message = AIMessage(content=output) - return {"messages": [human_message, tool_message, ai_message]} + await insert_ticket.ainvoke(tool_args) # Define constant node strings AGENT_NODE = "agent" diff --git a/agent/tools.py b/agent/tools.py new file mode 100644 index 000000000..d67af7d4d --- /dev/null +++ b/agent/tools.py @@ -0,0 +1,38 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Optional + +from toolbox_langchain import ToolboxClient + +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +TOOLBOX_URL = os.getenv("TOOLBOX_URL", default="http://127.0.0.1:5000") + + +# Tools for agent +async def initialize_tools(): + client = ToolboxClient(TOOLBOX_URL) + tools = await client.aload_toolset("cymbal_air") + + # Load insert_ticket and validate_ticket tools separately to implement + # human-in-the-loop. + insert_ticket = await client.aload_tool("insert_ticket") + validate_ticket = await client.aload_tool("validate_ticket") + + return (tools, insert_ticket, validate_ticket) + + +def get_confirmation_needing_tools(): + return ["insert_ticket"] diff --git a/app.py b/app.py index efe4ca387..1629ea44c 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,7 @@ from markdown import markdown from starlette.middleware.sessions import SessionMiddleware -from orchestrator import Orchestrator +from agent import Agent routes = APIRouter() templates = Jinja2Templates(directory="templates") @@ -39,7 +39,7 @@ async def lifespan(app: FastAPI): print("Loading application...") yield # FastAPI app shutdown event - await app.state.orchestrator.close_clients() + print("Application shutdown.") @routes.get("/") @@ -47,12 +47,12 @@ async def lifespan(app: FastAPI): async def index(request: Request): """Render the default template.""" # User session setup - orchestrator = request.app.state.orchestrator + agent = request.app.state.agent session = request.session # check if token and user info is still valid if "uuid" in session: - user_id_token = orchestrator.get_user_id_token(session["uuid"]) + user_id_token = agent.get_user_id_token(session["uuid"]) if user_id_token: if session.get("user_info") and not get_user_info( user_id_token, request.app.state.client_id @@ -61,8 +61,8 @@ async def index(request: Request): elif not user_id_token and "user_info" in session: await logout_google(request) - if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]): - await orchestrator.user_session_create(session) + if "uuid" not in session or not agent.user_session_exist(session["uuid"]): + await agent.user_session_create(session) return templates.TemplateResponse( "index.html", @@ -102,8 +102,8 @@ async def login_google( session["user_info"] = user_info # create new request session - orchestrator = request.app.state.orchestrator - orchestrator.set_user_session_header(session["uuid"], str(user_id_token)) + agent = request.app.state.agent + agent.set_user_session_header(session["uuid"], str(user_id_token)) print("Logged in to Google.") welcome_text = ( @@ -131,9 +131,9 @@ async def logout_google( raise HTTPException(status_code=400, detail="No session to reset.") uuid = request.session["uuid"] - orchestrator = request.app.state.orchestrator - if orchestrator.user_session_exist(uuid): - await orchestrator.user_session_signout(uuid) + agent = request.app.state.agent + if agent.user_session_exist(uuid): + await agent.user_session_signout(uuid) request.session.clear() @@ -150,8 +150,8 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["history"].append({"type": "human", "data": {"content": prompt}}) - orchestrator = request.app.state.orchestrator - response = await orchestrator.user_session_invoke(request.session["uuid"], prompt) + agent = request.app.state.agent + response = await agent.user_session_invoke(request.session["uuid"], prompt) output = response.get("output") confirmation = response.get("confirmation") trace = response.get("trace") @@ -177,10 +177,8 @@ async def book_flight(request: Request, params: str = Body(embed=True)): raise HTTPException( status_code=400, detail="Error: Invoke index handler before start chatting" ) - orchestrator = request.app.state.orchestrator - response = await orchestrator.user_session_insert_ticket( - request.session["uuid"], params - ) + agent = request.app.state.agent + response = await agent.user_session_insert_ticket(request.session["uuid"], params) # Note in the history, that the ticket has been successfully booked request.session["history"].append( {"type": "ai", "data": {"content": "I have booked your ticket."}} @@ -193,8 +191,8 @@ async def decline_flight(request: Request): """Handler for LangChain chat requests""" # Note in the history, that the ticket was not booked # This is helpful in case of reloads so there doesn't seem to be a break in communication. - orchestrator = request.app.state.orchestrator - response = await orchestrator.user_session_decline_ticket(request.session["uuid"]) + agent = request.app.state.agent + await agent.user_session_decline_ticket(request.session["uuid"]) request.session["history"].append( {"type": "ai", "data": {"content": "Please confirm if you would like to book."}} ) @@ -212,11 +210,11 @@ def reset(request: Request): raise HTTPException(status_code=400, detail="No session to reset.") uuid = request.session["uuid"] - orchestrator = request.app.state.orchestrator - if not orchestrator.user_session_exist(uuid): + agent = request.app.state.agent + if not agent.user_session_exist(uuid): raise HTTPException(status_code=500, detail="Current user session not found") - orchestrator.user_session_reset(request.session, uuid) + agent.user_session_reset(request.session, uuid) def get_user_info(user_id_token: str, client_id: str) -> dict[str, str]: @@ -243,7 +241,7 @@ def init_app( # FastAPI setup app = FastAPI(lifespan=lifespan) app.state.client_id = client_id - app.state.orchestrator = Orchestrator() + app.state.agent = Agent() app.include_router(routes) app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware(SessionMiddleware, secret_key=middleware_secret) diff --git a/evaluation/eval_golden.py b/evaluation/eval_golden.py index bdce26ee8..59b2f6c47 100644 --- a/evaluation/eval_golden.py +++ b/evaluation/eval_golden.py @@ -21,7 +21,7 @@ class ToolCall(BaseModel): """ - Represents tool call by orchestration. + Represents tool call by agent. """ name: str diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py index 8208687f8..ef2e95826 100644 --- a/evaluation/evaluation.py +++ b/evaluation/evaluation.py @@ -20,14 +20,14 @@ from vertexai.evaluation import EvalTask from vertexai.evaluation import _base as evaluation_base -from orchestrator import Orchestrator +from agent import Agent from .eval_golden import EvalData, ToolCall from .metrics import response_phase_metrics, retrieval_phase_metrics async def run_llm_for_eval( - eval_list: List[EvalData], orc: Orchestrator, session: Dict, session_id: str + eval_list: List[EvalData], orc: Agent, session: Dict, session_id: str ) -> List[EvalData]: """ Generate llm_tool_calls and llm_output for golden dataset query. diff --git a/orchestrator/tools.py b/orchestrator/tools.py deleted file mode 100644 index 18ed8da11..000000000 --- a/orchestrator/tools.py +++ /dev/null @@ -1,455 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from dataclasses import dataclass -from datetime import date, datetime -from typing import Any, Dict, Optional - -import aiohttp -import google.oauth2.id_token # type: ignore -from google.auth import compute_engine # type: ignore -from google.auth.transport.requests import Request # type: ignore -from langchain_core.tools import StructuredTool -from pydantic import BaseModel, Field - -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -CREDENTIALS = None - - -def filter_none_values(params: Dict) -> Dict: - return {key: value for key, value in params.items() if value is not None} - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not hasattr(CREDENTIALS, "id_token"): - # Use Compute Engine default credential - CREDENTIALS = compute_engine.IDTokenCredentials( - request=Request(), - target_audience=BASE_URL, - use_metadata_identity_endpoint=True, - ) - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - if hasattr(CREDENTIALS, "id_token"): - return CREDENTIALS.id_token - else: - return CREDENTIALS.token - - -def get_headers(client: aiohttp.ClientSession, user_id_token: str): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - headers["User-Id-Token"] = f"Bearer {user_id_token}" - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers - - -# Tools -class AirportSearchInput(BaseModel): - country: Optional[str] = Field(description="Country") - city: Optional[str] = Field(description="City") - name: Optional[str] = Field(description="Airport name") - user_id_token: Optional[str] - - -def generate_search_airports(client: aiohttp.ClientSession): - async def search_airports(country: str, city: str, name: str, user_id_token: str): - params = { - "country": country, - "city": city, - "name": name, - } - response = await client.get( - url=f"{BASE_URL}/airports/search", - params=filter_none_values(params), - headers=get_headers(client, user_id_token), - ) - - response_json = await response.json() - if len(response_json) < 1: - return "There are no airports matching that query. Let the user know there are no results." - else: - return response_json - - return search_airports - - -class FlightNumberInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - user_id_token: Optional[str] - - -def generate_search_flights_by_number(client: aiohttp.ClientSession): - async def search_flights_by_number( - airline: str, flight_number: str, user_id_token: str - ): - response = await client.get( - url=f"{BASE_URL}/flights/search", - params={"airline": airline, "flight_number": flight_number}, - headers=get_headers(client, user_id_token), - ) - - return await response.json() - - return search_flights_by_number - - -class ListFlightsInput(BaseModel): - departure_airport: Optional[str] = Field( - description="Departure airport 3-letter code", - ) - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - date: str = Field(description="Date of flight departure") - user_id_token: Optional[str] - - -def generate_list_flights(client: aiohttp.ClientSession): - async def list_flights( - departure_airport: str, - arrival_airport: str, - date: str, - user_id_token: str, - ): - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "date": date, - } - response = await client.get( - url=f"{BASE_URL}/flights/search", - params=filter_none_values(params), - headers=get_headers(client, user_id_token), - ) - - response_json = await response.json() - if len(response_json) < 1: - return { - "results": "There are no flights matching that query. Let the user know there are no results." - } - else: - return response_json - - return list_flights - - -class QueryInput(BaseModel): - query: str = Field(description="Search query") - user_id_token: Optional[str] - - -def generate_search_amenities(client: aiohttp.ClientSession): - async def search_amenities(query: str, user_id_token: str): - response = await client.get( - url=f"{BASE_URL}/amenities/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client, user_id_token), - ) - - response = await response.json() - return response - - return search_amenities - - -def generate_search_policies(client: aiohttp.ClientSession): - async def search_policies(query: str, user_id_token: str): - response = await client.get( - url=f"{BASE_URL}/policies/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client, user_id_token), - ) - - response = await response.json() - return response - - return search_policies - - -class TicketInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - departure_airport: str = Field( - description="Departure airport 3-letter code", - ) - departure_time: datetime = Field(description="Flight departure datetime") - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - arrival_time: Optional[datetime] = Field(description="Flight arrival datetime") - - -def generate_insert_ticket(client: aiohttp.ClientSession): - async def insert_ticket( - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: datetime, - arrival_time: datetime, - ): - return {"results": f"Booking ticket on {airline} {flight_number}"} - - return insert_ticket - - -@dataclass -class TicketInfo: - airline: str - flight_number: str - departure_airport: str - departure_time: str - arrival_airport: str - arrival_time: str - - -async def insert_ticket( - client: aiohttp.ClientSession, ticket_info: TicketInfo, user_id_token: str -): - response = await client.post( - url=f"{BASE_URL}/tickets/insert", - params={ - "airline": ticket_info.airline, - "flight_number": ticket_info.flight_number, - "departure_airport": ticket_info.departure_airport, - "arrival_airport": ticket_info.arrival_airport, - "departure_time": ticket_info.departure_time.replace("T", " "), - "arrival_time": ticket_info.arrival_time.replace("T", " "), - }, - headers=get_headers(client, user_id_token), - ) - response = await response.json() - return "Flight booking successful." - - -async def validate_ticket( - client: aiohttp.ClientSession, ticket_info: Dict[Any, Any], user_id_token: str -): - response = await client.get( - url=f"{BASE_URL}/tickets/validate", - params=filter_none_values( - { - "airline": ticket_info.get("airline"), - "flight_number": ticket_info.get("flight_number"), - "departure_airport": ticket_info.get("departure_airport"), - "departure_time": ticket_info.get("departure_time", "").replace( - "T", " " - ), - } - ), - headers=get_headers(client, user_id_token), - ) - response_json = await response.json() - response_results = response_json.get("results") - - flight_info = { - "airline": response_results.get("airline"), - "flight_number": response_results.get("flight_number"), - "departure_airport": response_results.get("departure_airport"), - "arrival_airport": response_results.get("arrival_airport"), - "departure_time": response_results.get("departure_time"), - "arrival_time": response_results.get("arrival_time"), - } - return flight_info - - -def generate_list_tickets(client: aiohttp.ClientSession): - async def list_tickets(user_id_token: str): - response = await client.get( - url=f"{BASE_URL}/tickets/list", - headers=get_headers(client, user_id_token), - ) - - response_json = await response.json() - tickets = response_json.get("results") - if len(tickets) == 0: - return { - "results": "There are no upcoming tickets", - "sql": response_json.get("sql"), - } - else: - return response_json - - return list_tickets - - -# Tools for agent -async def initialize_tools(client: aiohttp.ClientSession): - return [ - StructuredTool.from_function( - coroutine=generate_search_airports(client), - name="Search Airport", - description=""" - Use this tool to list all airports matching search criteria. - Takes at least one of country, city, name, or all and returns all matching airports. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - country, city, name. - Example: - {{ - "country": "United States", - "city": "San Francisco", - "name": null - }} - Example: - {{ - "country": null, - "city": "Goroka", - "name": "Goroka" - }} - Example: - {{ - "country": "Mexico", - "city": null, - "name": null - }} - """, - args_schema=AirportSearchInput, - ), - StructuredTool.from_function( - coroutine=generate_search_flights_by_number(client), - name="Search Flights By Flight Number", - description=""" - Use this tool to get information for a specific flight. - Takes an airline code and flight number and returns info on the flight. - Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number. - A airline code is a code for an airline service consisting of two-character - airline designator and followed by flight number, which is 1 to 4 digit number. - For example, if given CY 0123, the airline is "CY", and flight_number is "123". - Another example for this is DL 1234, the airline is "DL", and flight_number is "1234". - If the tool returns more than one option choose the date closes to today. - Example: - {{ - "airline": "CY", - "flight_number": "888", - }} - Example: - {{ - "airline": "DL", - "flight_number": "1234", - }} - """, - args_schema=FlightNumberInput, - ), - StructuredTool.from_function( - coroutine=generate_list_flights(client), - name="List Flights", - description=""" - Use this tool to list flights information matching search criteria. - Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. - If 3-letter iata code is not provided for departure_airport or arrival_airport, use search airport tools to get iata code information. - Do NOT guess a date, ask user for date input if it is not given. Date must be in the following format: YYYY-MM-DD. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date. - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": null, - "date": 2025-10-30" - }} - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": "SEA", - "date": "2025-11-01" - }} - Example: - {{ - "departure_airport": null, - "arrival_airport": "SFO", - "date": "2025-01-01" - }} - """, - args_schema=ListFlightsInput, - ), - StructuredTool.from_function( - coroutine=generate_search_amenities(client), - name="Search Amenities", - description=""" - Use this tool to search amenities by name or to recommended airport amenities at SFO. - If user provides flight info, use 'Search Flights by Flight Number' - first to get gate info and location. - Only recommend amenities that are returned by this query. - Find amenities close to the user by matching the terminal and then comparing - the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 - B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. - Input of this tool must be in JSON format and include one `query` input. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_search_policies(client), - name="Search Policies", - description=""" - Use this tool to search for cymbal air passenger policy. - Policy that are listed is unchangeable. - You will not answer any questions outside of the policy given. - Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. - Input of this tool must be in JSON format and include one `query` input. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_insert_ticket(client), - name="Insert Ticket", - description=""" - Use this tool to book a flight ticket for the user. - Example: - {{ - "airline": "AA", - "flight_number": "452", - "departure_airport": "LAX", - "arrival_airport": "SFO", - "departure_time": "2025-01-01 05:50:00", - "arrival_time": "2025-01-01 09:23:00" - }} - Example: - {{ - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "arrival_airport": "DEN", - "departure_time": "2025-01-08 05:50:00", - "arrival_time": "2025-01-08 09:23:00" - }} - Example: - {{ - "airline": "OO", - "flight_number": "6307", - "departure_airport": "SFO", - "arrival_airport": "MSP", - "departure_time": "2025-10-28 20:13:00", - "arrival_time": "2025-10-28 21:07:00" - }} - """, - args_schema=TicketInput, - ), - StructuredTool.from_function( - coroutine=generate_list_tickets(client), - name="List Tickets", - description=""" - Use this tool to list a user's flight tickets. - Takes no input and returns a list of current user's flight tickets. - Input is always empty JSON blob. Example: {{}} - """, - ), - ] - - -def get_confirmation_needing_tools(): - return ["Insert Ticket"] diff --git a/requirements.txt b/requirements.txt index 8c244916a..f21a5e9ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ types-pytz==2025.1.0.20250204 langgraph==0.4.8 pandas-stubs==2.2.2.240807 pandas==2.2.3 -pydantic==2.9.0 \ No newline at end of file +pydantic==2.9.0 +toolbox-langchain==0.3.0 \ No newline at end of file diff --git a/run_evaluation.py b/run_evaluation.py index f24087092..890bb999c 100644 --- a/run_evaluation.py +++ b/run_evaluation.py @@ -26,7 +26,7 @@ goldens, run_llm_for_eval, ) -from orchestrator import Orchestrator +from agent import Agent def export_metrics_table_csv(retrieval: pd.DataFrame, response: pd.DataFrame): @@ -56,8 +56,8 @@ async def main(): "RESPONSE_EXPERIMENT_NAME", default="response-phase-eval" ) - # Prepare orchestrator and session - orc = Orchestrator() + # Prepare agent and session + orc = Agent() session_id = str(uuid.uuid4()) session = {"uuid": session_id} await orc.user_session_create(session)