Skip to content

Commit a204997

Browse files
SeanChinJunKaimichaellatman
authored andcommitted
feat: minor refactoring
1 parent 0010767 commit a204997

File tree

2 files changed

+126
-85
lines changed

2 files changed

+126
-85
lines changed

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt

Lines changed: 123 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,24 @@ import io.ktor.server.response.*
77
import io.ktor.server.sse.*
88
import io.modelcontextprotocol.kotlin.sdk.*
99
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
10+
import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID
1011
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
1112
import kotlinx.serialization.encodeToString
13+
import kotlinx.serialization.json.JsonArray
14+
import kotlinx.serialization.json.JsonElement
15+
import kotlinx.serialization.json.JsonObject
16+
import kotlinx.serialization.json.decodeFromJsonElement
1217
import kotlin.collections.HashMap
1318
import kotlin.concurrent.atomics.AtomicBoolean
1419
import kotlin.concurrent.atomics.ExperimentalAtomicApi
1520
import kotlin.uuid.ExperimentalUuidApi
1621
import kotlin.uuid.Uuid
1722

23+
/**
24+
* Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages.
25+
*
26+
* Creates a new StreamableHttp server transport.
27+
*/
1828
@OptIn(ExperimentalAtomicApi::class)
1929
public class StreamableHttpServerTransport(
2030
private val isStateful: Boolean = false,
@@ -55,7 +65,8 @@ public class StreamableHttpServerTransport(
5565
}
5666

5767
val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId")
58-
val correspondingStream = streamMapping[streamId] ?: error("No connection established for request id $requestId")
68+
val correspondingStream =
69+
streamMapping[streamId] ?: error("No connection established for request id $requestId")
5970
val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId")
6071

6172
if (!enableJSONResponse) {
@@ -66,32 +77,33 @@ public class StreamableHttpServerTransport(
6677
}
6778

6879
requestResponseMapping[requestId] = message
69-
val relatedIds = requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
80+
val relatedIds =
81+
requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
7082
val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null }
7183

72-
if (allResponsesReady) {
73-
if (enableJSONResponse) {
74-
correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString())
75-
correspondingCall.response.status(HttpStatusCode.OK)
76-
if (sessionId != null) {
77-
correspondingCall.response.header("Mcp-Session-Id", sessionId!!)
78-
}
79-
val responses = relatedIds.map{ requestResponseMapping[it] }
80-
if (responses.size == 1) {
81-
correspondingCall.respond(responses[0]!!)
82-
} else {
83-
correspondingCall.respond(responses)
84-
}
85-
callMapping.remove(streamId)
84+
if (!allResponsesReady) return
85+
86+
if (enableJSONResponse) {
87+
correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString())
88+
correspondingCall.response.status(HttpStatusCode.OK)
89+
if (sessionId != null) {
90+
correspondingCall.response.header(MCP_SESSION_ID, sessionId!!)
91+
}
92+
val responses = relatedIds.map { requestResponseMapping[it] }
93+
if (responses.size == 1) {
94+
correspondingCall.respond(responses[0]!!)
8695
} else {
87-
correspondingStream.close()
88-
streamMapping.remove(streamId)
96+
correspondingCall.respond(responses)
8997
}
98+
callMapping.remove(streamId)
99+
} else {
100+
correspondingStream.close()
101+
streamMapping.remove(streamId)
102+
}
90103

91-
for (id in relatedIds) {
92-
requestToStreamMapping.remove(id)
93-
requestResponseMapping.remove(id)
94-
}
104+
for (id in relatedIds) {
105+
requestToStreamMapping.remove(id)
106+
requestResponseMapping.remove(id)
95107
}
96108

97109
}
@@ -110,47 +122,13 @@ public class StreamableHttpServerTransport(
110122
@OptIn(ExperimentalUuidApi::class)
111123
public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) {
112124
try {
113-
val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf()
125+
if (!validateHeaders(call)) return
114126

115-
if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) {
116-
call.response.status(HttpStatusCode.NotAcceptable)
117-
call.respond(
118-
JSONRPCResponse(
119-
id = null,
120-
error = JSONRPCError(
121-
code = ErrorCode.Unknown(-32000),
122-
message = "Not Acceptable: Client must accept both application/json and text/event-stream"
123-
)
124-
)
125-
)
126-
return
127-
}
127+
val messages = parseBody(call)
128128

129-
val contentType = call.request.contentType()
130-
if (contentType != ContentType.Application.Json) {
131-
call.response.status(HttpStatusCode.UnsupportedMediaType)
132-
call.respond(
133-
JSONRPCResponse(
134-
id = null,
135-
error = JSONRPCError(
136-
code = ErrorCode.Unknown(-32000),
137-
message = "Unsupported Media Type: Content-Type must be application/json"
138-
)
139-
)
140-
)
141-
return
142-
}
143-
144-
val body = call.receiveText()
145-
val messages = mutableListOf<JSONRPCMessage>()
146-
147-
if (body.startsWith("[")) {
148-
messages.addAll(McpJson.decodeFromString<List<JSONRPCMessage>>(body))
149-
} else {
150-
messages.add(McpJson.decodeFromString(body))
151-
}
129+
if (messages.isEmpty()) return
152130

153-
val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == "initialize" }
131+
val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method.Defined.Initialize.value }
154132
if (hasInitializationRequest) {
155133
if (initialized.load() && sessionId != null) {
156134
call.response.status(HttpStatusCode.BadRequest)
@@ -184,38 +162,37 @@ public class StreamableHttpServerTransport(
184162
sessionId = Uuid.random().toString()
185163
}
186164
initialized.store(true)
165+
}
187166

188-
if (!validateSession(call)) {
189-
return
190-
}
191-
192-
val hasRequests = messages.any { it is JSONRPCRequest }
193-
val streamId = Uuid.random().toString()
167+
if (!validateSession(call)) {
168+
return
169+
}
194170

195-
if (!hasRequests){
196-
call.respondNullable(HttpStatusCode.Accepted)
197-
} else {
198-
if (!enableJSONResponse) {
199-
call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString())
171+
val hasRequests = messages.any { it is JSONRPCRequest }
172+
val streamId = Uuid.random().toString()
200173

201-
if (sessionId != null) {
202-
call.response.header("Mcp-Session-Id", sessionId!!)
203-
}
204-
}
174+
if (!hasRequests) {
175+
call.respondNullable(HttpStatusCode.Accepted)
176+
} else {
177+
if (!enableJSONResponse) {
178+
call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString())
205179

206-
for (message in messages) {
207-
if (message is JSONRPCRequest) {
208-
streamMapping[streamId] = session
209-
callMapping[streamId] = call
210-
requestToStreamMapping[message.id] = streamId
211-
}
180+
if (sessionId != null) {
181+
call.response.header(MCP_SESSION_ID, sessionId!!)
212182
}
213183
}
184+
214185
for (message in messages) {
215-
_onMessage.invoke(message)
186+
if (message is JSONRPCRequest) {
187+
streamMapping[streamId] = session
188+
callMapping[streamId] = call
189+
requestToStreamMapping[message.id] = streamId
190+
}
216191
}
217192
}
218-
193+
for (message in messages) {
194+
_onMessage.invoke(message)
195+
}
219196
} catch (e: Exception) {
220197
call.response.status(HttpStatusCode.BadRequest)
221198
call.respond(
@@ -251,7 +228,7 @@ public class StreamableHttpServerTransport(
251228
}
252229

253230
if (sessionId != null) {
254-
call.response.header("Mcp-Session-Id", sessionId!!)
231+
call.response.header(MCP_SESSION_ID, sessionId!!)
255232
}
256233

257234
if (streamMapping[standalone] != null) {
@@ -281,7 +258,7 @@ public class StreamableHttpServerTransport(
281258
call.respondNullable(HttpStatusCode.OK)
282259
}
283260

284-
public suspend fun validateSession(call: ApplicationCall): Boolean {
261+
private suspend fun validateSession(call: ApplicationCall): Boolean {
285262
if (sessionId == null) {
286263
return true
287264
}
@@ -301,4 +278,65 @@ public class StreamableHttpServerTransport(
301278
}
302279
return true
303280
}
281+
282+
private suspend fun validateHeaders(call: ApplicationCall): Boolean {
283+
val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf()
284+
285+
if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) {
286+
call.response.status(HttpStatusCode.NotAcceptable)
287+
call.respond(
288+
JSONRPCResponse(
289+
id = null,
290+
error = JSONRPCError(
291+
code = ErrorCode.Unknown(-32000),
292+
message = "Not Acceptable: Client must accept both application/json and text/event-stream"
293+
)
294+
)
295+
)
296+
return false
297+
}
298+
299+
val contentType = call.request.contentType()
300+
if (contentType != ContentType.Application.Json) {
301+
call.response.status(HttpStatusCode.UnsupportedMediaType)
302+
call.respond(
303+
JSONRPCResponse(
304+
id = null,
305+
error = JSONRPCError(
306+
code = ErrorCode.Unknown(-32000),
307+
message = "Unsupported Media Type: Content-Type must be application/json"
308+
)
309+
)
310+
)
311+
return false
312+
}
313+
314+
return true
315+
}
316+
317+
private suspend fun parseBody(
318+
call: ApplicationCall,
319+
): List<JSONRPCMessage> {
320+
val messages = mutableListOf<JSONRPCMessage>()
321+
when (val body = call.receive<JsonElement>()) {
322+
is JsonObject -> messages.add(McpJson.decodeFromJsonElement(body))
323+
is JsonArray -> messages.addAll(McpJson.decodeFromJsonElement<List<JSONRPCMessage>>(body))
324+
else -> {
325+
call.response.status(HttpStatusCode.BadRequest)
326+
call.respond(
327+
JSONRPCResponse(
328+
id = null,
329+
error = JSONRPCError(
330+
code = ErrorCode.Defined.InvalidRequest,
331+
message = "Invalid Request: Server already initialized"
332+
)
333+
)
334+
)
335+
return listOf()
336+
}
337+
}
338+
return messages
339+
}
340+
341+
304342
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package io.modelcontextprotocol.kotlin.sdk.shared
2+
3+
internal const val MCP_SESSION_ID = "mcp-session-id"

0 commit comments

Comments
 (0)