1
- import asyncio
2
1
import json
3
2
import logging
4
- from typing import AsyncGenerator , Awaitable , List
3
+ from typing import Awaitable , List
5
4
6
5
from aiostream import stream
7
- from fastapi import Request
6
+ from fastapi import BackgroundTasks , Request
8
7
from fastapi .responses import StreamingResponse
8
+ from llama_index .core .chat_engine .types import StreamingAgentChatResponse
9
+ from llama_index .core .schema import NodeWithScore
9
10
10
- from app .api .routers .models import ChatData , Message
11
+ from app .api .routers .events import EventCallbackHandler
12
+ from app .api .routers .models import ChatData , Message , SourceNodes
11
13
from app .api .services .suggestion import NextQuestionSuggestion
12
14
13
15
logger = logging .getLogger ("uvicorn" )
14
16
15
17
16
18
class VercelStreamResponse (StreamingResponse ):
17
19
"""
18
- Base class to convert the response from the chat engine to the streaming format expected by Vercel
20
+ Class to convert the response from the chat engine to the streaming format expected by Vercel
19
21
"""
20
22
21
23
TEXT_PREFIX = "0:"
22
24
DATA_PREFIX = "8:"
23
25
24
- def __init__ (self , request : Request , chat_data : ChatData , * args , ** kwargs ):
25
- self .request = request
26
- self .chat_data = chat_data
27
- content = self .content_generator (* args , ** kwargs )
26
+ def __init__ (
27
+ self ,
28
+ request : Request ,
29
+ event_handler : EventCallbackHandler ,
30
+ response : Awaitable [StreamingAgentChatResponse ],
31
+ chat_data : ChatData ,
32
+ background_tasks : BackgroundTasks ,
33
+ ):
34
+ content = VercelStreamResponse .content_generator (
35
+ request , event_handler , response , chat_data , background_tasks
36
+ )
28
37
super ().__init__ (content = content )
29
38
30
- async def content_generator (self , event_handler , events ):
31
- stream = self ._create_stream (
32
- self .request , self .chat_data , event_handler , events
39
+ @classmethod
40
+ async def content_generator (
41
+ cls ,
42
+ request : Request ,
43
+ event_handler : EventCallbackHandler ,
44
+ response : Awaitable [StreamingAgentChatResponse ],
45
+ chat_data : ChatData ,
46
+ background_tasks : BackgroundTasks ,
47
+ ):
48
+ chat_response_generator = cls ._chat_response_generator (
49
+ response , background_tasks , event_handler , chat_data
33
50
)
51
+ event_generator = cls ._event_generator (event_handler )
52
+
53
+ # Merge the chat response generator and the event generator
54
+ combine = stream .merge (chat_response_generator , event_generator )
34
55
is_stream_started = False
35
- try :
36
- async with stream .stream () as streamer :
37
- async for output in streamer :
38
- if not is_stream_started :
39
- is_stream_started = True
40
- # Stream a blank message to start the stream
41
- yield self .convert_text ("" )
42
-
43
- yield output
44
- except asyncio .CancelledError :
45
- logger .info ("Stopping workflow" )
46
- await event_handler .cancel_run ()
47
- except Exception as e :
48
- logger .error (
49
- f"Unexpected error in content_generator: { str (e )} " , exc_info = True
50
- )
51
- finally :
52
- logger .info ("The stream has been stopped!" )
56
+ async with combine .stream () as streamer :
57
+ async for output in streamer :
58
+ if not is_stream_started :
59
+ is_stream_started = True
60
+ # Stream a blank message to start displaying the response in the UI
61
+ yield cls .convert_text ("" )
53
62
54
- def _create_stream (
55
- self ,
56
- request : Request ,
63
+ yield output
64
+
65
+ if await request .is_disconnected ():
66
+ break
67
+
68
+ @classmethod
69
+ async def _event_generator (cls , event_handler : EventCallbackHandler ):
70
+ """
71
+ Yield the events from the event handler
72
+ """
73
+ async for event in event_handler .async_event_gen ():
74
+ event_response = event .to_response ()
75
+ if event_response is not None :
76
+ yield cls .convert_data (event_response )
77
+
78
+ @classmethod
79
+ async def _chat_response_generator (
80
+ cls ,
81
+ response : Awaitable [StreamingAgentChatResponse ],
82
+ background_tasks : BackgroundTasks ,
83
+ event_handler : EventCallbackHandler ,
57
84
chat_data : ChatData ,
58
- event_handler : Awaitable ,
59
- events : AsyncGenerator ,
60
- verbose : bool = True ,
61
85
):
62
- # Yield the text response
63
- async def _chat_response_generator ():
64
- result = await event_handler
65
- final_response = ""
66
-
67
- if isinstance (result , AsyncGenerator ):
68
- async for token in result :
69
- final_response += str (token .delta )
70
- yield self .convert_text (token .delta )
71
- else :
72
- if hasattr (result , "response" ):
73
- content = result .response .message .content
74
- if content :
75
- for token in content :
76
- final_response += str (token )
77
- yield self .convert_text (token )
78
-
79
- # Generate next questions if next question prompt is configured
80
- question_data = await self ._generate_next_questions (
81
- chat_data .messages , final_response
82
- )
83
- if question_data :
84
- yield self .convert_data (question_data )
86
+ """
87
+ Yield the text response and source nodes from the chat engine
88
+ """
89
+ # Wait for the response from the chat engine
90
+ result = await response
91
+
92
+ # Once we got a source node, start a background task to download the files (if needed)
93
+ cls ._process_response_nodes (result .source_nodes , background_tasks )
94
+
95
+ # Yield the source nodes
96
+ yield cls .convert_data (
97
+ {
98
+ "type" : "sources" ,
99
+ "data" : {
100
+ "nodes" : [
101
+ SourceNodes .from_source_node (node ).model_dump ()
102
+ for node in result .source_nodes
103
+ ]
104
+ },
105
+ }
106
+ )
85
107
86
- # TODO: stream sources
108
+ final_response = ""
109
+ async for token in result .async_response_gen ():
110
+ final_response += token
111
+ yield cls .convert_text (token )
87
112
88
- # Yield the events from the event handler
89
- async def _event_generator ():
90
- async for event in events :
91
- event_response = event .to_response ()
92
- if verbose :
93
- logger .debug (event_response )
94
- if event_response is not None :
95
- yield self .convert_data (event_response )
113
+ # Generate next questions if next question prompt is configured
114
+ question_data = await cls ._generate_next_questions (
115
+ chat_data .messages , final_response
116
+ )
117
+ if question_data :
118
+ yield cls .convert_data (question_data )
96
119
97
- combine = stream . merge ( _chat_response_generator (), _event_generator ())
98
- return combine
120
+ # the text_generator is the leading stream, once it's finished, also finish the event stream
121
+ event_handler . is_done = True
99
122
100
123
@classmethod
101
124
def convert_text (cls , token : str ):
@@ -108,6 +131,24 @@ def convert_data(cls, data: dict):
108
131
data_str = json .dumps (data )
109
132
return f"{ cls .DATA_PREFIX } [{ data_str } ]\n "
110
133
134
+ @staticmethod
135
+ def _process_response_nodes (
136
+ source_nodes : List [NodeWithScore ],
137
+ background_tasks : BackgroundTasks ,
138
+ ):
139
+ try :
140
+ # Start background tasks to download documents from LlamaCloud if needed
141
+ from app .engine .service import LLamaCloudFileService # type: ignore
142
+
143
+ LLamaCloudFileService .download_files_from_nodes (
144
+ source_nodes , background_tasks
145
+ )
146
+ except ImportError :
147
+ logger .debug (
148
+ "LlamaCloud is not configured. Skipping post processing of nodes"
149
+ )
150
+ pass
151
+
111
152
@staticmethod
112
153
async def _generate_next_questions (chat_history : List [Message ], response : str ):
113
154
questions = await NextQuestionSuggestion .suggest_next_questions (
0 commit comments