Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/airline_agent/backend/services/airline_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pathlib
import uuid
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable

from cleanlab_codex import Client, Project
from codex.types import ProjectValidateResponse
Expand All @@ -14,6 +14,7 @@
CallToolsNode,
ModelMessage,
ModelRequestNode,
ModelResponse,
ModelSettings,
TextPart,
ToolCallPart,
Expand Down Expand Up @@ -42,6 +43,9 @@
consult_cleanlab,
update_prompt_with_guidance,
)
from airline_agent.cleanlab_utils.conversion_utils import (
convert_string_to_response_message,
)
from airline_agent.cleanlab_utils.validate_utils import (
get_tools_in_openai_format,
run_cleanlab_validation_logging_tools,
Expand Down Expand Up @@ -96,6 +100,11 @@ async def airline_chat_streaming(
run_id = uuid.uuid4()
thread_id = message.thread_id

tool_fallback_formatters: dict[str, Callable[[str], str]] = {
**booking.tool_fallback_fmt,
}
mutate_tools = list(tool_fallback_formatters.keys())

if thread_id not in cleanlab_enabled_by_thread:
cleanlab_enabled_by_thread[thread_id] = cleanlab_enabled
elif cleanlab_enabled_by_thread[thread_id] != cleanlab_enabled:
Expand All @@ -118,6 +127,7 @@ async def airline_chat_streaming(
thread_to_messages[thread_id] = []

current_tool_calls: dict[str, ToolCall] = {}
mutate_tool_fallback_response: list[str] = []

original_user_query = message.content
if cleanlab_enabled:
Expand Down Expand Up @@ -158,6 +168,16 @@ async def airline_chat_streaming(
request = node.request
for request_part in request.parts:
if isinstance(request_part, ToolReturnPart) and request_part.tool_call_id in current_tool_calls:
tool_call = current_tool_calls[request_part.tool_call_id]

if tool_call.tool_name in mutate_tools:
tool_result_str = request_part.model_response_str()
if tool_call.tool_name in tool_fallback_formatters:
formatted_result = tool_fallback_formatters[tool_call.tool_name](tool_result_str)
mutate_tool_fallback_response.append(formatted_result)
else:
mutate_tool_fallback_response.append(tool_result_str)

yield RunEventThreadMessage(
id=run_id,
object=RunEventObject.THREAD_MESSAGE,
Expand Down Expand Up @@ -185,12 +205,25 @@ async def airline_chat_streaming(
if guidance_items
else None,
)

response_content = (
"I've completed the following for you:\n\n" + "\n\n".join(mutate_tool_fallback_response)
if validation_result.should_guardrail and mutate_tool_fallback_response
else final_response
)
if (
response_content != final_response
and updated_message_history
and isinstance(updated_message_history[-1], ModelResponse)
):
updated_message_history[-1] = convert_string_to_response_message(response_content)

yield RunEventThreadMessage(
id=run_id,
object=RunEventObject.THREAD_MESSAGE,
data=AssistantMessage(
thread_id=thread_id,
content=final_response,
content=response_content,
metadata=MessageMetadata(
original_llm_response=run.result.output,
is_expert_answer=validation_result.expert_answer is not None,
Expand Down
10 changes: 9 additions & 1 deletion src/airline_agent/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@
AGENT_MODEL = "gpt-4o"
FALLBACK_RESPONSE = "I'm sorry, but I don't have the information you're looking for. Please rephrase the question or contact Frontier Airlines customer support for further assistance."
AGENT_INSTRUCTIONS = f"""
You are an AI customer support agent for Frontier Airlines. You can use tools to access a knowledge base of articles and documents about the airline's services, policies, and procedures. You can help users find flight information and pricing, but you cannot book flights or make reservations.
You are an AI customer support agent for Frontier Airlines. You can use tools to access a knowledge base of articles and documents about the airline's services, policies, and procedures. You can help users find flight information and pricing. You can also book flights.

## You have access to the following tools:
- search — find candidate articles by query (keep top-k small, ≤5), returns title/snippet/path.
- get_article — get the full article by its path.
- list_directory — list directory structure to make more informed searches.
- search_flights — search available flights by origin and destination airport codes (IATA) and departure date (YYYY-MM-DD). Always ask for the departure date if the user doesn't provide it.
- get_fare_details — retrieve fare bundle pricing, included services, and add-ons for a specific flight.
- book_flights — book one or more flights for the current user. Requires list of flight IDs and fare bundle type (basic, economy, premium, business; defaults to basic). Returns booking confirmation with booking ID and total price.
- get_booking — retrieve booking details by booking ID.
- get_my_bookings — retrieve all confirmed bookings for the current user.
- add_service_to_booking — add an eligible service (bags, seat selection, etc.) to a specific flight within a booking.
- check_in — complete check-in for a specific flight in a booking.
- get_flight_timings — get check-in, boarding, and door-close timing windows for a flight.
- get_flight_status — get the latest status, gates, and delay information for a flight.

Expand All @@ -43,12 +48,15 @@
- Answer primarily based on information from retrieved content, unless the question is simply to clarify broadly understood aspects of commercial air travel (such as standard security procedures, boarding processes, or common airline terminology).
- If a missing detail blocks tool use, ask one short clarifying question. If not blocking, proceed and state your assumption.
- Don't dump raw tool output—summarize clearly.
- When booking multiple flights (outbound and return), include all flight IDs in a single book_flights call.

## Response Guidelines:
- Answer questions primarily based on information you look up in the knowledge base.
- For requests that involve general airline knowledge that is not specific to Frontier Airlines (e.g., common terms, standard processes, and widely known industry roles), you may rely on your own knowledge if the knowledge base does not add important Frontier-specific details.
- When responding to user, never use phrases like "according to the knowledge base", "I couldn't find anything in the knowledge base", etc. When responding to user, treat the retrieved knowledge base content as your own knowledge, not something you are referencing or searching through.
- **If the user asks something unrelated to Frontier Airlines or air travel, politely refuse and redirect the conversation. Do not attempt to fulfill or improvise unrelated requests.**
- When a booking is successfully created, provide the booking ID and confirmation details clearly.
- If you book flights, provide the booking ID and summarize the flights booked and total price.
- Avoid hedging language (e.g., “typically,” “generally,” “usually”) when the information is known and factual. Be clear and assertive in your response, and do not speculate.

## Context:
Expand Down
96 changes: 92 additions & 4 deletions src/airline_agent/tools/booking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import random
from collections.abc import Callable
from datetime import date, datetime, timedelta
from typing import Any

Expand Down Expand Up @@ -185,6 +187,33 @@ def book_flights(

return booking

def _format_service_type(self, service_type: str) -> str:
"""Format service type from snake_case to Title Case with spaces."""
return service_type.replace("_", " ").title()

def _book_flights_to_string(self, result: str) -> str:
"""Format book_flights result for fallback response."""
try:
booking_data = json.loads(result)
booking = Booking.model_validate(booking_data)

flight_parts = []
for i, flight_booking in enumerate(booking.flights, 1):
add_ons_text = (
f". You also have the following add ons: {', '.join(self._format_service_type(ao.service_type) for ao in flight_booking.add_ons)}"
if flight_booking.add_ons
else ""
)
flight_parts.append(
f"{i}. Flight {flight_booking.flight_id} ({flight_booking.fare_type} fare) ({flight_booking.currency} ${flight_booking.price_total:.2f}){add_ons_text}"
)

flights_text = "\n".join(flight_parts)
except (json.JSONDecodeError, ValueError):
return result
else:
return f"Booking confirmed with booking ID {booking.booking_id} for a total price of {booking.currency} ${booking.total_price:.2f}.\n\nFlights:\n{flights_text}"

def get_booking(self, booking_id: str) -> Booking:
"""
Retrieve a booking by its booking ID.
Expand Down Expand Up @@ -325,6 +354,32 @@ def add_service_to_booking(

return booking

def _add_service_to_booking_to_string(self, result: str) -> str:
"""Format add_service_to_booking result for fallback response."""
try:
booking_data = json.loads(result)
booking = Booking.model_validate(booking_data)

# Find the flight that has the most recently added add-on
# (the one that was just modified)
service_parts = []
for flight_booking in booking.flights:
if flight_booking.add_ons:
# Get the most recently added add-on
latest_addon = max(flight_booking.add_ons, key=lambda ao: ao.added_at)
add_ons_list = [self._format_service_type(ao.service_type) for ao in flight_booking.add_ons]
add_ons_text = ", ".join(add_ons_list)
price_text = f" ({booking.currency} ${latest_addon.price:.2f})" if latest_addon.price > 0 else ""
service_parts.append(
f"Added {self._format_service_type(latest_addon.service_type)}{price_text} to flight {flight_booking.flight_id}.\nYour add-ons for this flight are: {add_ons_text}"
)

services_text = "\n\n".join(service_parts)
except (json.JSONDecodeError, ValueError):
return result
else:
return services_text

def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str:
"""Assign a seat to a flight booking based on preferences, fare type, or randomly."""
# Check if any seat selection add-on exists with an assignment
Expand Down Expand Up @@ -468,6 +523,29 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking:

return booking

def _check_in_to_string(self, result: str) -> str:
"""Format check_in result for fallback response."""
try:
booking_data = json.loads(result)
booking = Booking.model_validate(booking_data)
except (json.JSONDecodeError, ValueError):
return result

# Find the flight that was checked in
checked_in_flights = [fb for fb in booking.flights if fb.checked_in]

if not checked_in_flights:
return f"Checked in for booking {booking.booking_id}."

check_in_parts = []
for flight_booking in checked_in_flights:
seat_text = (
f" Your seat assignment is {flight_booking.seat_assignment}." if flight_booking.seat_assignment else ""
)
check_in_parts.append(f"Checked in for flight {flight_booking.flight_id}.{seat_text}")

return "\n".join(check_in_parts) if check_in_parts else f"Checked in for booking {booking.booking_id}."

def get_flight_timings(self, flight_id: str) -> dict[str, Any]:
"""
Get all timing windows for a flight (check-in, boarding, doors close, etc.).
Expand Down Expand Up @@ -562,15 +640,25 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]:

@property
def tools(self) -> FunctionToolset:
# For now, only include read-only/informational tools.
#
# State mutation tools (book_flights, add_service_to_booking, check_in)
# and booking lookup tools (get_booking, get_my_bookings) are excluded.
return FunctionToolset(
tools=[
self.search_flights,
self.get_fare_details,
self.book_flights,
self.get_booking,
self.get_my_bookings,
self.add_service_to_booking,
self.check_in,
self.get_flight_timings,
self.get_flight_status,
]
)

@property
def tool_fallback_fmt(self) -> dict[str, Callable[[str], str]]:
"""Return formatters for mutative tools in this toolset."""
return {
"book_flights": self._book_flights_to_string,
"add_service_to_booking": self._add_service_to_booking_to_string,
"check_in": self._check_in_to_string,
}
66 changes: 58 additions & 8 deletions tests/func/test_booking.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
import pytest
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
from unittest.mock import patch

from codex.types.project_validate_response import ProjectValidateResponse
from pydantic_ai.messages import ModelMessage

from airline_agent.cleanlab_utils.validate_utils import run_cleanlab_validation_logging_tools
from airline_agent.util import TestAgent as Agent
from tests.judge import assert_judge


@contextmanager
def mock_cleanlab_validation_with_guardrail() -> Iterator[None]:
"""Context manager that mocks cleanlab validation to always set should_guardrail=True."""

def mock_validation(*args: Any, **kwargs: Any) -> tuple[list[ModelMessage], str, ProjectValidateResponse]:
updated_message_history, final_response, validation_result = run_cleanlab_validation_logging_tools(
*args, **kwargs
)
validation_result.should_guardrail = True
return updated_message_history, final_response, validation_result

with patch(
"airline_agent.backend.services.airline_chat.run_cleanlab_validation_logging_tools", side_effect=mock_validation
):
yield


def test_search_flights() -> None:
agent = Agent(cleanlab_enabled=False)
answer, _ = agent.chat("I want to fly from SFO to JFK on November 12, 2025")
Expand All @@ -29,7 +53,6 @@ def test_get_fare_details() -> None:
)


@pytest.mark.skip(reason="book_flights tool disabled")
def test_book_single_flight() -> None:
agent = Agent(cleanlab_enabled=False)
agent.chat("I need a flight from SFO to JFK on November 12, 2025")
Expand All @@ -44,7 +67,6 @@ def test_book_single_flight() -> None:
)


@pytest.mark.skip(reason="book_flights tool disabled")
def test_book_round_trip() -> None:
agent = Agent(cleanlab_enabled=False)
agent.chat("Find flights from OAK to LGA on November 13, 2025")
Expand All @@ -60,7 +82,16 @@ def test_book_round_trip() -> None:
)


@pytest.mark.skip(reason="book_flights and get_my_bookings tools disabled")
def test_book_flight_fallback() -> None:
with mock_cleanlab_validation_with_guardrail():
agent = Agent()
answer, _ = agent.chat("book me the F9 707 flight from SFO to LGA on 11/11")
assert (
answer
== "I've completed the following for you:\n\nBooking confirmed with booking ID BK-A3B1799D for a total price of USD $80.84.\n\nFlights:\n1. Flight F9-SFO-LGA-2025-11-11T17:00 (basic fare) (USD $80.84)"
)


def test_retrieve_booking() -> None:
agent = Agent(cleanlab_enabled=False)
agent.chat("Find a flight from SJC to JFK on November 12, 2025")
Expand All @@ -75,7 +106,6 @@ def test_retrieve_booking() -> None:
)


@pytest.mark.skip(reason="book_flights and add_service_to_booking tools disabled")
def test_add_service_to_booking() -> None:
agent = Agent(cleanlab_enabled=False)
agent.chat("Show me flights from SFO to EWR on November 14, 2025")
Expand All @@ -90,10 +120,20 @@ def test_add_service_to_booking() -> None:
)


@pytest.mark.skip(reason="book_flights and check_in tools disabled")
def test_add_service_to_booking_fallback() -> None:
agent = Agent()
agent.chat("book me the F9 707 flight from SFO to LGA on 11/11")
with mock_cleanlab_validation_with_guardrail():
answer, _ = agent.chat("Add a checked bag to my booking")
assert (
answer
== "I've completed the following for you:\n\nAdded Checked Bag (USD $37.91) to flight F9-SFO-LGA-2025-11-11T17:00.\nYour add-ons for this flight are: Checked Bag"
)


def test_check_in() -> None:
agent = Agent(cleanlab_enabled=False)
agent.chat("Find flights from SFO to JFK on November 12, 2025")
agent.chat("Find flights from SFO to JFK on November 6, 2025")
agent.chat("Book the first available flight")
answer, _ = agent.chat("Check me in for my flight")
assert_judge(
Expand All @@ -105,6 +145,17 @@ def test_check_in() -> None:
)


def test_check_in_fallback() -> None:
agent = Agent()
agent.chat("book me the F9 707 flight from SFO to LGA on 11/11")
with mock_cleanlab_validation_with_guardrail():
answer, _ = agent.chat("Check me in for my flight")
assert (
answer
== "I've completed the following for you:\n\nChecked in for flight F9-SFO-LGA-2025-11-11T17:00. Your seat assignment is 19F."
)


def test_flight_status() -> None:
agent = Agent(cleanlab_enabled=False)
agent.chat("Show me flights from OAK to LGA on November 12, 2025")
Expand Down Expand Up @@ -168,7 +219,6 @@ def test_no_date_provided() -> None:
)


@pytest.mark.skip(reason="get_my_bookings tool disabled")
def test_no_existing_bookings() -> None:
agent = Agent(cleanlab_enabled=False)
answer, _ = agent.chat("Show me my bookings")
Expand Down