diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 01395dd9..f373c1e7 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -3,7 +3,7 @@ import os import pathlib import uuid -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from cleanlab_codex import Client, Project from codex.types import ProjectValidateResponse @@ -14,6 +14,7 @@ CallToolsNode, ModelMessage, ModelRequestNode, + ModelResponse, ModelSettings, TextPart, ToolCallPart, @@ -42,6 +43,9 @@ consult_cleanlab, update_prompt_with_guidance, ) +from airline_agent.cleanlab_utils.conversion_utils import ( + convert_string_to_response_message, +) from airline_agent.cleanlab_utils.validate_utils import ( get_tools_in_openai_format, run_cleanlab_validation_logging_tools, @@ -96,6 +100,11 @@ async def airline_chat_streaming( run_id = uuid.uuid4() thread_id = message.thread_id + tool_fallback_formatters: dict[str, Callable[[str], str]] = { + **booking.tool_fallback_fmt, + } + mutate_tools = list(tool_fallback_formatters.keys()) + if thread_id not in cleanlab_enabled_by_thread: cleanlab_enabled_by_thread[thread_id] = cleanlab_enabled elif cleanlab_enabled_by_thread[thread_id] != cleanlab_enabled: @@ -118,6 +127,7 @@ async def airline_chat_streaming( thread_to_messages[thread_id] = [] current_tool_calls: dict[str, ToolCall] = {} + mutate_tool_fallback_response: list[str] = [] original_user_query = message.content if cleanlab_enabled: @@ -158,6 +168,16 @@ async def airline_chat_streaming( request = node.request for request_part in request.parts: if isinstance(request_part, ToolReturnPart) and request_part.tool_call_id in current_tool_calls: + tool_call = current_tool_calls[request_part.tool_call_id] + + if tool_call.tool_name in mutate_tools: + tool_result_str = request_part.model_response_str() + if tool_call.tool_name in tool_fallback_formatters: + formatted_result = tool_fallback_formatters[tool_call.tool_name](tool_result_str) + mutate_tool_fallback_response.append(formatted_result) + else: + mutate_tool_fallback_response.append(tool_result_str) + yield RunEventThreadMessage( id=run_id, object=RunEventObject.THREAD_MESSAGE, @@ -185,12 +205,25 @@ async def airline_chat_streaming( if guidance_items else None, ) + + response_content = ( + "I've completed the following for you:\n\n" + "\n\n".join(mutate_tool_fallback_response) + if validation_result.should_guardrail and mutate_tool_fallback_response + else final_response + ) + if ( + response_content != final_response + and updated_message_history + and isinstance(updated_message_history[-1], ModelResponse) + ): + updated_message_history[-1] = convert_string_to_response_message(response_content) + yield RunEventThreadMessage( id=run_id, object=RunEventObject.THREAD_MESSAGE, data=AssistantMessage( thread_id=thread_id, - content=final_response, + content=response_content, metadata=MessageMetadata( original_llm_response=run.result.output, is_expert_answer=validation_result.expert_answer is not None, diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index 170e8629..515045f0 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -27,7 +27,7 @@ AGENT_MODEL = "gpt-4o" FALLBACK_RESPONSE = "I'm sorry, but I don't have the information you're looking for. Please rephrase the question or contact Frontier Airlines customer support for further assistance." AGENT_INSTRUCTIONS = f""" -You are an AI customer support agent for Frontier Airlines. You can use tools to access a knowledge base of articles and documents about the airline's services, policies, and procedures. You can help users find flight information and pricing, but you cannot book flights or make reservations. +You are an AI customer support agent for Frontier Airlines. You can use tools to access a knowledge base of articles and documents about the airline's services, policies, and procedures. You can help users find flight information and pricing. You can also book flights. ## You have access to the following tools: - search — find candidate articles by query (keep top-k small, ≤5), returns title/snippet/path. @@ -35,6 +35,11 @@ - list_directory — list directory structure to make more informed searches. - search_flights — search available flights by origin and destination airport codes (IATA) and departure date (YYYY-MM-DD). Always ask for the departure date if the user doesn't provide it. - get_fare_details — retrieve fare bundle pricing, included services, and add-ons for a specific flight. +- book_flights — book one or more flights for the current user. Requires list of flight IDs and fare bundle type (basic, economy, premium, business; defaults to basic). Returns booking confirmation with booking ID and total price. +- get_booking — retrieve booking details by booking ID. +- get_my_bookings — retrieve all confirmed bookings for the current user. +- add_service_to_booking — add an eligible service (bags, seat selection, etc.) to a specific flight within a booking. +- check_in — complete check-in for a specific flight in a booking. - get_flight_timings — get check-in, boarding, and door-close timing windows for a flight. - get_flight_status — get the latest status, gates, and delay information for a flight. @@ -43,12 +48,15 @@ - Answer primarily based on information from retrieved content, unless the question is simply to clarify broadly understood aspects of commercial air travel (such as standard security procedures, boarding processes, or common airline terminology). - If a missing detail blocks tool use, ask one short clarifying question. If not blocking, proceed and state your assumption. - Don't dump raw tool output—summarize clearly. +- When booking multiple flights (outbound and return), include all flight IDs in a single book_flights call. ## Response Guidelines: - Answer questions primarily based on information you look up in the knowledge base. - For requests that involve general airline knowledge that is not specific to Frontier Airlines (e.g., common terms, standard processes, and widely known industry roles), you may rely on your own knowledge if the knowledge base does not add important Frontier-specific details. - When responding to user, never use phrases like "according to the knowledge base", "I couldn't find anything in the knowledge base", etc. When responding to user, treat the retrieved knowledge base content as your own knowledge, not something you are referencing or searching through. - **If the user asks something unrelated to Frontier Airlines or air travel, politely refuse and redirect the conversation. Do not attempt to fulfill or improvise unrelated requests.** +- When a booking is successfully created, provide the booking ID and confirmation details clearly. +- If you book flights, provide the booking ID and summarize the flights booked and total price. - Avoid hedging language (e.g., “typically,” “generally,” “usually”) when the information is known and factual. Be clear and assertive in your response, and do not speculate. ## Context: diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 8a16184a..720fd200 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -1,4 +1,6 @@ +import json import random +from collections.abc import Callable from datetime import date, datetime, timedelta from typing import Any @@ -185,6 +187,33 @@ def book_flights( return booking + def _format_service_type(self, service_type: str) -> str: + """Format service type from snake_case to Title Case with spaces.""" + return service_type.replace("_", " ").title() + + def _book_flights_to_string(self, result: str) -> str: + """Format book_flights result for fallback response.""" + try: + booking_data = json.loads(result) + booking = Booking.model_validate(booking_data) + + flight_parts = [] + for i, flight_booking in enumerate(booking.flights, 1): + add_ons_text = ( + f". You also have the following add ons: {', '.join(self._format_service_type(ao.service_type) for ao in flight_booking.add_ons)}" + if flight_booking.add_ons + else "" + ) + flight_parts.append( + f"{i}. Flight {flight_booking.flight_id} ({flight_booking.fare_type} fare) ({flight_booking.currency} ${flight_booking.price_total:.2f}){add_ons_text}" + ) + + flights_text = "\n".join(flight_parts) + except (json.JSONDecodeError, ValueError): + return result + else: + return f"Booking confirmed with booking ID {booking.booking_id} for a total price of {booking.currency} ${booking.total_price:.2f}.\n\nFlights:\n{flights_text}" + def get_booking(self, booking_id: str) -> Booking: """ Retrieve a booking by its booking ID. @@ -325,6 +354,32 @@ def add_service_to_booking( return booking + def _add_service_to_booking_to_string(self, result: str) -> str: + """Format add_service_to_booking result for fallback response.""" + try: + booking_data = json.loads(result) + booking = Booking.model_validate(booking_data) + + # Find the flight that has the most recently added add-on + # (the one that was just modified) + service_parts = [] + for flight_booking in booking.flights: + if flight_booking.add_ons: + # Get the most recently added add-on + latest_addon = max(flight_booking.add_ons, key=lambda ao: ao.added_at) + add_ons_list = [self._format_service_type(ao.service_type) for ao in flight_booking.add_ons] + add_ons_text = ", ".join(add_ons_list) + price_text = f" ({booking.currency} ${latest_addon.price:.2f})" if latest_addon.price > 0 else "" + service_parts.append( + f"Added {self._format_service_type(latest_addon.service_type)}{price_text} to flight {flight_booking.flight_id}.\nYour add-ons for this flight are: {add_ons_text}" + ) + + services_text = "\n\n".join(service_parts) + except (json.JSONDecodeError, ValueError): + return result + else: + return services_text + def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: """Assign a seat to a flight booking based on preferences, fare type, or randomly.""" # Check if any seat selection add-on exists with an assignment @@ -468,6 +523,29 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: return booking + def _check_in_to_string(self, result: str) -> str: + """Format check_in result for fallback response.""" + try: + booking_data = json.loads(result) + booking = Booking.model_validate(booking_data) + except (json.JSONDecodeError, ValueError): + return result + + # Find the flight that was checked in + checked_in_flights = [fb for fb in booking.flights if fb.checked_in] + + if not checked_in_flights: + return f"Checked in for booking {booking.booking_id}." + + check_in_parts = [] + for flight_booking in checked_in_flights: + seat_text = ( + f" Your seat assignment is {flight_booking.seat_assignment}." if flight_booking.seat_assignment else "" + ) + check_in_parts.append(f"Checked in for flight {flight_booking.flight_id}.{seat_text}") + + return "\n".join(check_in_parts) if check_in_parts else f"Checked in for booking {booking.booking_id}." + def get_flight_timings(self, flight_id: str) -> dict[str, Any]: """ Get all timing windows for a flight (check-in, boarding, doors close, etc.). @@ -562,15 +640,25 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]: @property def tools(self) -> FunctionToolset: - # For now, only include read-only/informational tools. - # - # State mutation tools (book_flights, add_service_to_booking, check_in) - # and booking lookup tools (get_booking, get_my_bookings) are excluded. return FunctionToolset( tools=[ self.search_flights, self.get_fare_details, + self.book_flights, + self.get_booking, + self.get_my_bookings, + self.add_service_to_booking, + self.check_in, self.get_flight_timings, self.get_flight_status, ] ) + + @property + def tool_fallback_fmt(self) -> dict[str, Callable[[str], str]]: + """Return formatters for mutative tools in this toolset.""" + return { + "book_flights": self._book_flights_to_string, + "add_service_to_booking": self._add_service_to_booking_to_string, + "check_in": self._check_in_to_string, + } diff --git a/tests/func/test_booking.py b/tests/func/test_booking.py index 548d9930..f2b9182b 100644 --- a/tests/func/test_booking.py +++ b/tests/func/test_booking.py @@ -1,9 +1,33 @@ -import pytest +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any +from unittest.mock import patch +from codex.types.project_validate_response import ProjectValidateResponse +from pydantic_ai.messages import ModelMessage + +from airline_agent.cleanlab_utils.validate_utils import run_cleanlab_validation_logging_tools from airline_agent.util import TestAgent as Agent from tests.judge import assert_judge +@contextmanager +def mock_cleanlab_validation_with_guardrail() -> Iterator[None]: + """Context manager that mocks cleanlab validation to always set should_guardrail=True.""" + + def mock_validation(*args: Any, **kwargs: Any) -> tuple[list[ModelMessage], str, ProjectValidateResponse]: + updated_message_history, final_response, validation_result = run_cleanlab_validation_logging_tools( + *args, **kwargs + ) + validation_result.should_guardrail = True + return updated_message_history, final_response, validation_result + + with patch( + "airline_agent.backend.services.airline_chat.run_cleanlab_validation_logging_tools", side_effect=mock_validation + ): + yield + + def test_search_flights() -> None: agent = Agent(cleanlab_enabled=False) answer, _ = agent.chat("I want to fly from SFO to JFK on November 12, 2025") @@ -29,7 +53,6 @@ def test_get_fare_details() -> None: ) -@pytest.mark.skip(reason="book_flights tool disabled") def test_book_single_flight() -> None: agent = Agent(cleanlab_enabled=False) agent.chat("I need a flight from SFO to JFK on November 12, 2025") @@ -44,7 +67,6 @@ def test_book_single_flight() -> None: ) -@pytest.mark.skip(reason="book_flights tool disabled") def test_book_round_trip() -> None: agent = Agent(cleanlab_enabled=False) agent.chat("Find flights from OAK to LGA on November 13, 2025") @@ -60,7 +82,16 @@ def test_book_round_trip() -> None: ) -@pytest.mark.skip(reason="book_flights and get_my_bookings tools disabled") +def test_book_flight_fallback() -> None: + with mock_cleanlab_validation_with_guardrail(): + agent = Agent() + answer, _ = agent.chat("book me the F9 707 flight from SFO to LGA on 11/11") + assert ( + answer + == "I've completed the following for you:\n\nBooking confirmed with booking ID BK-A3B1799D for a total price of USD $80.84.\n\nFlights:\n1. Flight F9-SFO-LGA-2025-11-11T17:00 (basic fare) (USD $80.84)" + ) + + def test_retrieve_booking() -> None: agent = Agent(cleanlab_enabled=False) agent.chat("Find a flight from SJC to JFK on November 12, 2025") @@ -75,7 +106,6 @@ def test_retrieve_booking() -> None: ) -@pytest.mark.skip(reason="book_flights and add_service_to_booking tools disabled") def test_add_service_to_booking() -> None: agent = Agent(cleanlab_enabled=False) agent.chat("Show me flights from SFO to EWR on November 14, 2025") @@ -90,10 +120,20 @@ def test_add_service_to_booking() -> None: ) -@pytest.mark.skip(reason="book_flights and check_in tools disabled") +def test_add_service_to_booking_fallback() -> None: + agent = Agent() + agent.chat("book me the F9 707 flight from SFO to LGA on 11/11") + with mock_cleanlab_validation_with_guardrail(): + answer, _ = agent.chat("Add a checked bag to my booking") + assert ( + answer + == "I've completed the following for you:\n\nAdded Checked Bag (USD $37.91) to flight F9-SFO-LGA-2025-11-11T17:00.\nYour add-ons for this flight are: Checked Bag" + ) + + def test_check_in() -> None: agent = Agent(cleanlab_enabled=False) - agent.chat("Find flights from SFO to JFK on November 12, 2025") + agent.chat("Find flights from SFO to JFK on November 6, 2025") agent.chat("Book the first available flight") answer, _ = agent.chat("Check me in for my flight") assert_judge( @@ -105,6 +145,17 @@ def test_check_in() -> None: ) +def test_check_in_fallback() -> None: + agent = Agent() + agent.chat("book me the F9 707 flight from SFO to LGA on 11/11") + with mock_cleanlab_validation_with_guardrail(): + answer, _ = agent.chat("Check me in for my flight") + assert ( + answer + == "I've completed the following for you:\n\nChecked in for flight F9-SFO-LGA-2025-11-11T17:00. Your seat assignment is 19F." + ) + + def test_flight_status() -> None: agent = Agent(cleanlab_enabled=False) agent.chat("Show me flights from OAK to LGA on November 12, 2025") @@ -168,7 +219,6 @@ def test_no_date_provided() -> None: ) -@pytest.mark.skip(reason="get_my_bookings tool disabled") def test_no_existing_bookings() -> None: agent = Agent(cleanlab_enabled=False) answer, _ = agent.chat("Show me my bookings")