From e82eed4f12c7cc716584c1fac16e4d16749ffd31 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Sun, 2 Mar 2025 00:17:08 +0800 Subject: [PATCH 1/6] add axum sse example --- crates/mcp-server/Cargo.toml | 10 + .../mcp-server/examples/sse-axim/counter.rs | 180 ++++++++++++++++++ crates/mcp-server/examples/sse-axim/main.rs | 177 +++++++++++++++++ crates/mcp-server/src/lib.rs | 3 +- 4 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 crates/mcp-server/examples/sse-axim/counter.rs create mode 100644 crates/mcp-server/examples/sse-axim/main.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index 6baa7fe8..075d2c19 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -23,3 +23,13 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" async-trait = "0.1" + + +[dev-dependencies] +axum = { version = "0.8", features = ["macros"] } +tokio-util = { version = "0.7", features = ["io", "codec"]} +rand = { version = "0.8" } + +[[example]] +name = "example-sse-axum" +path = "examples/sse-axum/main.rs" diff --git a/crates/mcp-server/examples/sse-axim/counter.rs b/crates/mcp-server/examples/sse-axim/counter.rs new file mode 100644 index 00000000..557f3a0b --- /dev/null +++ b/crates/mcp-server/examples/sse-axim/counter.rs @@ -0,0 +1,180 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + +use mcp_core::{handler::{PromptError, ResourceError}, prompt::{Prompt, PromptArgument}, protocol::ServerCapabilities, Content, Resource, Tool, ToolError}; +use mcp_server::router::CapabilitiesBuilder; +use serde_json::Value; +use tokio::sync::Mutex; + +#[derive(Clone)] +pub struct CounterRouter { + counter: Arc>, +} + +impl CounterRouter { + pub fn new() -> Self { + Self { + counter: Arc::new(Mutex::new(0)), + } + } + + async fn increment(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter += 1; + Ok(*counter) + } + + async fn decrement(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter -= 1; + Ok(*counter) + } + + async fn get_value(&self) -> Result { + let counter = self.counter.lock().await; + Ok(*counter) + } + + fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { + Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() + } +} + +impl mcp_server::Router for CounterRouter { + fn name(&self) -> String { + "counter".to_string() + } + + fn instructions(&self) -> String { + "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() + } + + fn capabilities(&self) -> ServerCapabilities { + CapabilitiesBuilder::new() + .with_tools(false) + .with_resources(false, false) + .with_prompts(false) + .build() + } + + fn list_tools(&self) -> Vec { + vec![ + Tool::new( + "increment".to_string(), + "Increment the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "decrement".to_string(), + "Decrement the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "get_value".to_string(), + "Get the current counter value".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + ] + } + + fn call_tool( + &self, + tool_name: &str, + _arguments: Value, + ) -> Pin, ToolError>> + Send + 'static>> { + let this = self.clone(); + let tool_name = tool_name.to_string(); + + Box::pin(async move { + match tool_name.as_str() { + "increment" => { + let value = this.increment().await?; + Ok(vec![Content::text(value.to_string())]) + } + "decrement" => { + let value = this.decrement().await?; + Ok(vec![Content::text(value.to_string())]) + } + "get_value" => { + let value = this.get_value().await?; + Ok(vec![Content::text(value.to_string())]) + } + _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), + } + }) + } + + fn list_resources(&self) -> Vec { + vec![ + self._create_resource_text("str:////Users/to/some/path/", "cwd"), + self._create_resource_text("memo://insights", "memo-name"), + ] + } + + fn read_resource( + &self, + uri: &str, + ) -> Pin> + Send + 'static>> { + let uri = uri.to_string(); + Box::pin(async move { + match uri.as_str() { + "str:////Users/to/some/path/" => { + let cwd = "/Users/to/some/path/"; + Ok(cwd.to_string()) + } + "memo://insights" => { + let memo = + "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; + Ok(memo.to_string()) + } + _ => Err(ResourceError::NotFound(format!( + "Resource {} not found", + uri + ))), + } + }) + } + + fn list_prompts(&self) -> Vec { + vec![Prompt::new( + "example_prompt", + Some("This is an example prompt that takes one required agrument, message"), + Some(vec![PromptArgument { + name: "message".to_string(), + description: Some("A message to put in the prompt".to_string()), + required: Some(true), + }]), + )] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + match prompt_name.as_str() { + "example_prompt" => { + let prompt = "This is an example prompt with your message here: '{message}'"; + Ok(prompt.to_string()) + } + _ => Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))), + } + }) + } +} + diff --git a/crates/mcp-server/examples/sse-axim/main.rs b/crates/mcp-server/examples/sse-axim/main.rs new file mode 100644 index 00000000..8106618d --- /dev/null +++ b/crates/mcp-server/examples/sse-axim/main.rs @@ -0,0 +1,177 @@ +use axum::{ + body::Body, + extract::{Query, State}, + http::StatusCode, + response::sse::{Event, Sse}, + routing::get, + Router, +}; +use futures::{stream::Stream, StreamExt, TryStreamExt}; +use mcp_server::{ByteTransport, Server}; +use std::collections::HashMap; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use anyhow::Result; +use mcp_server::router::RouterService; +use std::sync::Arc; +use tokio::{ + io::{self, AsyncWriteExt}, + sync::Mutex, +}; +use tracing_subscriber::{self}; +mod counter; + +type C2SWriter = Arc>>; +type SessionId = Arc; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[derive(Clone, Default)] +pub struct App { + txs: Arc>>, +} + +impl App { + pub fn new() -> Self { + Self { + txs: Default::default(), + } + } + pub fn router(&self) -> Router { + Router::new() + .route("/sse", get(sse_handler).post(post_event_handler)) + .with_state(self.clone()) + } +} + +fn session_id() -> SessionId { + let id = format!("{:016x}", rand::random::()); + Arc::from(id) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + State(app): State, + Query(PostEventQuery { session_id }): Query, + body: Body, +) -> Result { + const BODY_BYTES_LIMIT: usize = 1 << 22; + let write_stream = { + let rg = app.txs.read().await; + rg.get(session_id.as_str()) + .ok_or(StatusCode::NOT_FOUND)? + .clone() + }; + let mut write_stream = write_stream.lock().await; + let mut body = body.into_data_stream(); + if let (_, Some(size)) = body.size_hint() { + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + } + // calculate the body size + let mut size = 0; + while let Some(chunk) = body.next().await { + let Ok(chunk) = chunk else { + return Err(StatusCode::BAD_REQUEST); + }; + size += chunk.len(); + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + write_stream + .write_all(&chunk) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + write_stream + .write_u8(b'\n') + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(StatusCode::ACCEPTED) +} + +async fn sse_handler(State(app): State) -> Sse>> { + // it's 4KB + const BUFFER_SIZE: usize = 1 << 12; + let session = session_id(); + tracing::info!(%session, "sse connection"); + let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); + let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); + app.txs + .write() + .await + .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); + { + let session = session.clone(); + tokio::spawn(async move { + let router = RouterService(counter::CounterRouter::new()); + let server = Server::new(router); + let bytes_transport = ByteTransport::new(c2s_read, s2c_write); + let _result = server + .run(bytes_transport) + .await + .inspect_err(|e| tracing::error!(?e, "server run error")); + app.txs.write().await.remove(&session); + }); + } + + use tokio_util::codec::{Decoder, FramedRead}; + #[derive(Default)] + pub struct LinesBytesCodec; + impl Decoder for LinesBytesCodec { + type Item = tokio_util::bytes::Bytes; + type Error = io::Error; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + if let Some(end) = src + .iter() + .enumerate() + .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) + { + let line = src.split_to(end); + let _char_next_line = src.split_to(1); + Ok(Some(line.freeze())) + } else { + Ok(None) + } + } + } + + let stream = futures::stream::once(futures::future::ok( + Event::default() + .event("endpoint") + .data(format!("?sessionId={session}")), + )) + .chain( + FramedRead::new(s2c_read, LinesBytesCodec) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + .and_then(move |bytes| match std::str::from_utf8(&bytes) { + Ok(message) => futures::future::ok(Event::default().event("message").data(message)), + Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), + }), + ); + Sse::new(stream) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; + + tracing::debug!("listening on {}", listener.local_addr()?); + axum::serve(listener, App::new().router()).await +} diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index 97594523..7588617d 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -142,7 +142,8 @@ where tracing::info!("Server started"); while let Some(msg_result) = transport.next().await { - let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered(); + let _span = tracing::span!(tracing::Level::INFO, "message_processing"); + let _enter = _span.enter(); match msg_result { Ok(msg) => { match msg { From 4663da3a12977125974a9bcaf773690fe58a6607 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 3 Mar 2025 09:07:04 +0800 Subject: [PATCH 2/6] fix: make the sse axum folder name correct --- crates/mcp-server/examples/{sse-axim => sse-axum}/counter.rs | 0 crates/mcp-server/examples/{sse-axim => sse-axum}/main.rs | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename crates/mcp-server/examples/{sse-axim => sse-axum}/counter.rs (100%) rename crates/mcp-server/examples/{sse-axim => sse-axum}/main.rs (100%) diff --git a/crates/mcp-server/examples/sse-axim/counter.rs b/crates/mcp-server/examples/sse-axum/counter.rs similarity index 100% rename from crates/mcp-server/examples/sse-axim/counter.rs rename to crates/mcp-server/examples/sse-axum/counter.rs diff --git a/crates/mcp-server/examples/sse-axim/main.rs b/crates/mcp-server/examples/sse-axum/main.rs similarity index 100% rename from crates/mcp-server/examples/sse-axim/main.rs rename to crates/mcp-server/examples/sse-axum/main.rs From 830f2a6b3541ca01b421d67c1979e2945a2b77f0 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 3 Mar 2025 09:58:17 +0800 Subject: [PATCH 3/6] feat: add common modules for Axum SSE example This commit adds supporting modules for the Axum Server-Sent Events (SSE) example: - Added `counter.rs` with a simple counter router implementation - Added `jsonrpc_frame_codec.rs` for decoding JSON-RPC frames - Created a `mod.rs` to expose these modules - Removed the previous SSE example configuration from Cargo.toml --- crates/mcp-server/Cargo.toml | 6 +--- .../examples/{sse-axum/main.rs => axum.rs} | 30 +++---------------- .../examples/{sse-axum => common}/counter.rs | 3 +- .../examples/common/jsonrpc_frame_codec.rs | 24 +++++++++++++++ crates/mcp-server/examples/common/mod.rs | 2 ++ 5 files changed, 32 insertions(+), 33 deletions(-) rename crates/mcp-server/examples/{sse-axum/main.rs => axum.rs} (84%) rename crates/mcp-server/examples/{sse-axum => common}/counter.rs (99%) create mode 100644 crates/mcp-server/examples/common/jsonrpc_frame_codec.rs create mode 100644 crates/mcp-server/examples/common/mod.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index 075d2c19..f06bfeda 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -28,8 +28,4 @@ async-trait = "0.1" [dev-dependencies] axum = { version = "0.8", features = ["macros"] } tokio-util = { version = "0.7", features = ["io", "codec"]} -rand = { version = "0.8" } - -[[example]] -name = "example-sse-axum" -path = "examples/sse-axum/main.rs" +rand = { version = "0.8" } \ No newline at end of file diff --git a/crates/mcp-server/examples/sse-axum/main.rs b/crates/mcp-server/examples/axum.rs similarity index 84% rename from crates/mcp-server/examples/sse-axum/main.rs rename to crates/mcp-server/examples/axum.rs index 8106618d..7c3e4b93 100644 --- a/crates/mcp-server/examples/sse-axum/main.rs +++ b/crates/mcp-server/examples/axum.rs @@ -9,6 +9,7 @@ use axum::{ use futures::{stream::Stream, StreamExt, TryStreamExt}; use mcp_server::{ByteTransport, Server}; use std::collections::HashMap; +use tokio_util::codec::FramedRead; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use anyhow::Result; @@ -19,7 +20,8 @@ use tokio::{ sync::Mutex, }; use tracing_subscriber::{self}; -mod counter; +mod common; +use common::counter; type C2SWriter = Arc>>; type SessionId = Arc; @@ -121,37 +123,13 @@ async fn sse_handler(State(app): State) -> Sse Result, Self::Error> { - if let Some(end) = src - .iter() - .enumerate() - .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) - { - let line = src.split_to(end); - let _char_next_line = src.split_to(1); - Ok(Some(line.freeze())) - } else { - Ok(None) - } - } - } - let stream = futures::stream::once(futures::future::ok( Event::default() .event("endpoint") .data(format!("?sessionId={session}")), )) .chain( - FramedRead::new(s2c_read, LinesBytesCodec) + FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .and_then(move |bytes| match std::str::from_utf8(&bytes) { Ok(message) => futures::future::ok(Event::default().event("message").data(message)), diff --git a/crates/mcp-server/examples/sse-axum/counter.rs b/crates/mcp-server/examples/common/counter.rs similarity index 99% rename from crates/mcp-server/examples/sse-axum/counter.rs rename to crates/mcp-server/examples/common/counter.rs index 557f3a0b..a1081984 100644 --- a/crates/mcp-server/examples/sse-axum/counter.rs +++ b/crates/mcp-server/examples/common/counter.rs @@ -176,5 +176,4 @@ impl mcp_server::Router for CounterRouter { } }) } -} - +} \ No newline at end of file diff --git a/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs b/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs new file mode 100644 index 00000000..4368118c --- /dev/null +++ b/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs @@ -0,0 +1,24 @@ +use tokio_util::codec::Decoder; + +#[derive(Default)] +pub struct JsonRpcFrameCodec; +impl Decoder for JsonRpcFrameCodec { + type Item = tokio_util::bytes::Bytes; + type Error = tokio::io::Error; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + if let Some(end) = src + .iter() + .enumerate() + .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) + { + let line = src.split_to(end); + let _char_next_line = src.split_to(1); + Ok(Some(line.freeze())) + } else { + Ok(None) + } + } +} diff --git a/crates/mcp-server/examples/common/mod.rs b/crates/mcp-server/examples/common/mod.rs new file mode 100644 index 00000000..02b1a761 --- /dev/null +++ b/crates/mcp-server/examples/common/mod.rs @@ -0,0 +1,2 @@ +pub mod counter; +pub mod jsonrpc_frame_codec; \ No newline at end of file From 4a4871231082b2e680f11d71eef0d83df7eeb797 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 3 Mar 2025 17:44:41 +0800 Subject: [PATCH 4/6] refactor: remove main.rs and update Cargo.toml for WASM compatibility This commit makes two key changes: 1. Removes the main.rs file and make it a example 2. Add a wasi_std_io example --- crates/mcp-server/Cargo.toml | 13 +- crates/mcp-server/examples/tokio_std_io.rs | 36 ++++ crates/mcp-server/examples/wasi_std_io.rs | 130 ++++++++++++ crates/mcp-server/src/main.rs | 217 --------------------- 4 files changed, 176 insertions(+), 220 deletions(-) create mode 100644 crates/mcp-server/examples/tokio_std_io.rs create mode 100644 crates/mcp-server/examples/wasi_std_io.rs delete mode 100644 crates/mcp-server/src/main.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index f06bfeda..34d1f2c9 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -14,7 +14,7 @@ mcp-macros = { workspace = true } serde = { version = "1.0.216", features = ["derive"] } serde_json = "1.0.133" schemars = "0.8" -tokio = { version = "1", features = ["full"] } +tokio = { version = "1", features = ["io-util"] } tower = { version = "0.4", features = ["timeout"] } tower-service = "0.3" futures = "0.3" @@ -26,6 +26,13 @@ async-trait = "0.1" [dev-dependencies] -axum = { version = "0.8", features = ["macros"] } tokio-util = { version = "0.7", features = ["io", "codec"]} -rand = { version = "0.8" } \ No newline at end of file +rand = { version = "0.8" } + +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +axum = { version = "0.8", features = ["macros"] } +tokio = { version = "1", features = ["full"] } + +[target.'cfg(target_arch = "wasm32")'.dev-dependencies] +tokio = { version = "1", features = ["io-util", "rt", "time", "macros"] } +wasi = { version = "0.11.0+wasi-snapshot-preview1" } diff --git a/crates/mcp-server/examples/tokio_std_io.rs b/crates/mcp-server/examples/tokio_std_io.rs new file mode 100644 index 00000000..9c4261fb --- /dev/null +++ b/crates/mcp-server/examples/tokio_std_io.rs @@ -0,0 +1,36 @@ +use anyhow::Result; +use mcp_server::router::RouterService; +use mcp_server::{ByteTransport, Server}; +use tokio::io::{stdin, stdout}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{self, EnvFilter}; +mod common; +use common::counter::CounterRouter; + +#[tokio::main] +async fn main() -> Result<()> { + // Set up file appender for logging + let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log"); + + // Initialize the tracing subscriber with file and stdout logging + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .with_writer(file_appender) + .with_target(false) + .with_thread_ids(true) + .with_file(true) + .with_line_number(true) + .init(); + + tracing::info!("Starting MCP server"); + + // Create an instance of our counter router + let router = RouterService(CounterRouter::new()); + + // Create and run the server + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + + tracing::info!("Server initialized and ready to handle requests"); + Ok(server.run(transport).await?) +} diff --git a/crates/mcp-server/examples/wasi_std_io.rs b/crates/mcp-server/examples/wasi_std_io.rs new file mode 100644 index 00000000..0f94b8ca --- /dev/null +++ b/crates/mcp-server/examples/wasi_std_io.rs @@ -0,0 +1,130 @@ +use mcp_server::{router::RouterService, ByteTransport, Server}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::EnvFilter; +mod common; +use anyhow::Result; +use common::counter::CounterRouter; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<()> { + // Set up file appender for logging + let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log"); + + // Initialize the tracing subscriber with file and stdout logging + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .with_writer(file_appender) + .with_target(false) + .with_thread_ids(true) + .with_file(true) + .with_line_number(true) + .init(); + + tracing::info!("Starting MCP server"); + + // Create an instance of our counter router + let router = RouterService(CounterRouter::new()); + + // Create and run the server + let server = Server::new(router); + + let transport = ByteTransport::new( + async_io::AsyncInputStream::get(), + async_io::AsyncOutputStream::get(), + ); + + tracing::info!("Server initialized and ready to handle requests"); + Ok(server.run(transport).await?) +} + +mod async_io { + use tokio::io::{AsyncRead, AsyncWrite}; + use wasi::{Fd, FD_STDIN, FD_STDOUT}; + + pub struct AsyncInputStream { + fd: Fd, + } + + impl AsyncInputStream { + pub fn get() -> Self { + Self { fd: FD_STDIN } + } + } + + impl AsyncRead for AsyncInputStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let mut temp_buf = vec![0u8; buf.remaining()]; + unsafe { + match wasi::fd_read( + self.fd, + &[wasi::Iovec { + buf: temp_buf.as_mut_ptr(), + buf_len: temp_buf.len(), + }], + ) { + Ok(n) => { + buf.put_slice(&temp_buf[..n]); + std::task::Poll::Ready(Ok(())) + } + Err(err) => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("WASI read error: {}", err), + ))), + } + } + } + } + + pub struct AsyncOutputStream { + fd: Fd, + } + + impl AsyncOutputStream { + pub fn get() -> Self { + Self { fd: FD_STDOUT } + } + } + + impl AsyncWrite for AsyncOutputStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + unsafe { + match wasi::fd_write( + self.fd, + &[wasi::Ciovec { + buf: buf.as_ptr(), + buf_len: buf.len(), + }], + ) { + Ok(n) => std::task::Poll::Ready(Ok(n)), + Err(err) => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("WASI write error: {}", err), + ))), + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // WASI 没有显式的 flush 操作 + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_flush(cx) + } + } +} diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs deleted file mode 100644 index 907cc1b1..00000000 --- a/crates/mcp-server/src/main.rs +++ /dev/null @@ -1,217 +0,0 @@ -use anyhow::Result; -use mcp_core::content::Content; -use mcp_core::handler::{PromptError, ResourceError}; -use mcp_core::prompt::{Prompt, PromptArgument}; -use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; -use mcp_server::router::{CapabilitiesBuilder, RouterService}; -use mcp_server::{ByteTransport, Router, Server}; -use serde_json::Value; -use std::{future::Future, pin::Pin, sync::Arc}; -use tokio::{ - io::{stdin, stdout}, - sync::Mutex, -}; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -use tracing_subscriber::{self, EnvFilter}; - -// A simple counter service that demonstrates the Router trait -#[derive(Clone)] -struct CounterRouter { - counter: Arc>, -} - -impl CounterRouter { - fn new() -> Self { - Self { - counter: Arc::new(Mutex::new(0)), - } - } - - async fn increment(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter += 1; - Ok(*counter) - } - - async fn decrement(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter -= 1; - Ok(*counter) - } - - async fn get_value(&self) -> Result { - let counter = self.counter.lock().await; - Ok(*counter) - } - - fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { - Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() - } -} - -impl Router for CounterRouter { - fn name(&self) -> String { - "counter".to_string() - } - - fn instructions(&self) -> String { - "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() - } - - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new() - .with_tools(false) - .with_resources(false, false) - .with_prompts(false) - .build() - } - - fn list_tools(&self) -> Vec { - vec![ - Tool::new( - "increment".to_string(), - "Increment the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "decrement".to_string(), - "Decrement the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "get_value".to_string(), - "Get the current counter value".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - ] - } - - fn call_tool( - &self, - tool_name: &str, - _arguments: Value, - ) -> Pin, ToolError>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - - Box::pin(async move { - match tool_name.as_str() { - "increment" => { - let value = this.increment().await?; - Ok(vec![Content::text(value.to_string())]) - } - "decrement" => { - let value = this.decrement().await?; - Ok(vec![Content::text(value.to_string())]) - } - "get_value" => { - let value = this.get_value().await?; - Ok(vec![Content::text(value.to_string())]) - } - _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), - } - }) - } - - fn list_resources(&self) -> Vec { - vec![ - self._create_resource_text("str:////Users/to/some/path/", "cwd"), - self._create_resource_text("memo://insights", "memo-name"), - ] - } - - fn read_resource( - &self, - uri: &str, - ) -> Pin> + Send + 'static>> { - let uri = uri.to_string(); - Box::pin(async move { - match uri.as_str() { - "str:////Users/to/some/path/" => { - let cwd = "/Users/to/some/path/"; - Ok(cwd.to_string()) - } - "memo://insights" => { - let memo = - "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; - Ok(memo.to_string()) - } - _ => Err(ResourceError::NotFound(format!( - "Resource {} not found", - uri - ))), - } - }) - } - - fn list_prompts(&self) -> Vec { - vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required agrument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )] - } - - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - match prompt_name.as_str() { - "example_prompt" => { - let prompt = "This is an example prompt with your message here: '{message}'"; - Ok(prompt.to_string()) - } - _ => Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))), - } - }) - } -} - -#[tokio::main] -async fn main() -> Result<()> { - // Set up file appender for logging - let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log"); - - // Initialize the tracing subscriber with file and stdout logging - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) - .with_writer(file_appender) - .with_target(false) - .with_thread_ids(true) - .with_file(true) - .with_line_number(true) - .init(); - - tracing::info!("Starting MCP server"); - - // Create an instance of our counter router - let router = RouterService(CounterRouter::new()); - - // Create and run the server - let server = Server::new(router); - let transport = ByteTransport::new(stdin(), stdout()); - - tracing::info!("Server initialized and ready to handle requests"); - Ok(server.run(transport).await?) -} From ffb797f1005ccc3787b57cf9484887a94d43730d Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 3 Mar 2025 17:53:28 +0800 Subject: [PATCH 5/6] refactor: simplify WASI I/O handling in example and fmt Merged AsyncInputStream and AsyncOutputStream into a single WasiFd struct with std_in() and std_out() methods, reducing code duplication and improving clarity of the WASI standard I/O example. --- crates/mcp-server/examples/common/counter.rs | 9 +++++-- crates/mcp-server/examples/common/mod.rs | 2 +- crates/mcp-server/examples/wasi_std_io.rs | 28 +++++++------------- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/crates/mcp-server/examples/common/counter.rs b/crates/mcp-server/examples/common/counter.rs index a1081984..4936a612 100644 --- a/crates/mcp-server/examples/common/counter.rs +++ b/crates/mcp-server/examples/common/counter.rs @@ -1,6 +1,11 @@ use std::{future::Future, pin::Pin, sync::Arc}; -use mcp_core::{handler::{PromptError, ResourceError}, prompt::{Prompt, PromptArgument}, protocol::ServerCapabilities, Content, Resource, Tool, ToolError}; +use mcp_core::{ + handler::{PromptError, ResourceError}, + prompt::{Prompt, PromptArgument}, + protocol::ServerCapabilities, + Content, Resource, Tool, ToolError, +}; use mcp_server::router::CapabilitiesBuilder; use serde_json::Value; use tokio::sync::Mutex; @@ -176,4 +181,4 @@ impl mcp_server::Router for CounterRouter { } }) } -} \ No newline at end of file +} diff --git a/crates/mcp-server/examples/common/mod.rs b/crates/mcp-server/examples/common/mod.rs index 02b1a761..2a456e37 100644 --- a/crates/mcp-server/examples/common/mod.rs +++ b/crates/mcp-server/examples/common/mod.rs @@ -1,2 +1,2 @@ pub mod counter; -pub mod jsonrpc_frame_codec; \ No newline at end of file +pub mod jsonrpc_frame_codec; diff --git a/crates/mcp-server/examples/wasi_std_io.rs b/crates/mcp-server/examples/wasi_std_io.rs index 0f94b8ca..1424a502 100644 --- a/crates/mcp-server/examples/wasi_std_io.rs +++ b/crates/mcp-server/examples/wasi_std_io.rs @@ -28,10 +28,7 @@ async fn main() -> Result<()> { // Create and run the server let server = Server::new(router); - let transport = ByteTransport::new( - async_io::AsyncInputStream::get(), - async_io::AsyncOutputStream::get(), - ); + let transport = ByteTransport::new(async_io::WasiFd::std_in(), async_io::WasiFd::std_out()); tracing::info!("Server initialized and ready to handle requests"); Ok(server.run(transport).await?) @@ -41,17 +38,20 @@ mod async_io { use tokio::io::{AsyncRead, AsyncWrite}; use wasi::{Fd, FD_STDIN, FD_STDOUT}; - pub struct AsyncInputStream { + pub struct WasiFd { fd: Fd, } - impl AsyncInputStream { - pub fn get() -> Self { + impl WasiFd { + pub fn std_in() -> Self { Self { fd: FD_STDIN } } + pub fn std_out() -> Self { + Self { fd: FD_STDOUT } + } } - impl AsyncRead for AsyncInputStream { + impl AsyncRead for WasiFd { fn poll_read( self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, @@ -79,17 +79,7 @@ mod async_io { } } - pub struct AsyncOutputStream { - fd: Fd, - } - - impl AsyncOutputStream { - pub fn get() -> Self { - Self { fd: FD_STDOUT } - } - } - - impl AsyncWrite for AsyncOutputStream { + impl AsyncWrite for WasiFd { fn poll_write( self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, From 957f183b3fa37ffce72f91fc47684484f8a51934 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 4 Mar 2025 09:42:38 +0800 Subject: [PATCH 6/6] move examples to root --- crates/mcp-server/Cargo.toml | 5 - crates/mcp-server/examples/common/mod.rs | 2 - examples/servers/Cargo.toml | 9 + .../examples => examples/servers/src}/axum.rs | 0 .../servers/src}/common/counter.rs | 9 +- .../src}/common/jsonrpc_frame_codec.rs | 0 examples/servers/src/common/mod.rs | 2 + examples/servers/src/counter_server.rs | 191 +----------------- 8 files changed, 23 insertions(+), 195 deletions(-) delete mode 100644 crates/mcp-server/examples/common/mod.rs rename {crates/mcp-server/examples => examples/servers/src}/axum.rs (100%) rename {crates/mcp-server/examples => examples/servers/src}/common/counter.rs (96%) rename {crates/mcp-server/examples => examples/servers/src}/common/jsonrpc_frame_codec.rs (100%) create mode 100644 examples/servers/src/common/mod.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index f06bfeda..048f18ff 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -24,8 +24,3 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" async-trait = "0.1" - -[dev-dependencies] -axum = { version = "0.8", features = ["macros"] } -tokio-util = { version = "0.7", features = ["io", "codec"]} -rand = { version = "0.8" } \ No newline at end of file diff --git a/crates/mcp-server/examples/common/mod.rs b/crates/mcp-server/examples/common/mod.rs deleted file mode 100644 index 02b1a761..00000000 --- a/crates/mcp-server/examples/common/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod counter; -pub mod jsonrpc_frame_codec; \ No newline at end of file diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 5a64c19c..c8d47783 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -17,6 +17,15 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" futures = "0.3" +[dev-dependencies] +axum = { version = "0.8", features = ["macros"] } +tokio-util = { version = "0.7", features = ["io", "codec"]} +rand = { version = "0.8" } + [[example]] name = "counter-server" path = "src/counter_server.rs" + +[[example]] +name = "axum" +path = "src/axum.rs" \ No newline at end of file diff --git a/crates/mcp-server/examples/axum.rs b/examples/servers/src/axum.rs similarity index 100% rename from crates/mcp-server/examples/axum.rs rename to examples/servers/src/axum.rs diff --git a/crates/mcp-server/examples/common/counter.rs b/examples/servers/src/common/counter.rs similarity index 96% rename from crates/mcp-server/examples/common/counter.rs rename to examples/servers/src/common/counter.rs index a1081984..4936a612 100644 --- a/crates/mcp-server/examples/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,6 +1,11 @@ use std::{future::Future, pin::Pin, sync::Arc}; -use mcp_core::{handler::{PromptError, ResourceError}, prompt::{Prompt, PromptArgument}, protocol::ServerCapabilities, Content, Resource, Tool, ToolError}; +use mcp_core::{ + handler::{PromptError, ResourceError}, + prompt::{Prompt, PromptArgument}, + protocol::ServerCapabilities, + Content, Resource, Tool, ToolError, +}; use mcp_server::router::CapabilitiesBuilder; use serde_json::Value; use tokio::sync::Mutex; @@ -176,4 +181,4 @@ impl mcp_server::Router for CounterRouter { } }) } -} \ No newline at end of file +} diff --git a/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs b/examples/servers/src/common/jsonrpc_frame_codec.rs similarity index 100% rename from crates/mcp-server/examples/common/jsonrpc_frame_codec.rs rename to examples/servers/src/common/jsonrpc_frame_codec.rs diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs new file mode 100644 index 00000000..2a456e37 --- /dev/null +++ b/examples/servers/src/common/mod.rs @@ -0,0 +1,2 @@ +pub mod counter; +pub mod jsonrpc_frame_codec; diff --git a/examples/servers/src/counter_server.rs b/examples/servers/src/counter_server.rs index 907cc1b1..d9fdf614 100644 --- a/examples/servers/src/counter_server.rs +++ b/examples/servers/src/counter_server.rs @@ -1,192 +1,11 @@ use anyhow::Result; -use mcp_core::content::Content; -use mcp_core::handler::{PromptError, ResourceError}; -use mcp_core::prompt::{Prompt, PromptArgument}; -use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; -use mcp_server::router::{CapabilitiesBuilder, RouterService}; -use mcp_server::{ByteTransport, Router, Server}; -use serde_json::Value; -use std::{future::Future, pin::Pin, sync::Arc}; -use tokio::{ - io::{stdin, stdout}, - sync::Mutex, -}; +use mcp_server::router::RouterService; +use mcp_server::{ByteTransport, Server}; +use tokio::io::{stdin, stdout}; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{self, EnvFilter}; -// A simple counter service that demonstrates the Router trait -#[derive(Clone)] -struct CounterRouter { - counter: Arc>, -} - -impl CounterRouter { - fn new() -> Self { - Self { - counter: Arc::new(Mutex::new(0)), - } - } - - async fn increment(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter += 1; - Ok(*counter) - } - - async fn decrement(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter -= 1; - Ok(*counter) - } - - async fn get_value(&self) -> Result { - let counter = self.counter.lock().await; - Ok(*counter) - } - - fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { - Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() - } -} - -impl Router for CounterRouter { - fn name(&self) -> String { - "counter".to_string() - } - - fn instructions(&self) -> String { - "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() - } - - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new() - .with_tools(false) - .with_resources(false, false) - .with_prompts(false) - .build() - } - - fn list_tools(&self) -> Vec { - vec![ - Tool::new( - "increment".to_string(), - "Increment the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "decrement".to_string(), - "Decrement the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "get_value".to_string(), - "Get the current counter value".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - ] - } - - fn call_tool( - &self, - tool_name: &str, - _arguments: Value, - ) -> Pin, ToolError>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - - Box::pin(async move { - match tool_name.as_str() { - "increment" => { - let value = this.increment().await?; - Ok(vec![Content::text(value.to_string())]) - } - "decrement" => { - let value = this.decrement().await?; - Ok(vec![Content::text(value.to_string())]) - } - "get_value" => { - let value = this.get_value().await?; - Ok(vec![Content::text(value.to_string())]) - } - _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), - } - }) - } - - fn list_resources(&self) -> Vec { - vec![ - self._create_resource_text("str:////Users/to/some/path/", "cwd"), - self._create_resource_text("memo://insights", "memo-name"), - ] - } - - fn read_resource( - &self, - uri: &str, - ) -> Pin> + Send + 'static>> { - let uri = uri.to_string(); - Box::pin(async move { - match uri.as_str() { - "str:////Users/to/some/path/" => { - let cwd = "/Users/to/some/path/"; - Ok(cwd.to_string()) - } - "memo://insights" => { - let memo = - "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; - Ok(memo.to_string()) - } - _ => Err(ResourceError::NotFound(format!( - "Resource {} not found", - uri - ))), - } - }) - } - - fn list_prompts(&self) -> Vec { - vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required agrument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )] - } - - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - match prompt_name.as_str() { - "example_prompt" => { - let prompt = "This is an example prompt with your message here: '{message}'"; - Ok(prompt.to_string()) - } - _ => Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))), - } - }) - } -} +mod common; #[tokio::main] async fn main() -> Result<()> { @@ -206,7 +25,7 @@ async fn main() -> Result<()> { tracing::info!("Starting MCP server"); // Create an instance of our counter router - let router = RouterService(CounterRouter::new()); + let router = RouterService(common::counter::CounterRouter::new()); // Create and run the server let server = Server::new(router);