@@ -32,6 +32,7 @@ async def handle_sse(request):
32
32
"""
33
33
34
34
import logging
35
+ from collections .abc import Callable
35
36
from contextlib import asynccontextmanager
36
37
from typing import Any
37
38
from urllib .parse import quote
@@ -41,6 +42,7 @@ async def handle_sse(request):
41
42
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
42
43
from pydantic import ValidationError
43
44
from sse_starlette import EventSourceResponse
45
+ from starlette .background import BackgroundTask
44
46
from starlette .requests import Request
45
47
from starlette .responses import Response
46
48
from starlette .types import Receive , Scope , Send
@@ -79,7 +81,13 @@ def __init__(self, endpoint: str) -> None:
79
81
logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
80
82
81
83
@asynccontextmanager
82
- async def connect_sse (self , scope : Scope , receive : Receive , send : Send ):
84
+ async def connect_sse (
85
+ self ,
86
+ scope : Scope ,
87
+ receive : Receive ,
88
+ send : Send ,
89
+ callback : Callable [[], None ] | None = None ,
90
+ ):
83
91
if scope ["type" ] != "http" :
84
92
logger .error ("connect_sse received non-HTTP request" )
85
93
raise ValueError ("connect_sse can only handle HTTP requests" )
@@ -120,9 +128,19 @@ async def sse_writer():
120
128
}
121
129
)
122
130
131
+ async def _remove_stream_writer () -> None :
132
+ await read_stream_writer .aclose ()
133
+ await write_stream_reader .aclose ()
134
+ del self ._read_stream_writers [session_id ]
135
+ if callback :
136
+ callback ()
137
+ logger .debug (f"Closed SSE session with ID: { session_id } " )
138
+
123
139
async with anyio .create_task_group () as tg :
124
140
response = EventSourceResponse (
125
- content = sse_stream_reader , data_sender_callable = sse_writer
141
+ content = sse_stream_reader ,
142
+ data_sender_callable = sse_writer ,
143
+ background = BackgroundTask (_remove_stream_writer ),
126
144
)
127
145
logger .debug ("Starting SSE response task" )
128
146
tg .start_soon (response , scope , receive , send )
0 commit comments