Skip to content

Commit ab3cad2

Browse files
committed
fix wrong update
1 parent d6b1dcf commit ab3cad2

File tree

2 files changed

+112
-76
lines changed

2 files changed

+112
-76
lines changed

templates/components/multiagent/python/app/api/routers/vercel_response.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import json
33
import logging
4-
from typing import AsyncGenerator, Awaitable, Generator, List
4+
from typing import AsyncGenerator, Awaitable, List
55

66
from aiostream import stream
77
from app.api.routers.models import ChatData, Message
@@ -67,11 +67,6 @@ async def _chat_response_generator():
6767
async for token in result:
6868
final_response += str(token.delta)
6969
yield self.convert_text(token.delta)
70-
elif isinstance(result, Generator):
71-
for chunk in result:
72-
chunk_str = str(chunk)
73-
final_response += chunk_str
74-
yield self.convert_text(chunk_str)
7570
else:
7671
if hasattr(result, "response"):
7772
content = result.response.message.content

templates/types/streaming/fastapi/app/api/routers/vercel_response.py

Lines changed: 111 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,124 @@
1-
import asyncio
21
import json
32
import logging
4-
from typing import AsyncGenerator, Awaitable, List
3+
from typing import Awaitable, List
54

65
from aiostream import stream
7-
from fastapi import Request
6+
from fastapi import BackgroundTasks, Request
87
from fastapi.responses import StreamingResponse
8+
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
9+
from llama_index.core.schema import NodeWithScore
910

10-
from app.api.routers.models import ChatData, Message
11+
from app.api.routers.events import EventCallbackHandler
12+
from app.api.routers.models import ChatData, Message, SourceNodes
1113
from app.api.services.suggestion import NextQuestionSuggestion
1214

1315
logger = logging.getLogger("uvicorn")
1416

1517

1618
class VercelStreamResponse(StreamingResponse):
1719
"""
18-
Base class to convert the response from the chat engine to the streaming format expected by Vercel
20+
Class to convert the response from the chat engine to the streaming format expected by Vercel
1921
"""
2022

2123
TEXT_PREFIX = "0:"
2224
DATA_PREFIX = "8:"
2325

24-
def __init__(self, request: Request, chat_data: ChatData, *args, **kwargs):
25-
self.request = request
26-
self.chat_data = chat_data
27-
content = self.content_generator(*args, **kwargs)
26+
def __init__(
27+
self,
28+
request: Request,
29+
event_handler: EventCallbackHandler,
30+
response: Awaitable[StreamingAgentChatResponse],
31+
chat_data: ChatData,
32+
background_tasks: BackgroundTasks,
33+
):
34+
content = VercelStreamResponse.content_generator(
35+
request, event_handler, response, chat_data, background_tasks
36+
)
2837
super().__init__(content=content)
2938

30-
async def content_generator(self, event_handler, events):
31-
stream = self._create_stream(
32-
self.request, self.chat_data, event_handler, events
39+
@classmethod
40+
async def content_generator(
41+
cls,
42+
request: Request,
43+
event_handler: EventCallbackHandler,
44+
response: Awaitable[StreamingAgentChatResponse],
45+
chat_data: ChatData,
46+
background_tasks: BackgroundTasks,
47+
):
48+
chat_response_generator = cls._chat_response_generator(
49+
response, background_tasks, event_handler, chat_data
3350
)
51+
event_generator = cls._event_generator(event_handler)
52+
53+
# Merge the chat response generator and the event generator
54+
combine = stream.merge(chat_response_generator, event_generator)
3455
is_stream_started = False
35-
try:
36-
async with stream.stream() as streamer:
37-
async for output in streamer:
38-
if not is_stream_started:
39-
is_stream_started = True
40-
# Stream a blank message to start the stream
41-
yield self.convert_text("")
42-
43-
yield output
44-
except asyncio.CancelledError:
45-
logger.info("Stopping workflow")
46-
await event_handler.cancel_run()
47-
except Exception as e:
48-
logger.error(
49-
f"Unexpected error in content_generator: {str(e)}", exc_info=True
50-
)
51-
finally:
52-
logger.info("The stream has been stopped!")
56+
async with combine.stream() as streamer:
57+
async for output in streamer:
58+
if not is_stream_started:
59+
is_stream_started = True
60+
# Stream a blank message to start displaying the response in the UI
61+
yield cls.convert_text("")
5362

54-
def _create_stream(
55-
self,
56-
request: Request,
63+
yield output
64+
65+
if await request.is_disconnected():
66+
break
67+
68+
@classmethod
69+
async def _event_generator(cls, event_handler: EventCallbackHandler):
70+
"""
71+
Yield the events from the event handler
72+
"""
73+
async for event in event_handler.async_event_gen():
74+
event_response = event.to_response()
75+
if event_response is not None:
76+
yield cls.convert_data(event_response)
77+
78+
@classmethod
79+
async def _chat_response_generator(
80+
cls,
81+
response: Awaitable[StreamingAgentChatResponse],
82+
background_tasks: BackgroundTasks,
83+
event_handler: EventCallbackHandler,
5784
chat_data: ChatData,
58-
event_handler: Awaitable,
59-
events: AsyncGenerator,
60-
verbose: bool = True,
6185
):
62-
# Yield the text response
63-
async def _chat_response_generator():
64-
result = await event_handler
65-
final_response = ""
66-
67-
if isinstance(result, AsyncGenerator):
68-
async for token in result:
69-
final_response += str(token.delta)
70-
yield self.convert_text(token.delta)
71-
else:
72-
if hasattr(result, "response"):
73-
content = result.response.message.content
74-
if content:
75-
for token in content:
76-
final_response += str(token)
77-
yield self.convert_text(token)
78-
79-
# Generate next questions if next question prompt is configured
80-
question_data = await self._generate_next_questions(
81-
chat_data.messages, final_response
82-
)
83-
if question_data:
84-
yield self.convert_data(question_data)
86+
"""
87+
Yield the text response and source nodes from the chat engine
88+
"""
89+
# Wait for the response from the chat engine
90+
result = await response
91+
92+
# Once we got a source node, start a background task to download the files (if needed)
93+
cls._process_response_nodes(result.source_nodes, background_tasks)
94+
95+
# Yield the source nodes
96+
yield cls.convert_data(
97+
{
98+
"type": "sources",
99+
"data": {
100+
"nodes": [
101+
SourceNodes.from_source_node(node).model_dump()
102+
for node in result.source_nodes
103+
]
104+
},
105+
}
106+
)
85107

86-
# TODO: stream sources
108+
final_response = ""
109+
async for token in result.async_response_gen():
110+
final_response += token
111+
yield cls.convert_text(token)
87112

88-
# Yield the events from the event handler
89-
async def _event_generator():
90-
async for event in events:
91-
event_response = event.to_response()
92-
if verbose:
93-
logger.debug(event_response)
94-
if event_response is not None:
95-
yield self.convert_data(event_response)
113+
# Generate next questions if next question prompt is configured
114+
question_data = await cls._generate_next_questions(
115+
chat_data.messages, final_response
116+
)
117+
if question_data:
118+
yield cls.convert_data(question_data)
96119

97-
combine = stream.merge(_chat_response_generator(), _event_generator())
98-
return combine
120+
# the text_generator is the leading stream, once it's finished, also finish the event stream
121+
event_handler.is_done = True
99122

100123
@classmethod
101124
def convert_text(cls, token: str):
@@ -108,6 +131,24 @@ def convert_data(cls, data: dict):
108131
data_str = json.dumps(data)
109132
return f"{cls.DATA_PREFIX}[{data_str}]\n"
110133

134+
@staticmethod
135+
def _process_response_nodes(
136+
source_nodes: List[NodeWithScore],
137+
background_tasks: BackgroundTasks,
138+
):
139+
try:
140+
# Start background tasks to download documents from LlamaCloud if needed
141+
from app.engine.service import LLamaCloudFileService # type: ignore
142+
143+
LLamaCloudFileService.download_files_from_nodes(
144+
source_nodes, background_tasks
145+
)
146+
except ImportError:
147+
logger.debug(
148+
"LlamaCloud is not configured. Skipping post processing of nodes"
149+
)
150+
pass
151+
111152
@staticmethod
112153
async def _generate_next_questions(chat_history: List[Message], response: str):
113154
questions = await NextQuestionSuggestion.suggest_next_questions(

0 commit comments

Comments
 (0)