Skip to content

mcp-server Transport trait + Toolset for easier Router building #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 9 additions & 114 deletions crates/mcp-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,120 +1,17 @@
use std::{
pin::Pin,
task::{Context, Poll},
};

use futures::{Future, Stream};
use futures::Future;
use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
use pin_project::pin_project;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use std::pin::Pin;
use tower_service::Service;

mod errors;
pub use errors::{BoxError, RouterError, ServerError, TransportError};

pub mod router;
pub use router::Router;
use transport::Transport;

/// A transport layer that handles JSON-RPC messages over byte
#[pin_project]
pub struct ByteTransport<R, W> {
// Reader is a BufReader on the underlying stream (stdin or similar) buffering
// the underlying data across poll calls, we clear one line (\n) during each
// iteration of poll_next from this buffer
#[pin]
reader: BufReader<R>,
#[pin]
writer: W,
}

impl<R, W> ByteTransport<R, W>
where
R: AsyncRead,
W: AsyncWrite,
{
pub fn new(reader: R, writer: W) -> Self {
Self {
// Default BufReader capacity is 8 * 1024, increase this to 2MB to the file size limit
// allows the buffer to have the capacity to read very large calls
reader: BufReader::with_capacity(2 * 1024 * 1024, reader),
writer,
}
}
}

impl<R, W> Stream for ByteTransport<R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
type Item = Result<JsonRpcMessage, TransportError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let mut buf = Vec::new();

let mut reader = this.reader.as_mut();
let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf));
match read_future.as_mut().poll(cx) {
Poll::Ready(Ok(0)) => Poll::Ready(None), // EOF
Poll::Ready(Ok(_)) => {
// Convert to UTF-8 string
let line = match String::from_utf8(buf) {
Ok(s) => s,
Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))),
};
// Log incoming message here before serde conversion to
// track incomplete chunks which are not valid JSON
tracing::info!(json = %line, "incoming message");

// Parse JSON and validate message format
match serde_json::from_str::<serde_json::Value>(&line) {
Ok(value) => {
// Validate basic JSON-RPC structure
if !value.is_object() {
return Poll::Ready(Some(Err(TransportError::InvalidMessage(
"Message must be a JSON object".into(),
))));
}
let obj = value.as_object().unwrap(); // Safe due to check above

// Check jsonrpc version field
if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" {
return Poll::Ready(Some(Err(TransportError::InvalidMessage(
"Missing or invalid jsonrpc version".into(),
))));
}

// Now try to parse as proper message
match serde_json::from_value::<JsonRpcMessage>(value) {
Ok(msg) => Poll::Ready(Some(Ok(msg))),
Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))),
}
}
Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))),
}
}
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))),
Poll::Pending => Poll::Pending,
}
}
}

impl<R, W> ByteTransport<R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
pub async fn write_message(&mut self, msg: JsonRpcMessage) -> Result<(), std::io::Error> {
let json = serde_json::to_string(&msg)?;
Pin::new(&mut self.writer)
.write_all(json.as_bytes())
.await?;
Pin::new(&mut self.writer).write_all(b"\n").await?;
Pin::new(&mut self.writer).flush().await?;
Ok(())
}
}
pub mod toolset;
pub mod transport;

/// The main server type that processes incoming requests
pub struct Server<S> {
Expand All @@ -131,11 +28,9 @@ where
Self { service }
}

// TODO transport trait instead of byte transport if we implement others
pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
pub async fn run<T>(self, mut transport: T) -> Result<(), ServerError>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
T: Transport + Unpin,
{
use futures::StreamExt;
let mut service = self.service;
Expand Down Expand Up @@ -193,7 +88,7 @@ where
.write_message(JsonRpcMessage::Response(response))
.await
{
return Err(ServerError::Transport(TransportError::Io(e)));
return Err(ServerError::Transport(e));
}
}
JsonRpcMessage::Response(_)
Expand Down Expand Up @@ -234,7 +129,7 @@ where
});

if let Err(e) = transport.write_message(error_response).await {
return Err(ServerError::Transport(TransportError::Io(e)));
return Err(ServerError::Transport(e));
}
}
}
Expand Down
Empty file added crates/mcp-server/src/server.rs
Empty file.
108 changes: 108 additions & 0 deletions crates/mcp-server/src/toolset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use anyhow::Result;
use mcp_core::handler::ToolHandler;
use mcp_core::Tool;
use mcp_core::{Content, ToolError};
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

pub struct Toolset {
tool_list: HashMap<String, Tool>,
tool_handlers: HashMap<String, ToolHandlerFn>,
}

pub struct ToolsetBuilder {
tool_list: HashMap<String, Tool>,
tool_handlers: HashMap<String, ToolHandlerFn>,
}

impl ToolsetBuilder {
pub fn new() -> Self {
Self {
tool_list: HashMap::new(),
tool_handlers: HashMap::new(),
}
}

pub fn add_tool(mut self, tool: Tool, handler: ToolHandlerFn) -> Self {
let name = tool.name.clone();
self.tool_list.insert(name.clone(), tool);
self.tool_handlers.insert(name, handler);
self
}

pub fn add_tool_from_handler(mut self, tool_handler: impl ToolHandler) -> Self {
let name = tool_handler.name().to_string();
self.tool_list.insert(
name.clone(),
Tool {
name: name.clone(),
description: tool_handler.description().to_string(),
input_schema: tool_handler.schema(),
},
);

let tool_handler = Arc::new(tool_handler);
let handler_fn = Box::new(move |_name: &str, params: Value| {
let tool_handler = Arc::clone(&tool_handler);
Box::pin(async move {
let result = tool_handler.call(params).await?;
let contents: Vec<Content> = serde_json::from_value(result)
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
Ok(contents)
}) as Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send>>
});

self.tool_handlers.insert(name, handler_fn);
self
}

pub fn build(self) -> Toolset {
Toolset {
tool_list: self.tool_list,
tool_handlers: self.tool_handlers,
}
}
}

impl Default for ToolsetBuilder {
fn default() -> Self {
Self::new()
}
}

impl Toolset {
pub fn builder() -> ToolsetBuilder {
ToolsetBuilder::new()
}

pub fn get_tool(&self, name: &str) -> Option<Tool> {
self.tool_list.get(name).cloned()
}

pub async fn call_tool(
&self,
name: &str,
arguements: Value,
) -> Result<Vec<Content>, ToolError> {
let handler = self
.tool_handlers
.get(name)
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;

Ok((handler)(name, arguements).await?)
}

pub fn list_tools(&self) -> Vec<Tool> {
self.tool_list.values().cloned().collect()
}
}

pub type ToolHandlerFn = Box<
dyn Fn(
&str,
Value,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>,
>;
15 changes: 15 additions & 0 deletions crates/mcp-server/src/transport/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use async_trait::async_trait;
use futures::Stream;
use mcp_core::protocol::JsonRpcMessage;

use crate::TransportError;

pub mod stdio;
pub use stdio::StdioTransport;

/// A trait representing a transport layer for JSON-RPC messages
#[async_trait]
pub trait Transport: Stream<Item = Result<JsonRpcMessage, TransportError>> {
/// Writes a JSON-RPC message to the transport
async fn write_message(&mut self, message: JsonRpcMessage) -> Result<(), TransportError>;
}
Loading