Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions orchestrator/__init__.py → agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .orchestrator import Orchestrator
from .agent import Agent

__ALL__ = ["Orchestrator"]
__ALL__ = ["Agent"]
118 changes: 28 additions & 90 deletions orchestrator/orchestrator.py → agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -16,20 +16,15 @@
import os
import uuid
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, Optional, Sequence, TypedDict
from typing import Any, Dict, List, Optional, Sequence

from aiohttp import ClientSession, TCPConnector
from fastapi import HTTPException
from langchain.globals import set_verbose # type: ignore
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import StructuredTool
from langchain_core.prompts import ChatPromptTemplate
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import MemorySaver
from pytz import timezone

from ..orchestrator import BaseOrchestrator, classproperty
from .react_graph import create_graph
from .tools import initialize_tools

Expand All @@ -41,13 +36,12 @@
}


class LangGraphOrchestrator(BaseOrchestrator):
class Agent:
MODEL = "gemini-2.0-flash-001"

_user_sessions: Dict[str, str]
# aiohttp context
connector = None
client: Optional[ClientSession] = None

def __init__(self):
self._user_sessions = {}
Expand All @@ -72,14 +66,19 @@ async def user_session_decline_ticket(self, uuid: str) -> dict[str, Any]:

async def user_session_create(self, session: dict[str, Any]):
"""Create and load an agent executor with tools and LLM."""
client = await self.create_client_session()
if self._langgraph_app is None:
print("Initializing graph..")
tools = await initialize_tools(client)
prompt = self.create_prompt_template(tools)
tools, insert_ticket, validate_ticket = await initialize_tools()
prompt = self.create_prompt_template()
checkpointer = MemorySaver()
langgraph_app = await create_graph(
tools, checkpointer, prompt, self.MODEL, client, DEBUG
tools,
insert_ticket,
validate_ticket,
checkpointer,
prompt,
self.MODEL,
DEBUG,
)
self._checkpointer = checkpointer
self._langgraph_app = langgraph_app
Expand All @@ -95,7 +94,6 @@ async def user_session_create(self, session: dict[str, Any]):
config = self.get_config(session_id)
self._langgraph_app.update_state(config, {"messages": history})
self._user_sessions[session_id] = ""
self.client = client

async def user_session_invoke(
self, uuid: str, user_prompt: Optional[str]
Expand All @@ -106,10 +104,7 @@ async def user_session_invoke(
)
if user_prompt:
user_query = [HumanMessage(content=user_prompt)]
app_input = {
"messages": user_query,
"user_id_token": self.get_user_id_token(uuid),
}
app_input = {"messages": user_query}
else:
app_input = None
final_state = await self._langgraph_app.ainvoke(
Expand Down Expand Up @@ -165,45 +160,18 @@ def user_session_reset(self, session: dict[str, Any], uuid: str):
# Update state with message history
self._langgraph_app.update_state(config, {"messages": history})

def get_user_session(self, uuid: str):
raise NotImplementedError("Irrelevant to LangGraph.")

def set_user_session_header(self, uuid: str, user_id_token: str):
self._user_sessions[uuid] = user_id_token

def get_user_id_token(self, uuid: str) -> Optional[str]:
return self._user_sessions.get(uuid)

async def get_connector(self) -> TCPConnector:
if self.connector is None:
self.connector = TCPConnector(limit=100)
return self.connector

async def create_client_session(self) -> ClientSession:
return ClientSession(
connector=await self.get_connector(),
connector_owner=False,
headers={},
raise_for_status=True,
)

def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate:
# Create new prompt template
tool_strings = "\n".join(
[f"> {tool.name}: {tool.description}" for tool in tools]
)
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = FORMAT_INSTRUCTIONS.format(
tool_names=tool_names,
)
def create_prompt_template(self) -> ChatPromptTemplate:
current_datetime = "Today's date and current time is {cur_datetime}."
template = "\n\n".join(
[
PREFIX,
current_datetime,
TOOLS_PREFIX,
tool_strings,
format_instructions,
SUFFIX,
]
)
Expand Down Expand Up @@ -242,7 +210,15 @@ def get_base_history(self, session: dict[str, Any]):
return BASE_HISTORY

def get_config(self, uuid: str):
return {"configurable": {"thread_id": uuid, "checkpoint_ns": ""}}
return {
"configurable": {
"thread_id": uuid,
"auth_token_getters": {
"my_google_service": lambda: self.get_user_id_token(uuid)
},
"checkpoint_ns": "",
},
}

async def user_session_signout(self, uuid: str):
checkpoint = empty_checkpoint()
Expand All @@ -252,10 +228,6 @@ async def user_session_signout(self, uuid: str):
)
del self._user_sessions[uuid]

async def close_clients(self):
if self.client:
await self.client.close()


PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs.

Expand All @@ -268,43 +240,9 @@ async def close_clients(self):
require passing results from one query to another. Using the latest AI models, Assistant is able to
generate human-like text based on the input it receives, allowing it to engage in natural-sounding
conversations and provide responses that are coherent and relevant to the topic at hand. The assistant should
not answer questions about other peoples information for privacy reasons.
not answer questions about other people's information for privacy reasons.

Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air
as well as ammenities of San Francisco Airport."""

TOOLS_PREFIX = """
TOOLS:
------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:

"""

FORMAT_INSTRUCTIONS = """
When responding, please output a response in one of two formats:

** Option 1:**
Use this is you want to use a tool.
Markdown code snippet formatted in the following schema:
```json
{{{{
"action": string, \ The action to take. Must be one of {tool_names}
"action_input": string \ The input to the action
}}}}
```

**Option 2:**
Use this if you want to respond directly to the human.
Markdown code snippet formatted following schema:
```json
{{{{
"action": "Final Answer",
"action_input": string \ You should put what you want to return to user here
}}}}
```
"""

SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate.

Remember to respond with a markdown code snippet of a json a single action, and NOTHING else.
"""
as well as amenities of San Francisco Airport."""

SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate."""
90 changes: 19 additions & 71 deletions orchestrator/react_graph.py → agent/react_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,47 @@
import uuid
from typing import Annotated, Literal, Sequence, TypedDict

from aiohttp import ClientSession
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
ToolCall,
ToolMessage,
)
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_google_vertexai import ChatVertexAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
from langgraph.prebuilt import ToolNode
from toolbox_langchain import ToolboxTool

from .tool_node import ToolNode
from .tools import (
TicketInfo,
get_confirmation_needing_tools,
insert_ticket,
validate_ticket,
)
from .tools import get_confirmation_needing_tools


class UserState(TypedDict):
"""
State with messages and ClientSession for each session/user.
State with messages for each session/user.
"""

messages: Annotated[Sequence[BaseMessage], add_messages]
user_id_token: str
is_last_step: IsLastStep


async def create_graph(
tools,
tools: list[ToolboxTool],
insert_ticket: ToolboxTool,
validate_ticket: ToolboxTool,
checkpointer: MemorySaver,
prompt: ChatPromptTemplate,
model_name: str,
client: ClientSession,
debug: bool,
):
"""
Creates a graph that works with a chat model that utilizes tool calling.

Args:
tools: A list of StructuredTools that will bind with the chat model.
tools: A list of ToolboxTools that will bind with the chat model.
insert_ticket: A ToolboxTool that inserts ticket for the logged in user.
validate_ticket: A ToolboxTool that validates the given flight data.
checkpointer: The checkpoint saver object. This is useful for persisting
the state of the graph (e.g., as chat memory).
prompt: Initial prompt for the model. This applies to messages before they
Expand All @@ -85,11 +78,13 @@ async def create_graph(
tool_node = ToolNode(tools)

# model node
# TODO: Use .bind_tools(tools) to bind the tools with the LLM.
model = ChatVertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0)

# Bind the tools with the LLM.
model_with_tools = model.bind_tools(tools)

# Add the prompt to the model to create a model runnable
model_runnable = prompt | model
model_runnable = prompt | model_with_tools

async def acall_model(state: UserState, config: RunnableConfig):
"""
Expand All @@ -98,43 +93,6 @@ async def acall_model(state: UserState, config: RunnableConfig):
"""
messages = state["messages"]
res = await model_runnable.ainvoke({"messages": messages}, config)

# TODO: Remove the temporary fix of parsing LLM response and invoking
# tools until we use bind_tools API and have automatic response parsing
# and tool calling. (see
# https://langchain-ai.github.io/langgraph/#example)
if "```json" in res.content:
try:
response = str(res.content).replace("```json", "").replace("```", "")
json_response = json.loads(response)
action = json_response.get("action")
action_input = json_response.get("action_input")
if action == "Final Answer":
res = AIMessage(content=action_input)
else:
res = AIMessage(
content="suggesting a tool call",
tool_calls=[
ToolCall(
id=str(uuid.uuid4()), name=action, args=action_input
)
],
)
except Exception as e:
json_response = response
res = AIMessage(
content="Sorry, failed to generate the right format for response"
)

# if model exceed the number of steps and has not yet return a final answer
if state["is_last_step"] and hasattr(res, "tool_calls"):
return {
"messages": [
AIMessage(
content="Sorry, need more steps to process this request.",
)
]
}
return {"messages": [res]}

def agent_should_continue(
Expand All @@ -151,7 +109,7 @@ def agent_should_continue(
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
if tool_name in confirmation_needing_tools:
if tool_name == "Insert Ticket":
if tool_name == "insert_ticket":
return "booking_validation"
return "continue"
# Otherwise, we stop (reply to the user)
Expand All @@ -164,13 +122,12 @@ async def booking_validation_node(state: UserState, config: RunnableConfig):
"""
messages = state["messages"]
last_message = messages[-1]
user_id_token = state["user_id_token"]
if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0:
tool_call = last_message.tool_calls[0]
# Run ticket validation and return the correct ticket information
flight_info = await validate_ticket(
client, tool_call.get("args"), user_id_token
)
flight_info = await validate_ticket.ainvoke(tool_call.get("args"))
flight_info = json.loads(flight_info)
flight_info = flight_info[0]

new_message = AIMessage(
content="Please confirm if you would like to book the ticket.",
Expand Down Expand Up @@ -203,20 +160,11 @@ async def insert_ticket_node(state: UserState, config: RunnableConfig):
"""
messages = state["messages"]
last_message = messages[-1]
user_id_token = state["user_id_token"]
# Run insert ticket
if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0:
tool_call = last_message.tool_calls[0]
tool_args = tool_call.get("args")
ticket_info = TicketInfo(**tool_args)
output = await insert_ticket(client, ticket_info, user_id_token)
tool_call_id = tool_call.get("id")
tool_message = ToolMessage(
content=output, name="Insert Ticket", tool_call_id=tool_call_id
)
human_message = HumanMessage(content="Looks good to me.")
ai_message = AIMessage(content=output)
return {"messages": [human_message, tool_message, ai_message]}
await insert_ticket.ainvoke(tool_args)

# Define constant node strings
AGENT_NODE = "agent"
Expand Down
Loading