Skip to content

Commit 4d0a85f

Browse files
committed
Handle types we coerce to Strings in MCP
1 parent 0f29501 commit 4d0a85f

File tree

2 files changed

+210
-6
lines changed

2 files changed

+210
-6
lines changed

mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServer.java

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import java.io.OutputStream;
1010
import java.io.PrintWriter;
1111
import java.io.StringWriter;
12+
import java.math.BigDecimal;
13+
import java.math.BigInteger;
1214
import java.nio.charset.StandardCharsets;
1315
import java.util.ArrayList;
1416
import java.util.HashMap;
@@ -155,9 +157,9 @@ private void handleRequest(JsonRpcRequest req) {
155157
} else {
156158
// Handle locally
157159
var operation = tool.operation();
158-
var input = req.getParams()
159-
.getMember("arguments")
160-
.asShape(operation.getApiOperation().inputBuilder());
160+
var argumentsDoc = req.getParams().getMember("arguments");
161+
var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema());
162+
var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder());
161163
var output = operation.function().apply(input, null);
162164
var result = CallToolResult.builder()
163165
.content(List.of(TextContent.builder()
@@ -417,14 +419,14 @@ public void awaitCompletion() throws InterruptedException {
417419
done.await();
418420
}
419421

420-
private record Tool(ToolInfo toolInfo, Operation operation, McpServerProxy proxy) {
422+
private record Tool(ToolInfo toolInfo, Operation operation, McpServerProxy proxy, boolean requiredAdapting) {
421423

422424
Tool(ToolInfo toolInfo, Operation operation) {
423-
this(toolInfo, operation, null);
425+
this(toolInfo, operation, null, false);
424426
}
425427

426428
Tool(ToolInfo toolInfo, McpServerProxy proxy) {
427-
this(toolInfo, null, proxy);
429+
this(toolInfo, null, proxy, false);
428430
}
429431
}
430432

@@ -436,6 +438,62 @@ private static String appendSentences(String first, String second) {
436438
return first + second;
437439
}
438440

441+
private static Document adaptDocument(Document doc, Schema schema) {
442+
var fromType = doc.type();
443+
var toType = schema.type();
444+
return switch (toType) {
445+
case BIG_DECIMAL -> switch (fromType) {
446+
case STRING -> Document.of(new BigDecimal(doc.asString()));
447+
case BIG_INTEGER -> doc;
448+
default -> badType(fromType, toType);
449+
};
450+
case BIG_INTEGER ->
451+
switch (fromType) {
452+
case STRING -> Document.of(new BigInteger(doc.asString()));
453+
case BIG_INTEGER -> doc;
454+
default -> badType(fromType, toType);
455+
};
456+
case BLOB -> switch (fromType) {
457+
case STRING -> Document.of(doc.asString().getBytes(StandardCharsets.UTF_8));
458+
case BLOB -> doc;
459+
default -> badType(fromType, toType);
460+
};
461+
case STRUCTURE, UNION -> {
462+
var convertedMembers = new HashMap<String, Document>();
463+
var members = schema.members();
464+
for (var member : members) {
465+
var memberName = member.memberName();
466+
var memberDoc = doc.getMember(memberName);
467+
if (memberDoc != null) {
468+
convertedMembers.put(memberName, adaptDocument(memberDoc, member.memberTarget()));
469+
}
470+
}
471+
yield Document.of(convertedMembers);
472+
}
473+
case LIST, SET -> {
474+
var listMember = schema.listMember();
475+
var convertedList = new ArrayList<Document>();
476+
for (var item : doc.asList()) {
477+
convertedList.add(adaptDocument(item, listMember.memberTarget()));
478+
}
479+
yield Document.of(convertedList);
480+
}
481+
case MAP -> {
482+
var mapValue = schema.mapValueMember();
483+
var convertedMap = new HashMap<String, Document>();
484+
for (var entry : doc.asStringMap().entrySet()) {
485+
convertedMap.put(entry.getKey(), adaptDocument(entry.getValue(), mapValue.memberTarget()));
486+
}
487+
yield Document.of(convertedMap);
488+
}
489+
default -> doc;
490+
};
491+
}
492+
493+
private static Document badType(ShapeType from, ShapeType to) {
494+
throw new RuntimeException("Cannot convert from " + from + " to " + to);
495+
}
496+
439497
public static McpServerBuilder builder() {
440498
return new McpServerBuilder();
441499
}

mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,22 @@
66
package software.amazon.smithy.java.mcp.server;
77

88
import static org.junit.jupiter.api.Assertions.assertEquals;
9+
import static org.junit.jupiter.api.Assertions.assertNotNull;
910
import static org.junit.jupiter.api.Assertions.assertTrue;
1011

12+
import java.math.BigDecimal;
13+
import java.math.BigInteger;
14+
import java.nio.charset.StandardCharsets;
15+
import java.util.List;
1116
import java.util.Map;
17+
import java.util.concurrent.atomic.AtomicReference;
1218
import org.junit.jupiter.api.AfterEach;
1319
import org.junit.jupiter.api.BeforeEach;
1420
import org.junit.jupiter.api.Test;
21+
import software.amazon.smithy.java.client.core.interceptors.ClientInterceptor;
22+
import software.amazon.smithy.java.client.core.interceptors.InputHook;
1523
import software.amazon.smithy.java.core.serde.document.Document;
24+
import software.amazon.smithy.java.dynamicschemas.StructDocument;
1625
import software.amazon.smithy.java.json.JsonCodec;
1726
import software.amazon.smithy.java.json.JsonSettings;
1827
import software.amazon.smithy.java.mcp.model.JsonRpcRequest;
@@ -21,6 +30,7 @@
2130
import software.amazon.smithy.java.server.Server;
2231
import software.amazon.smithy.model.Model;
2332
import software.amazon.smithy.model.shapes.ShapeId;
33+
import software.amazon.smithy.model.shapes.ShapeType;
2434

2535
public class McpServerTest {
2636
private static final JsonCodec CODEC = JsonCodec.builder()
@@ -84,6 +94,9 @@ public void validateToolStructure() {
8494
var list = properties.get("list").asStringMap();
8595
assertEquals("array", list.get("type").asString());
8696

97+
var bigDecimal = properties.get("bigDecimalField").asStringMap();
98+
assertEquals("string", bigDecimal.get("type").asString());
99+
87100
var listItems = list.get("items").asStringMap();
88101
assertEquals("object", listItems.get("type").asString());
89102
var listItemProperties = listItems.get("properties").asStringMap();
@@ -106,6 +119,111 @@ public void validateToolStructure() {
106119
validateNestedStructure(doubleNestedProperties);
107120
}
108121

122+
@Test
123+
void testInputAdaptation() {
124+
AtomicReference<StructDocument> capturedInput = new AtomicReference<>();
125+
server = McpServer.builder()
126+
.input(input)
127+
.output(output)
128+
.addService(ProxyService.builder()
129+
.service(ShapeId.from("smithy.test#TestService"))
130+
.proxyEndpoint("http://localhost")
131+
.clientConfigurator(
132+
clientConfigurator -> clientConfigurator.addInterceptor(new ClientInterceptor() {
133+
@Override
134+
public void readBeforeSerialization(InputHook<?, ?> hook) {
135+
capturedInput.set((StructDocument) hook.input());
136+
}
137+
}))
138+
.model(MODEL)
139+
.build())
140+
.build();
141+
142+
server.start();
143+
144+
var bigDecimalValue = BigDecimal.valueOf(Integer.MAX_VALUE).add(BigDecimal.TEN);
145+
var bigIntegerValue = BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.valueOf(100));
146+
var blobValue = "Hello, World!";
147+
var nestedBigDecimalValue = new BigDecimal("123.456");
148+
var nestedBigIntegerValue = new BigInteger("9876543210");
149+
var nestedBlobValue = "Nested blob content";
150+
151+
write("tools/call",
152+
Document.of(
153+
Map.of("name",
154+
Document.of("TestOperation"),
155+
"arguments",
156+
Document.of(Map.of(
157+
"bigDecimalField",
158+
Document.of(bigDecimalValue.toString()),
159+
"bigIntegerField",
160+
Document.of(bigIntegerValue.toString()),
161+
"blobField",
162+
Document.of(blobValue),
163+
"nestedWithBigNumbers",
164+
Document.of(Map.of(
165+
"nestedBigDecimal",
166+
Document.of(nestedBigDecimalValue.toString()),
167+
"nestedBigInteger",
168+
Document.of(nestedBigIntegerValue.toString()),
169+
"nestedBlob",
170+
Document.of(nestedBlobValue),
171+
"bigDecimalList",
172+
Document.of(List.of(
173+
Document.of("100.25"),
174+
Document.of("200.75"))))))))));
175+
assertNotNull(read());
176+
var inputDocument = capturedInput.get();
177+
178+
var bigDecimalField = inputDocument.getMember("bigDecimalField");
179+
assertNotNull(bigDecimalField);
180+
assertEquals(ShapeType.BIG_DECIMAL, bigDecimalField.type());
181+
assertEquals(bigDecimalValue, bigDecimalField.asBigDecimal());
182+
183+
var bigIntegerField = inputDocument.getMember("bigIntegerField");
184+
assertNotNull(bigIntegerField);
185+
assertEquals(ShapeType.BIG_INTEGER, bigIntegerField.type());
186+
assertEquals(bigIntegerValue, bigIntegerField.asBigInteger());
187+
188+
var blobField = inputDocument.getMember("blobField");
189+
assertNotNull(blobField);
190+
assertEquals(ShapeType.BLOB, blobField.type());
191+
assertEquals(blobValue, new String(blobField.asBlob().array(), StandardCharsets.UTF_8));
192+
193+
var nestedWithBigNumbers = inputDocument.getMember("nestedWithBigNumbers");
194+
assertNotNull(nestedWithBigNumbers);
195+
assertEquals(ShapeType.STRUCTURE, nestedWithBigNumbers.type());
196+
197+
var nestedStruct = (StructDocument) nestedWithBigNumbers;
198+
199+
var nestedBigDecimalField = nestedStruct.getMember("nestedBigDecimal");
200+
assertNotNull(nestedBigDecimalField);
201+
assertEquals(ShapeType.BIG_DECIMAL, nestedBigDecimalField.type());
202+
assertEquals(nestedBigDecimalValue, nestedBigDecimalField.asBigDecimal());
203+
204+
var nestedBigIntegerField = nestedStruct.getMember("nestedBigInteger");
205+
assertNotNull(nestedBigIntegerField);
206+
assertEquals(ShapeType.BIG_INTEGER, nestedBigIntegerField.type());
207+
assertEquals(nestedBigIntegerValue, nestedBigIntegerField.asBigInteger());
208+
209+
var nestedBlobField = nestedStruct.getMember("nestedBlob");
210+
assertNotNull(nestedBlobField);
211+
assertEquals(ShapeType.BLOB, nestedBlobField.type());
212+
assertEquals(nestedBlobValue, new String(nestedBlobField.asBlob().array(), StandardCharsets.UTF_8));
213+
214+
var bigDecimalListField = nestedStruct.getMember("bigDecimalList");
215+
assertNotNull(bigDecimalListField);
216+
assertEquals(ShapeType.LIST, bigDecimalListField.type());
217+
var bigDecimalList = bigDecimalListField.asList();
218+
assertEquals(2, bigDecimalList.size());
219+
assertEquals(ShapeType.BIG_DECIMAL, bigDecimalList.get(0).type());
220+
assertEquals(ShapeType.BIG_DECIMAL, bigDecimalList.get(1).type());
221+
assertEquals(new BigDecimal("100.25"), bigDecimalList.get(0).asBigDecimal());
222+
assertEquals(new BigDecimal("200.75"), bigDecimalList.get(1).asBigDecimal());
223+
224+
server.shutdown().join();
225+
}
226+
109227
private void validateNestedStructure(Map<String, Document> properties) {
110228
var nestedStr = properties.get("nestedStr").asStringMap();
111229
assertEquals("string", nestedStr.get("type").asString());
@@ -154,6 +272,7 @@ private JsonRpcResponse read() {
154272
/// A TestOperation
155273
operation TestOperation {
156274
input: TestInput
275+
output: TestInput
157276
}
158277
159278
/// An input for TestOperation with a nested member
@@ -167,6 +286,14 @@ private JsonRpcResponse read() {
167286
list: NestedList
168287
169288
doubleNestedList: DoubleNestedList
289+
290+
bigDecimalField: BigDecimal
291+
292+
bigIntegerField: BigInteger
293+
294+
blobField: Blob
295+
296+
nestedWithBigNumbers: NestedWithBigNumbers
170297
}
171298
172299
list NestedList {
@@ -193,6 +320,25 @@ private JsonRpcResponse read() {
193320
structure Recursive {
194321
/// the nested field that points back to us
195322
nested: Nested
323+
}
324+
325+
/// A structure containing big number types
326+
structure NestedWithBigNumbers {
327+
/// A nested BigDecimal
328+
nestedBigDecimal: BigDecimal
329+
330+
/// A nested BigInteger
331+
nestedBigInteger: BigInteger
332+
333+
/// A nested Blob
334+
nestedBlob: Blob
335+
336+
/// A list of BigDecimals
337+
bigDecimalList: BigDecimalList
338+
}
339+
340+
list BigDecimalList {
341+
member: BigDecimal
196342
}""";
197343

198344
private static final Model MODEL = Model.assembler()

0 commit comments

Comments
 (0)