From 148121fdb3a47ee4eb2e928be8f1e339ea63f47a Mon Sep 17 00:00:00 2001 From: stevo Date: Tue, 4 Mar 2025 19:19:59 -0500 Subject: [PATCH 1/2] added toolset --- crates/mcp-server/src/lib.rs | 3 + crates/mcp-server/src/toolset.rs | 108 +++++++++++++++++++++++++ crates/mcp-server/src/transport/mod.rs | 0 3 files changed, 111 insertions(+) create mode 100644 crates/mcp-server/src/toolset.rs create mode 100644 crates/mcp-server/src/transport/mod.rs diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index 7588617d..d4b692d4 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -15,6 +15,9 @@ pub use errors::{BoxError, RouterError, ServerError, TransportError}; pub mod router; pub use router::Router; +pub mod toolset; +pub mod transport; + /// A transport layer that handles JSON-RPC messages over byte #[pin_project] pub struct ByteTransport { diff --git a/crates/mcp-server/src/toolset.rs b/crates/mcp-server/src/toolset.rs new file mode 100644 index 00000000..758133dd --- /dev/null +++ b/crates/mcp-server/src/toolset.rs @@ -0,0 +1,108 @@ +use anyhow::Result; +use mcp_core::handler::ToolHandler; +use mcp_core::Tool; +use mcp_core::{Content, ToolError}; +use serde_json::Value; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +pub struct Toolset { + tool_list: HashMap, + tool_handlers: HashMap, +} + +pub struct ToolsetBuilder { + tool_list: HashMap, + tool_handlers: HashMap, +} + +impl ToolsetBuilder { + pub fn new() -> Self { + Self { + tool_list: HashMap::new(), + tool_handlers: HashMap::new(), + } + } + + pub fn add_tool(mut self, tool: Tool, handler: ToolHandlerFn) -> Self { + let name = tool.name.clone(); + self.tool_list.insert(name.clone(), tool); + self.tool_handlers.insert(name, handler); + self + } + + pub fn add_tool_from_handler(mut self, tool_handler: impl ToolHandler) -> Self { + let name = tool_handler.name().to_string(); + self.tool_list.insert( + name.clone(), + Tool { + name: name.clone(), + description: tool_handler.description().to_string(), + input_schema: tool_handler.schema(), + }, + ); + + let tool_handler = Arc::new(tool_handler); + let handler_fn = Box::new(move |_name: &str, params: Value| { + let tool_handler = Arc::clone(&tool_handler); + Box::pin(async move { + let result = tool_handler.call(params).await?; + let contents: Vec = serde_json::from_value(result) + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + Ok(contents) + }) as Pin, ToolError>> + Send>> + }); + + self.tool_handlers.insert(name, handler_fn); + self + } + + pub fn build(self) -> Toolset { + Toolset { + tool_list: self.tool_list, + tool_handlers: self.tool_handlers, + } + } +} + +impl Default for ToolsetBuilder { + fn default() -> Self { + Self::new() + } +} + +impl Toolset { + pub fn builder() -> ToolsetBuilder { + ToolsetBuilder::new() + } + + pub fn get_tool(&self, name: &str) -> Option { + self.tool_list.get(name).cloned() + } + + pub async fn call_tool( + &self, + name: &str, + arguements: Value, + ) -> Result, ToolError> { + let handler = self + .tool_handlers + .get(name) + .ok_or_else(|| ToolError::NotFound(name.to_string()))?; + + Ok((handler)(name, arguements).await?) + } + + pub fn list_tools(&self) -> Vec { + self.tool_list.values().cloned().collect() + } +} + +pub type ToolHandlerFn = Box< + dyn Fn( + &str, + Value, + ) -> Pin, ToolError>> + Send + 'static>>, +>; diff --git a/crates/mcp-server/src/transport/mod.rs b/crates/mcp-server/src/transport/mod.rs new file mode 100644 index 00000000..e69de29b From 4b125a8ad01f600b6f6b6980bc75166642379b9a Mon Sep 17 00:00:00 2001 From: stevo Date: Tue, 4 Mar 2025 20:11:21 -0500 Subject: [PATCH 2/2] converted transport to trait and converted bytes to stdio transport struct --- crates/mcp-server/src/lib.rs | 122 ++-------------------- crates/mcp-server/src/server.rs | 0 crates/mcp-server/src/transport/mod.rs | 15 +++ crates/mcp-server/src/transport/stdio.rs | 126 +++++++++++++++++++++++ examples/servers/src/axum.rs | 4 +- examples/servers/src/counter_server.rs | 4 +- 6 files changed, 152 insertions(+), 119 deletions(-) create mode 100644 crates/mcp-server/src/server.rs create mode 100644 crates/mcp-server/src/transport/stdio.rs diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index d4b692d4..b68c9572 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -1,12 +1,6 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use futures::{Future, Stream}; +use futures::Future; use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse}; -use pin_project::pin_project; -use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use std::pin::Pin; use tower_service::Service; mod errors; @@ -14,111 +8,11 @@ pub use errors::{BoxError, RouterError, ServerError, TransportError}; pub mod router; pub use router::Router; +use transport::Transport; pub mod toolset; pub mod transport; -/// A transport layer that handles JSON-RPC messages over byte -#[pin_project] -pub struct ByteTransport { - // Reader is a BufReader on the underlying stream (stdin or similar) buffering - // the underlying data across poll calls, we clear one line (\n) during each - // iteration of poll_next from this buffer - #[pin] - reader: BufReader, - #[pin] - writer: W, -} - -impl ByteTransport -where - R: AsyncRead, - W: AsyncWrite, -{ - pub fn new(reader: R, writer: W) -> Self { - Self { - // Default BufReader capacity is 8 * 1024, increase this to 2MB to the file size limit - // allows the buffer to have the capacity to read very large calls - reader: BufReader::with_capacity(2 * 1024 * 1024, reader), - writer, - } - } -} - -impl Stream for ByteTransport -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - let mut buf = Vec::new(); - - let mut reader = this.reader.as_mut(); - let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf)); - match read_future.as_mut().poll(cx) { - Poll::Ready(Ok(0)) => Poll::Ready(None), // EOF - Poll::Ready(Ok(_)) => { - // Convert to UTF-8 string - let line = match String::from_utf8(buf) { - Ok(s) => s, - Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))), - }; - // Log incoming message here before serde conversion to - // track incomplete chunks which are not valid JSON - tracing::info!(json = %line, "incoming message"); - - // Parse JSON and validate message format - match serde_json::from_str::(&line) { - Ok(value) => { - // Validate basic JSON-RPC structure - if !value.is_object() { - return Poll::Ready(Some(Err(TransportError::InvalidMessage( - "Message must be a JSON object".into(), - )))); - } - let obj = value.as_object().unwrap(); // Safe due to check above - - // Check jsonrpc version field - if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" { - return Poll::Ready(Some(Err(TransportError::InvalidMessage( - "Missing or invalid jsonrpc version".into(), - )))); - } - - // Now try to parse as proper message - match serde_json::from_value::(value) { - Ok(msg) => Poll::Ready(Some(Ok(msg))), - Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), - } - } - Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), - } - } - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))), - Poll::Pending => Poll::Pending, - } - } -} - -impl ByteTransport -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - pub async fn write_message(&mut self, msg: JsonRpcMessage) -> Result<(), std::io::Error> { - let json = serde_json::to_string(&msg)?; - Pin::new(&mut self.writer) - .write_all(json.as_bytes()) - .await?; - Pin::new(&mut self.writer).write_all(b"\n").await?; - Pin::new(&mut self.writer).flush().await?; - Ok(()) - } -} - /// The main server type that processes incoming requests pub struct Server { service: S, @@ -134,11 +28,9 @@ where Self { service } } - // TODO transport trait instead of byte transport if we implement others - pub async fn run(self, mut transport: ByteTransport) -> Result<(), ServerError> + pub async fn run(self, mut transport: T) -> Result<(), ServerError> where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, + T: Transport + Unpin, { use futures::StreamExt; let mut service = self.service; @@ -196,7 +88,7 @@ where .write_message(JsonRpcMessage::Response(response)) .await { - return Err(ServerError::Transport(TransportError::Io(e))); + return Err(ServerError::Transport(e)); } } JsonRpcMessage::Response(_) @@ -237,7 +129,7 @@ where }); if let Err(e) = transport.write_message(error_response).await { - return Err(ServerError::Transport(TransportError::Io(e))); + return Err(ServerError::Transport(e)); } } } diff --git a/crates/mcp-server/src/server.rs b/crates/mcp-server/src/server.rs new file mode 100644 index 00000000..e69de29b diff --git a/crates/mcp-server/src/transport/mod.rs b/crates/mcp-server/src/transport/mod.rs index e69de29b..69803ef0 100644 --- a/crates/mcp-server/src/transport/mod.rs +++ b/crates/mcp-server/src/transport/mod.rs @@ -0,0 +1,15 @@ +use async_trait::async_trait; +use futures::Stream; +use mcp_core::protocol::JsonRpcMessage; + +use crate::TransportError; + +pub mod stdio; +pub use stdio::StdioTransport; + +/// A trait representing a transport layer for JSON-RPC messages +#[async_trait] +pub trait Transport: Stream> { + /// Writes a JSON-RPC message to the transport + async fn write_message(&mut self, message: JsonRpcMessage) -> Result<(), TransportError>; +} diff --git a/crates/mcp-server/src/transport/stdio.rs b/crates/mcp-server/src/transport/stdio.rs new file mode 100644 index 00000000..7f0e3e24 --- /dev/null +++ b/crates/mcp-server/src/transport/stdio.rs @@ -0,0 +1,126 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_trait::async_trait; +use futures::{Future, Stream}; +use mcp_core::protocol::JsonRpcMessage; +use pin_project::pin_project; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; + +use super::Transport; +use crate::TransportError; + +/// A transport layer that handles JSON-RPC messages over byte +#[pin_project] +pub struct StdioTransport { + // Reader is a BufReader on the underlying stream (stdin or similar) buffering + // the underlying data across poll calls, we clear one line (\n) during each + // iteration of poll_next from this buffer + #[pin] + reader: BufReader, + #[pin] + writer: W, +} + +impl StdioTransport +where + R: AsyncRead, + W: AsyncWrite, +{ + pub fn new(reader: R, writer: W) -> Self { + Self { + // Default BufReader capacity is 8 * 1024, increase this to 2MB to the file size limit + // allows the buffer to have the capacity to read very large calls + reader: BufReader::with_capacity(2 * 1024 * 1024, reader), + writer, + } + } +} + +impl Stream for StdioTransport +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + let mut buf = Vec::new(); + + let mut reader = this.reader.as_mut(); + let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf)); + match read_future.as_mut().poll(cx) { + Poll::Ready(Ok(0)) => Poll::Ready(None), // EOF + Poll::Ready(Ok(_)) => { + // Convert to UTF-8 string + let line = match String::from_utf8(buf) { + Ok(s) => s, + Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))), + }; + // Log incoming message here before serde conversion to + // track incomplete chunks which are not valid JSON + tracing::info!(json = %line, "incoming message"); + + // Parse JSON and validate message format + match serde_json::from_str::(&line) { + Ok(value) => { + // Validate basic JSON-RPC structure + if !value.is_object() { + return Poll::Ready(Some(Err(TransportError::InvalidMessage( + "Message must be a JSON object".into(), + )))); + } + let obj = value.as_object().unwrap(); // Safe due to check above + + // Check jsonrpc version field + if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" { + return Poll::Ready(Some(Err(TransportError::InvalidMessage( + "Missing or invalid jsonrpc version".into(), + )))); + } + + // Now try to parse as proper message + match serde_json::from_value::(value) { + Ok(msg) => Poll::Ready(Some(Ok(msg))), + Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), + } + } + Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), + } + } + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))), + Poll::Pending => Poll::Pending, + } + } +} + +#[async_trait] +impl Transport for StdioTransport +where + R: AsyncRead + Unpin + Send, + W: AsyncWrite + Unpin + Send, +{ + async fn write_message(&mut self, message: JsonRpcMessage) -> Result<(), TransportError> { + let json = serde_json::to_string(&message).map_err(|e| TransportError::Json(e))?; + + Pin::new(&mut self.writer) + .write_all(json.as_bytes()) + .await + .map_err(|e| TransportError::Io(e))?; + + Pin::new(&mut self.writer) + .write_all(b"\n") + .await + .map_err(|e| TransportError::Io(e))?; + + Pin::new(&mut self.writer) + .flush() + .await + .map_err(|e| TransportError::Io(e))?; + + Ok(()) + } +} diff --git a/examples/servers/src/axum.rs b/examples/servers/src/axum.rs index 7c3e4b93..abd58dc2 100644 --- a/examples/servers/src/axum.rs +++ b/examples/servers/src/axum.rs @@ -7,7 +7,7 @@ use axum::{ Router, }; use futures::{stream::Stream, StreamExt, TryStreamExt}; -use mcp_server::{ByteTransport, Server}; +use mcp_server::{transport::StdioTransport, Server}; use std::collections::HashMap; use tokio_util::codec::FramedRead; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -114,7 +114,7 @@ async fn sse_handler(State(app): State) -> Sse Result<()> { // Create and run the server let server = Server::new(router); - let transport = ByteTransport::new(stdin(), stdout()); + let transport = StdioTransport::new(stdin(), stdout()); tracing::info!("Server initialized and ready to handle requests"); Ok(server.run(transport).await?)