Skip to content
Closed
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ OPENWEATHERMAP_API_KEY=
# LANGFUSE Configuration
#LANGFUSE_TRACING=true
#LANGFUSE_PUBLIC_KEY=pk-...
#LANGFUSE_SECRET_KEY=sk-lf-....
#LANGFUSE_SECRET_KEY=sk-lf-....
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dependencies = [
"pandas ~=2.2.3",
"psycopg[binary,pool] ~=3.2.4",
"pyarrow >=18.1.0",
"pydantic ~=2.10.1",
"pydantic ~=2.11.1",
"pydantic-settings ~=2.6.1",
"pypdf ~=5.3.0",
"pyowm ~=3.3.0",
Expand All @@ -55,7 +55,7 @@ dependencies = [
"streamlit ~=1.40.1",
"tiktoken >=0.8.0",
"uvicorn ~=0.32.1",

"ag-ui-protocol>=0.1.5",
]

[dependency-groups]
Expand All @@ -74,7 +74,7 @@ dev = [
# To install run: `uv sync --frozen --only-group client`
client = [
"httpx~=0.27.2",
"pydantic ~=2.10.1",
"pydantic~=2.11.1",
"python-dotenv ~=1.0.1",
"streamlit~=1.40.1",
]
Expand All @@ -100,3 +100,7 @@ exclude = "src/streamlit_app.py"
[[tool.mypy.overrides]]
module = ["numexpr.*"]
follow_untyped_imports = true

[[tool.mypy.overrides]]
module = ["ag_ui.*"]
ignore_missing_imports = true
11 changes: 11 additions & 0 deletions src/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from core.settings import settings
from schema.models import (
AlibabaQWenModelName,
AllModelEnum,
AnthropicModelName,
AWSModelName,
Expand All @@ -31,6 +32,7 @@
| {m: m.value for m in OpenAICompatibleName}
| {m: m.value for m in AzureOpenAIModelName}
| {m: m.value for m in DeepseekModelName}
| {m: m.value for m in AlibabaQWenModelName}
| {m: m.value for m in AnthropicModelName}
| {m: m.value for m in GoogleModelName}
| {m: m.value for m in VertexAIModelName}
Expand Down Expand Up @@ -104,6 +106,15 @@
openai_api_base="https://api.deepseek.com",
openai_api_key=settings.DEEPSEEK_API_KEY,
)
if model_name in AlibabaQWenModelName:
return ChatOpenAI(

Check warning on line 110 in src/core/llm.py

View check run for this annotation

Codecov / codecov/patch

src/core/llm.py#L110

Added line #L110 was not covered by tests
model=api_model_name,
temperature=0.5,
streaming=True,
openai_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1/",
openai_api_key=settings.ALIBABA_QWEN_API_KEY,
# extra_body={"top_k": 5},
)
if model_name in AnthropicModelName:
return ChatAnthropic(model=api_model_name, temperature=0.5, streaming=True)
if model_name in GoogleModelName:
Expand Down
7 changes: 7 additions & 0 deletions src/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic_settings import BaseSettings, SettingsConfigDict

from schema.models import (
AlibabaQWenModelName,
AllModelEnum,
AnthropicModelName,
AWSModelName,
Expand Down Expand Up @@ -58,6 +59,7 @@

OPENAI_API_KEY: SecretStr | None = None
DEEPSEEK_API_KEY: SecretStr | None = None
ALIBABA_QWEN_API_KEY: SecretStr | None = None
ANTHROPIC_API_KEY: SecretStr | None = None
GOOGLE_API_KEY: SecretStr | None = None
GOOGLE_APPLICATION_CREDENTIALS: SecretStr | None = None
Expand Down Expand Up @@ -124,6 +126,7 @@
Provider.OPENAI: self.OPENAI_API_KEY,
Provider.OPENAI_COMPATIBLE: self.COMPATIBLE_BASE_URL and self.COMPATIBLE_MODEL,
Provider.DEEPSEEK: self.DEEPSEEK_API_KEY,
Provider.ALIBABA_QWEN: self.ALIBABA_QWEN_API_KEY,
Provider.ANTHROPIC: self.ANTHROPIC_API_KEY,
Provider.GOOGLE: self.GOOGLE_API_KEY,
Provider.VERTEXAI: self.GOOGLE_APPLICATION_CREDENTIALS,
Expand Down Expand Up @@ -151,6 +154,10 @@
if self.DEFAULT_MODEL is None:
self.DEFAULT_MODEL = DeepseekModelName.DEEPSEEK_CHAT
self.AVAILABLE_MODELS.update(set(DeepseekModelName))
case Provider.ALIBABA_QWEN:
if self.DEFAULT_MODEL is None:
self.DEFAULT_MODEL = AlibabaQWenModelName.QWEN_PLUS
self.AVAILABLE_MODELS.update(set(AlibabaQWenModelName))

Check warning on line 160 in src/core/settings.py

View check run for this annotation

Codecov / codecov/patch

src/core/settings.py#L158-L160

Added lines #L158 - L160 were not covered by tests
case Provider.ANTHROPIC:
if self.DEFAULT_MODEL is None:
self.DEFAULT_MODEL = AnthropicModelName.HAIKU_3
Expand Down
10 changes: 10 additions & 0 deletions src/schema/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Provider(StrEnum):
OPENAI_COMPATIBLE = auto()
AZURE_OPENAI = auto()
DEEPSEEK = auto()
ALIBABA_QWEN = auto()
ANTHROPIC = auto()
GOOGLE = auto()
VERTEXAI = auto()
Expand Down Expand Up @@ -36,6 +37,14 @@ class DeepseekModelName(StrEnum):
DEEPSEEK_CHAT = "deepseek-chat"


class AlibabaQWenModelName(StrEnum):
"""https://help.aliyun.com/zh/model-studio/user-guide/text-generation/"""

QWEN_MAX = "qwen-max"
QWEN_PLUS = "qwen-plus"
QWEN_TURBO = "qwen-turbo"


class AnthropicModelName(StrEnum):
"""https://docs.anthropic.com/en/docs/about-claude/models#model-names"""

Expand Down Expand Up @@ -102,6 +111,7 @@ class FakeModelName(StrEnum):
| OpenAICompatibleName
| AzureOpenAIModelName
| DeepseekModelName
| AlibabaQWenModelName
| AnthropicModelName
| GoogleModelName
| VertexAIModelName
Expand Down
5 changes: 5 additions & 0 deletions src/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class StreamInput(UserInput):
default=True,
)

stream_protocol: Literal["sse", "agui"] = Field(
description="The protocol to use for streaming the agent's response.",
default="sse",
)


class ToolCall(TypedDict):
"""Represents a request to call a tool."""
Expand Down
149 changes: 145 additions & 4 deletions src/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
from typing import Annotated, Any
from uuid import UUID, uuid4

# AG-UI Protocol imports
from ag_ui.core import (
EventType,
RunErrorEvent,
RunFinishedEvent,
RunStartedEvent,
)
from ag_ui.core.events import TextMessageChunkEvent
from ag_ui.encoder import EventEncoder
from fastapi import APIRouter, Depends, FastAPI, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down Expand Up @@ -34,6 +43,7 @@
)
from service.utils import (
convert_message_content_to_string,
convert_message_to_agui_events,
langchain_to_chat_message,
remove_tool_calls,
)
Expand Down Expand Up @@ -305,6 +315,129 @@
yield "data: [DONE]\n\n"


async def message_generator_agui(
user_input: StreamInput,
agent_id: str = DEFAULT_AGENT,
) -> AsyncGenerator[str, None]:
"""
Generate a stream of messages from the agent using the AG-UI protocol.
This closely mirrors message_generator but outputs AG-UI compatible events.
"""
agent: Pregel = get_agent(agent_id)
kwargs, run_id = await _handle_input(user_input, agent)

# Create event encoder for AG-UI protocol
encoder = EventEncoder()

try:
# Send run started event
yield encoder.encode(
RunStartedEvent(
type=EventType.RUN_STARTED,
thread_id=kwargs["config"]["configurable"]["thread_id"],
run_id=str(run_id),
)
)

# Process streamed events from the graph - same structure as message_generator
async for stream_event in agent.astream(
**kwargs, stream_mode=["updates", "messages", "custom"]
):
if not isinstance(stream_event, tuple):
continue

Check warning on line 348 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L348

Added line #L348 was not covered by tests
stream_mode, event = stream_event
new_messages = []

# Handle updates - same logic as original
if stream_mode == "updates":
for node, updates in event.items():
if node == "__interrupt__":
interrupt: Interrupt
for interrupt in updates:
new_messages.append(AIMessage(content=interrupt.value))
continue
updates = updates or {}
update_messages = updates.get("messages", [])
# special cases for using langgraph-supervisor library
if node == "supervisor":
ai_messages = [msg for msg in update_messages if isinstance(msg, AIMessage)]
if ai_messages:
update_messages = [ai_messages[-1]]

Check warning on line 366 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L364-L366

Added lines #L364 - L366 were not covered by tests
if node in ("research_expert", "math_expert"):
msg = ToolMessage(

Check warning on line 368 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L368

Added line #L368 was not covered by tests
content=update_messages[0].content,
name=node,
tool_call_id="",
)
update_messages = [msg]

Check warning on line 373 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L373

Added line #L373 was not covered by tests
new_messages.extend(update_messages)

# Handle custom events - same logic as original
if stream_mode == "custom":
new_messages = [event]

# Process message parts - similar to original but simpler
processed_messages = []
current_message: dict[str, Any] = {}
for message in new_messages:
if isinstance(message, tuple):
key, value = message
current_message[key] = value

Check warning on line 386 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L385-L386

Added lines #L385 - L386 were not covered by tests
else:
if current_message:
processed_messages.append(_create_ai_message(current_message))
current_message = {}

Check warning on line 390 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L389-L390

Added lines #L389 - L390 were not covered by tests
processed_messages.append(message)

if current_message:
processed_messages.append(_create_ai_message(current_message))

Check warning on line 394 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L394

Added line #L394 was not covered by tests

# Convert messages to AG-UI events - this is the main difference
for message in processed_messages:
# Skip re-sent input messages
if isinstance(message, HumanMessage) and message.content == user_input.message:
continue

Check warning on line 400 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L400

Added line #L400 was not covered by tests

# Convert each message to appropriate AG-UI events
async for agui_event in convert_message_to_agui_events(message, encoder):
yield agui_event

# Handle token streaming - similar to original
if stream_mode == "messages":
if not user_input.stream_tokens:
continue
msg, metadata = event
if "skip_stream" in metadata.get("tags", []):
continue

Check warning on line 412 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L412

Added line #L412 was not covered by tests
if not isinstance(msg, AIMessageChunk):
continue

Check warning on line 414 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L414

Added line #L414 was not covered by tests
content = remove_tool_calls(msg.content)
if content:
# Convert to AG-UI text content event
message_id = str(uuid4())
yield encoder.encode(
TextMessageChunkEvent(
type=EventType.TEXT_MESSAGE_CHUNK,
message_id=message_id,
delta=convert_message_content_to_string(content),
)
)

# Send run finished event
yield encoder.encode(
RunFinishedEvent(
type=EventType.RUN_FINISHED,
thread_id=kwargs["config"]["configurable"]["thread_id"],
run_id=str(run_id),
)
)

except Exception as e:
logger.error(f"Error in AG-UI message generator: {e}")
yield encoder.encode(RunErrorEvent(type=EventType.RUN_ERROR, message=str(e)))

Check warning on line 438 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L436-L438

Added lines #L436 - L438 were not covered by tests


def _create_ai_message(parts: dict) -> AIMessage:
sig = inspect.signature(AIMessage)
valid_keys = set(sig.parameters)
Expand Down Expand Up @@ -343,10 +476,18 @@
Set `stream_tokens=false` to return intermediate messages but not token-by-token.
"""
return StreamingResponse(
message_generator(user_input, agent_id),
media_type="text/event-stream",
)
if user_input.stream_protocol == "sse":
return StreamingResponse(
message_generator(user_input, agent_id),
media_type="text/event-stream",
)
elif user_input.stream_protocol == "agui":
return StreamingResponse(
message_generator_agui(user_input, agent_id),
media_type="text/event-stream",
)
else:
raise HTTPException(status_code=400, detail="Invalid stream protocol")

Check warning on line 490 in src/service/service.py

View check run for this annotation

Codecov / codecov/patch

src/service/service.py#L490

Added line #L490 was not covered by tests


@router.post("/feedback")
Expand Down
Loading