Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ dependencies {
compileOnly group: 'com.google.code.gson', name: 'gson', version: "${versions.gson}"
compileOnly group: 'org.json', name: 'json', version: '20231013'
testImplementation group: 'org.json', name: 'json', version: '20231013'
compileOnly('io.modelcontextprotocol.sdk:mcp:0.12.1')
testImplementation('io.modelcontextprotocol.sdk:mcp:0.12.1')
implementation('com.google.guava:guava:32.1.3-jre') {
exclude group: 'com.google.guava', module: 'failureaccess'
exclude group: 'com.google.code.findbugs', module: 'jsr305'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,140 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Set;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.utils.StringUtils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import io.modelcontextprotocol.spec.McpSchema;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;

@Log4j2
public class MLMcpServerRequest extends ActionRequest {

private static final int MAX_ID_LENGTH = 1000;
private static final int MAX_REQUEST_SIZE = 10 * 1024 * 1024;
private static final Set<String> VALID_METHODS = Set
.of(
McpSchema.METHOD_INITIALIZE,
McpSchema.METHOD_NOTIFICATION_INITIALIZED,
McpSchema.METHOD_PING,
McpSchema.METHOD_NOTIFICATION_PROGRESS,
McpSchema.METHOD_TOOLS_LIST,
McpSchema.METHOD_TOOLS_CALL,
McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED,
McpSchema.METHOD_RESOURCES_LIST,
McpSchema.METHOD_RESOURCES_READ,
McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED,
McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED,
McpSchema.METHOD_RESOURCES_TEMPLATES_LIST,
McpSchema.METHOD_RESOURCES_SUBSCRIBE,
McpSchema.METHOD_RESOURCES_UNSUBSCRIBE,
McpSchema.METHOD_PROMPT_LIST,
McpSchema.METHOD_PROMPT_GET,
McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED,
McpSchema.METHOD_COMPLETION_COMPLETE,
McpSchema.METHOD_LOGGING_SET_LEVEL,
McpSchema.METHOD_NOTIFICATION_MESSAGE,
McpSchema.METHOD_ROOTS_LIST,
McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED,
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE,
McpSchema.METHOD_ELICITATION_CREATE
);

@Getter
private String requestBody;
private McpSchema.JSONRPCMessage message;

public MLMcpServerRequest(StreamInput in) throws IOException {
super(in);
this.requestBody = in.readString();
validateAndParseRequest(in.readString());
}

public MLMcpServerRequest(String requestBody) {
this.requestBody = requestBody;
validateAndParseRequest(requestBody);
}

private void validateAndParseRequest(String requestBody) {
if (requestBody == null || requestBody.isEmpty()) {
throw new IllegalArgumentException("Request body cannot be null or empty");
}
if (requestBody.length() > MAX_REQUEST_SIZE) {
throw new IllegalArgumentException("Request body exceeds maximum size of " + MAX_REQUEST_SIZE + " bytes");
}

try {
message = McpSchema.deserializeJsonRpcMessage(new ObjectMapper(), requestBody);
} catch (Exception e) {
log.error("Parse error: " + e.getMessage(), e);
throw new IllegalArgumentException("Failed to parse JSON-RPC message: " + e.getMessage(), e);
}

validateMessage();
}

private void validateMessage() {
if (!McpSchema.JSONRPC_VERSION.equals(message.jsonrpc())) {
throw new IllegalArgumentException("Invalid jsonrpc version. Expected '2.0' but got '" + message.jsonrpc() + "'");
}

if (message instanceof McpSchema.JSONRPCRequest request) {
validateRequestId(request.id());
validateMethod(request.method());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to also validate request.params?

} else if (message instanceof McpSchema.JSONRPCNotification notification) {
validateMethod(notification.method());
} else if (message instanceof McpSchema.JSONRPCResponse) {
throw new IllegalArgumentException("JSON-RPC responses are not accepted as incoming messages");
} else {
throw new IllegalArgumentException("Unknown JSON-RPC message type: " + message.getClass().getName());
}
}

private void validateRequestId(Object id) {
if (id == null) {
throw new IllegalArgumentException("Request ID cannot be null");
}
if (!(id instanceof String || id instanceof Integer || id instanceof Long)) {
throw new IllegalArgumentException("Request ID must be a string or integer, but got: " + id.getClass().getSimpleName());
}
if (id instanceof String) {
String idStr = (String) id;
if (idStr.length() > MAX_ID_LENGTH) {
throw new IllegalArgumentException("Request ID exceeds maximum length of " + MAX_ID_LENGTH + " characters");
}
if (!StringUtils.matchesSafePattern(idStr)) {
throw new IllegalArgumentException("Request ID " + StringUtils.SAFE_INPUT_DESCRIPTION);
}
}
}

private void validateMethod(String method) {
if (method == null || method.isEmpty()) {
throw new IllegalArgumentException("Method cannot be null or empty");
}
if (!VALID_METHODS.contains(method)) {
throw new IllegalArgumentException("Invalid MCP method: '" + method + "'. Must be one of the supported MCP methods.");
}
}

public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(requestBody);
// Serialize the message back to JSON string
ObjectMapper objectMapper = new ObjectMapper();
try {
String jsonString = objectMapper.writeValueAsString(message);
out.writeString(jsonString);
} catch (JsonProcessingException e) {
throw new IOException("Failed to serialize JSON-RPC message", e);
}
}

public static MLMcpServerRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Loading
Loading