@@ -7,14 +7,24 @@ import io.ktor.server.response.*
7
7
import io.ktor.server.sse.*
8
8
import io.modelcontextprotocol.kotlin.sdk.*
9
9
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
10
+ import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID
10
11
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
11
12
import kotlinx.serialization.encodeToString
13
+ import kotlinx.serialization.json.JsonArray
14
+ import kotlinx.serialization.json.JsonElement
15
+ import kotlinx.serialization.json.JsonObject
16
+ import kotlinx.serialization.json.decodeFromJsonElement
12
17
import kotlin.collections.HashMap
13
18
import kotlin.concurrent.atomics.AtomicBoolean
14
19
import kotlin.concurrent.atomics.ExperimentalAtomicApi
15
20
import kotlin.uuid.ExperimentalUuidApi
16
21
import kotlin.uuid.Uuid
17
22
23
+ /* *
24
+ * Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages.
25
+ *
26
+ * Creates a new StreamableHttp server transport.
27
+ */
18
28
@OptIn(ExperimentalAtomicApi ::class )
19
29
public class StreamableHttpServerTransport (
20
30
private val isStateful : Boolean = false ,
@@ -55,7 +65,8 @@ public class StreamableHttpServerTransport(
55
65
}
56
66
57
67
val streamId = requestToStreamMapping[requestId] ? : error(" No connection established for request id $requestId " )
58
- val correspondingStream = streamMapping[streamId] ? : error(" No connection established for request id $requestId " )
68
+ val correspondingStream =
69
+ streamMapping[streamId] ? : error(" No connection established for request id $requestId " )
59
70
val correspondingCall = callMapping[streamId] ? : error(" No connection established for request id $requestId " )
60
71
61
72
if (! enableJSONResponse) {
@@ -66,32 +77,33 @@ public class StreamableHttpServerTransport(
66
77
}
67
78
68
79
requestResponseMapping[requestId] = message
69
- val relatedIds = requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
80
+ val relatedIds =
81
+ requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
70
82
val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null }
71
83
72
- if (allResponsesReady) {
73
- if (enableJSONResponse) {
74
- correspondingCall.response.headers.append(ContentType .toString(), ContentType .Application .Json .toString())
75
- correspondingCall.response.status(HttpStatusCode .OK )
76
- if (sessionId != null ) {
77
- correspondingCall.response.header(" Mcp-Session-Id" , sessionId!! )
78
- }
79
- val responses = relatedIds.map{ requestResponseMapping[it] }
80
- if (responses.size == 1 ) {
81
- correspondingCall.respond(responses[0 ]!! )
82
- } else {
83
- correspondingCall.respond(responses)
84
- }
85
- callMapping.remove(streamId)
84
+ if (! allResponsesReady) return
85
+
86
+ if (enableJSONResponse) {
87
+ correspondingCall.response.headers.append(ContentType .toString(), ContentType .Application .Json .toString())
88
+ correspondingCall.response.status(HttpStatusCode .OK )
89
+ if (sessionId != null ) {
90
+ correspondingCall.response.header(MCP_SESSION_ID , sessionId!! )
91
+ }
92
+ val responses = relatedIds.map { requestResponseMapping[it] }
93
+ if (responses.size == 1 ) {
94
+ correspondingCall.respond(responses[0 ]!! )
86
95
} else {
87
- correspondingStream.close()
88
- streamMapping.remove(streamId)
96
+ correspondingCall.respond(responses)
89
97
}
98
+ callMapping.remove(streamId)
99
+ } else {
100
+ correspondingStream.close()
101
+ streamMapping.remove(streamId)
102
+ }
90
103
91
- for (id in relatedIds) {
92
- requestToStreamMapping.remove(id)
93
- requestResponseMapping.remove(id)
94
- }
104
+ for (id in relatedIds) {
105
+ requestToStreamMapping.remove(id)
106
+ requestResponseMapping.remove(id)
95
107
}
96
108
97
109
}
@@ -110,47 +122,13 @@ public class StreamableHttpServerTransport(
110
122
@OptIn(ExperimentalUuidApi ::class )
111
123
public suspend fun handlePostRequest (call : ApplicationCall , session : ServerSSESession ) {
112
124
try {
113
- val acceptHeader = call.request.headers[ " Accept " ]?.split( " , " ) ? : listOf ()
125
+ if ( ! validateHeaders(call)) return
114
126
115
- if (! acceptHeader.contains(" text/event-stream" ) || ! acceptHeader.contains(" application/json" )) {
116
- call.response.status(HttpStatusCode .NotAcceptable )
117
- call.respond(
118
- JSONRPCResponse (
119
- id = null ,
120
- error = JSONRPCError (
121
- code = ErrorCode .Unknown (- 32000 ),
122
- message = " Not Acceptable: Client must accept both application/json and text/event-stream"
123
- )
124
- )
125
- )
126
- return
127
- }
127
+ val messages = parseBody(call)
128
128
129
- val contentType = call.request.contentType()
130
- if (contentType != ContentType .Application .Json ) {
131
- call.response.status(HttpStatusCode .UnsupportedMediaType )
132
- call.respond(
133
- JSONRPCResponse (
134
- id = null ,
135
- error = JSONRPCError (
136
- code = ErrorCode .Unknown (- 32000 ),
137
- message = " Unsupported Media Type: Content-Type must be application/json"
138
- )
139
- )
140
- )
141
- return
142
- }
143
-
144
- val body = call.receiveText()
145
- val messages = mutableListOf<JSONRPCMessage >()
146
-
147
- if (body.startsWith(" [" )) {
148
- messages.addAll(McpJson .decodeFromString<List <JSONRPCMessage >>(body))
149
- } else {
150
- messages.add(McpJson .decodeFromString(body))
151
- }
129
+ if (messages.isEmpty()) return
152
130
153
- val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == " initialize " }
131
+ val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method . Defined . Initialize .value }
154
132
if (hasInitializationRequest) {
155
133
if (initialized.load() && sessionId != null ) {
156
134
call.response.status(HttpStatusCode .BadRequest )
@@ -184,38 +162,37 @@ public class StreamableHttpServerTransport(
184
162
sessionId = Uuid .random().toString()
185
163
}
186
164
initialized.store(true )
165
+ }
187
166
188
- if (! validateSession(call)) {
189
- return
190
- }
191
-
192
- val hasRequests = messages.any { it is JSONRPCRequest }
193
- val streamId = Uuid .random().toString()
167
+ if (! validateSession(call)) {
168
+ return
169
+ }
194
170
195
- if (! hasRequests){
196
- call.respondNullable(HttpStatusCode .Accepted )
197
- } else {
198
- if (! enableJSONResponse) {
199
- call.response.headers.append(ContentType .toString(), ContentType .Text .EventStream .toString())
171
+ val hasRequests = messages.any { it is JSONRPCRequest }
172
+ val streamId = Uuid .random().toString()
200
173
201
- if (sessionId != null ) {
202
- call.response.header(" Mcp-Session-Id" , sessionId!! )
203
- }
204
- }
174
+ if (! hasRequests) {
175
+ call.respondNullable(HttpStatusCode .Accepted )
176
+ } else {
177
+ if (! enableJSONResponse) {
178
+ call.response.headers.append(ContentType .toString(), ContentType .Text .EventStream .toString())
205
179
206
- for (message in messages) {
207
- if (message is JSONRPCRequest ) {
208
- streamMapping[streamId] = session
209
- callMapping[streamId] = call
210
- requestToStreamMapping[message.id] = streamId
211
- }
180
+ if (sessionId != null ) {
181
+ call.response.header(MCP_SESSION_ID , sessionId!! )
212
182
}
213
183
}
184
+
214
185
for (message in messages) {
215
- _onMessage .invoke(message)
186
+ if (message is JSONRPCRequest ) {
187
+ streamMapping[streamId] = session
188
+ callMapping[streamId] = call
189
+ requestToStreamMapping[message.id] = streamId
190
+ }
216
191
}
217
192
}
218
-
193
+ for (message in messages) {
194
+ _onMessage .invoke(message)
195
+ }
219
196
} catch (e: Exception ) {
220
197
call.response.status(HttpStatusCode .BadRequest )
221
198
call.respond(
@@ -251,7 +228,7 @@ public class StreamableHttpServerTransport(
251
228
}
252
229
253
230
if (sessionId != null ) {
254
- call.response.header(" Mcp-Session-Id " , sessionId!! )
231
+ call.response.header(MCP_SESSION_ID , sessionId!! )
255
232
}
256
233
257
234
if (streamMapping[standalone] != null ) {
@@ -281,7 +258,7 @@ public class StreamableHttpServerTransport(
281
258
call.respondNullable(HttpStatusCode .OK )
282
259
}
283
260
284
- public suspend fun validateSession (call : ApplicationCall ): Boolean {
261
+ private suspend fun validateSession (call : ApplicationCall ): Boolean {
285
262
if (sessionId == null ) {
286
263
return true
287
264
}
@@ -301,4 +278,65 @@ public class StreamableHttpServerTransport(
301
278
}
302
279
return true
303
280
}
281
+
282
+ private suspend fun validateHeaders (call : ApplicationCall ): Boolean {
283
+ val acceptHeader = call.request.headers[" Accept" ]?.split(" ," ) ? : listOf ()
284
+
285
+ if (! acceptHeader.contains(" text/event-stream" ) || ! acceptHeader.contains(" application/json" )) {
286
+ call.response.status(HttpStatusCode .NotAcceptable )
287
+ call.respond(
288
+ JSONRPCResponse (
289
+ id = null ,
290
+ error = JSONRPCError (
291
+ code = ErrorCode .Unknown (- 32000 ),
292
+ message = " Not Acceptable: Client must accept both application/json and text/event-stream"
293
+ )
294
+ )
295
+ )
296
+ return false
297
+ }
298
+
299
+ val contentType = call.request.contentType()
300
+ if (contentType != ContentType .Application .Json ) {
301
+ call.response.status(HttpStatusCode .UnsupportedMediaType )
302
+ call.respond(
303
+ JSONRPCResponse (
304
+ id = null ,
305
+ error = JSONRPCError (
306
+ code = ErrorCode .Unknown (- 32000 ),
307
+ message = " Unsupported Media Type: Content-Type must be application/json"
308
+ )
309
+ )
310
+ )
311
+ return false
312
+ }
313
+
314
+ return true
315
+ }
316
+
317
+ private suspend fun parseBody (
318
+ call : ApplicationCall ,
319
+ ): List <JSONRPCMessage > {
320
+ val messages = mutableListOf<JSONRPCMessage >()
321
+ when (val body = call.receive<JsonElement >()) {
322
+ is JsonObject -> messages.add(McpJson .decodeFromJsonElement(body))
323
+ is JsonArray -> messages.addAll(McpJson .decodeFromJsonElement<List <JSONRPCMessage >>(body))
324
+ else -> {
325
+ call.response.status(HttpStatusCode .BadRequest )
326
+ call.respond(
327
+ JSONRPCResponse (
328
+ id = null ,
329
+ error = JSONRPCError (
330
+ code = ErrorCode .Defined .InvalidRequest ,
331
+ message = " Invalid Request: Server already initialized"
332
+ )
333
+ )
334
+ )
335
+ return listOf ()
336
+ }
337
+ }
338
+ return messages
339
+ }
340
+
341
+
304
342
}
0 commit comments