Skip to content

Commit 44ab1c9

Browse files
authored
add tunnel to optimize sse sending (#74)
1 parent 90e8c69 commit 44ab1c9

File tree

3 files changed

+188
-25
lines changed

3 files changed

+188
-25
lines changed

internal/sse/sse.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,46 @@ func FormatJSONRPCEvent(w io.Writer, eventType string, id interface{}, data inte
156156
}
157157
return nil
158158
}
159+
160+
// EventBatch represents a single event in a batch operation
161+
type EventBatch struct {
162+
EventType string
163+
ID interface{}
164+
Data interface{}
165+
}
166+
167+
// FormatJSONRPCEventBatch formats multiple JSON-RPC events in a single write operation.
168+
// This reduces the number of write calls and improves performance for high-frequency event streams.
169+
// All events are formatted into a single buffer and written once.
170+
func FormatJSONRPCEventBatch(w io.Writer, events []EventBatch) error {
171+
if len(events) == 0 {
172+
return nil
173+
}
174+
175+
// Pre-allocate buffer for better performance
176+
// Rough estimation: each event ~200-500 bytes
177+
var buf bytes.Buffer
178+
for _, event := range events {
179+
// Create a JSON-RPC response with the data as the result
180+
response := jsonrpc.NewNotificationResponse(event.ID, event.Data)
181+
// Marshal the entire JSON-RPC envelope
182+
jsonData, err := json.Marshal(response)
183+
if err != nil {
184+
return fmt.Errorf("failed to marshal JSON-RPC SSE event data: %w", err)
185+
}
186+
// Format according to text/event-stream specification
187+
// event: <eventType>
188+
// data: <jsonrpc_envelope>
189+
// <empty line>
190+
if _, err := fmt.Fprintf(&buf, "event: %s\ndata: %s\n\n", event.EventType, string(jsonData)); err != nil {
191+
return fmt.Errorf("failed to format JSON-RPC SSE event: %w", err)
192+
}
193+
}
194+
195+
// Write the entire batch in one operation
196+
if _, err := w.Write(buf.Bytes()); err != nil {
197+
return fmt.Errorf("failed to write JSON-RPC SSE event batch: %w", err)
198+
}
199+
200+
return nil
201+
}

server/server.go

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func (s *A2AServer) parseJSONRPCRequest(w http.ResponseWriter, body io.ReadClose
284284

285285
// routeJSONRPCMethod routes the request to the appropriate handler based on the method.
286286
func (s *A2AServer) routeJSONRPCMethod(ctx context.Context, w http.ResponseWriter, request jsonrpc.Request) {
287-
log.Infof("Received JSON-RPC request (ID: %v, Method: %s)", request.ID, request.Method)
287+
log.Debugf("Received JSON-RPC request (ID: %v, Method: %s)", request.ID, request.Method)
288288

289289
switch request.Method {
290290

@@ -660,30 +660,9 @@ func handleSSEStream(
660660
// Use request context to detect client disconnection.
661661
clientClosed := ctx.Done()
662662

663-
// --- Event Forwarding Loop ---
664-
for {
665-
select {
666-
case event, ok := <-eventsChan:
667-
if !ok {
668-
flusher.Flush()
669-
return // End the handler.
670-
}
671-
if err := sendSSEEvent(w, rpcID, &event); err != nil {
672-
if err == errUnknownEvent {
673-
log.Warnf("Unknown event type received for request ID: %s: %T. Skipping.", rpcID, event)
674-
continue
675-
}
676-
log.Errorf("Error writing SSE JSON-RPC event for request ID: %s (client likely disconnected): %v", rpcID, err)
677-
return
678-
}
679-
// Flush the buffer to ensure the event is sent immediately.
680-
flusher.Flush()
681-
case <-clientClosed:
682-
// Client disconnected (request context canceled).
683-
log.Infof("SSE client disconnected for request ID: %s. Closing stream.", rpcID)
684-
return // Exit the handler.
685-
}
686-
}
663+
// Use optimized tunnel for batching events
664+
tunnel := newSSETunnel(w, flusher, rpcID)
665+
tunnel.start(ctx, eventsChan, clientClosed)
687666
}
688667

689668
func sendSSEEvent(w http.ResponseWriter, rpcID string, event interface{}) error {

server/sse_tunnel.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"time"
7+
8+
"trpc.group/trpc-go/trpc-a2a-go/internal/sse"
9+
"trpc.group/trpc-go/trpc-a2a-go/log"
10+
"trpc.group/trpc-go/trpc-a2a-go/protocol"
11+
)
12+
13+
const (
14+
defaultSSEBatchSize = 20
15+
defaultSSEFlushInterval = 1 * time.Second
16+
)
17+
18+
// sseTunnel optimizes SSE streaming by batching events before sending
19+
type sseTunnel struct {
20+
w http.ResponseWriter
21+
flusher http.Flusher
22+
rpcID string
23+
batchSize int
24+
flushInterval time.Duration
25+
batch []protocol.StreamingMessageEvent
26+
}
27+
28+
// newSSETunnel creates a new SSE tunnel with default settings
29+
func newSSETunnel(w http.ResponseWriter, flusher http.Flusher, rpcID string) *sseTunnel {
30+
return &sseTunnel{
31+
w: w,
32+
flusher: flusher,
33+
rpcID: rpcID,
34+
batchSize: defaultSSEBatchSize,
35+
flushInterval: defaultSSEFlushInterval,
36+
batch: make([]protocol.StreamingMessageEvent, 0, defaultSSEBatchSize),
37+
}
38+
}
39+
40+
// start runs the SSE tunnel with event batching optimization
41+
func (t *sseTunnel) start(ctx context.Context, eventsChan <-chan protocol.StreamingMessageEvent, clientClosed <-chan struct{}) {
42+
ticker := time.NewTicker(t.flushInterval)
43+
defer ticker.Stop()
44+
45+
for {
46+
select {
47+
case event, ok := <-eventsChan:
48+
if !ok {
49+
// Channel closed, flush any remaining events and exit
50+
if len(t.batch) > 0 {
51+
t.flushBatch()
52+
}
53+
return
54+
}
55+
56+
// Add event to batch
57+
t.batch = append(t.batch, event)
58+
59+
// Flush if batch is full
60+
if len(t.batch) >= t.batchSize {
61+
ticker.Reset(t.flushInterval)
62+
if !t.flushBatch() {
63+
return // Error occurred, exit
64+
}
65+
}
66+
67+
case <-ticker.C:
68+
// Periodic flush for any accumulated events
69+
if len(t.batch) > 0 {
70+
if !t.flushBatch() {
71+
return // Error occurred, exit
72+
}
73+
}
74+
75+
case <-clientClosed:
76+
// Client disconnected
77+
log.Infof("SSE client disconnected for request ID: %s. Closing stream.", t.rpcID)
78+
return
79+
}
80+
}
81+
}
82+
83+
// flushBatch sends all events in the current batch as a single write operation
84+
func (t *sseTunnel) flushBatch() bool {
85+
if len(t.batch) == 0 {
86+
return true
87+
}
88+
89+
// Convert to batch format for efficient processing
90+
events := make([]sse.EventBatch, 0, len(t.batch))
91+
92+
// Process all events in the batch
93+
for _, event := range t.batch {
94+
eventType, err := t.getEventType(&event)
95+
if err != nil {
96+
if err == errUnknownEvent {
97+
log.Warnf("Unknown event type received for request ID: %s: %T. Skipping.", t.rpcID, event)
98+
continue
99+
}
100+
log.Errorf("Error determining event type for request ID: %s: %v", t.rpcID, err)
101+
return false
102+
}
103+
104+
// Add to batch events
105+
events = append(events, sse.EventBatch{
106+
EventType: eventType,
107+
ID: t.rpcID,
108+
Data: &event,
109+
})
110+
}
111+
112+
// Write the entire batch using optimized batch function
113+
if err := sse.FormatJSONRPCEventBatch(t.w, events); err != nil {
114+
log.Errorf("Error writing SSE batch for request ID: %s (client likely disconnected): %v", t.rpcID, err)
115+
return false
116+
}
117+
118+
// Clear the batch and flush to client
119+
t.batch = t.batch[:0]
120+
t.flusher.Flush()
121+
122+
return true
123+
}
124+
125+
// getEventType determines the SSE event type from a StreamingMessageEvent
126+
func (t *sseTunnel) getEventType(event *protocol.StreamingMessageEvent) (string, error) {
127+
actualEvent := event.Result
128+
129+
switch actualEvent.(type) {
130+
case *protocol.TaskStatusUpdateEvent:
131+
return protocol.EventStatusUpdate, nil
132+
case *protocol.TaskArtifactUpdateEvent:
133+
return protocol.EventArtifactUpdate, nil
134+
case *protocol.Message:
135+
return protocol.EventMessage, nil
136+
case *protocol.Task:
137+
return protocol.EventTask, nil
138+
default:
139+
return "", errUnknownEvent
140+
}
141+
}

0 commit comments

Comments
 (0)