diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index 2461684a..59701845 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -28,6 +28,7 @@ url = "2.4" tower = "0.5" axum = "0.8" reqwest = "0.12" +clap = { version = "4.0", features = ["derive"] } [[example]] name = "clients_sse" @@ -51,4 +52,8 @@ path = "src/collection.rs" [[example]] name = "clients_oauth_client" -path = "src/auth/oauth_client.rs" \ No newline at end of file +path = "src/auth/oauth_client.rs" + +[[example]] +name = "clients_progress_test_client" +path = "src/progress_test_client.rs" \ No newline at end of file diff --git a/examples/clients/README.md b/examples/clients/README.md index 0b171961..0d30adf5 100644 --- a/examples/clients/README.md +++ b/examples/clients/README.md @@ -4,6 +4,7 @@ This directory contains Model Context Protocol (MCP) client examples implemented ## Example List + ### SSE Client (`sse.rs`) A client that communicates with an MCP server using Server-Sent Events (SSE) transport. @@ -57,6 +58,15 @@ A client demonstrating how to authenticate with an MCP server using OAuth. - Establishes an authorized connection to the MCP server using the acquired access token - Demonstrates how to use the authorized connection to retrieve available tools and prompts +### Progress Test Client (`progress_test_client.rs`) + +A client that communicates with an MCP server using progress notifications. + +- Launches the `cargo run --example clients_progress_test_client -- --transport {stdio|sse|http|all}` to test the progress notifications +- Connects to the server using different transport methods +- Tests the progress notifications + + ## How to Run Each example can be run using Cargo: diff --git a/examples/clients/src/progress_test_client.rs b/examples/clients/src/progress_test_client.rs new file mode 100644 index 00000000..2037b3ed --- /dev/null +++ b/examples/clients/src/progress_test_client.rs @@ -0,0 +1,388 @@ +use std::{ + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use clap::{Parser, ValueEnum}; +use rmcp::{ + ClientHandler, ServiceExt, + model::{ + CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation, + ProgressNotificationParam, + }, + service::{NotificationContext, RoleClient}, + transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess}, +}; +use tokio::{process::Command, time::sleep}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[derive(Debug, Clone, ValueEnum)] +enum TransportType { + Stdio, + Sse, + Http, + All, +} + +#[derive(Parser)] +#[command(name = "progress-test-client")] +#[command(about = "RMCP Progress Notification Test Client")] +struct Args { + /// Transport type to test + #[arg(short, long, value_enum, default_value_t = TransportType::Stdio)] + transport: TransportType, + + /// Number of records to process + #[arg(short, long, default_value_t = 10)] + records: u32, + + /// SSE server URL + #[arg(long, default_value = "http://127.0.0.1:8000/sse")] + sse_url: String, + + /// HTTP server URL + #[arg(long, default_value = "http://127.0.0.1:8001/mcp")] + http_url: String, +} + +#[derive(Debug, Clone)] +struct ProgressTracker { + start_time: Instant, + progress_count: Arc>, +} + +impl ProgressTracker { + fn new() -> Self { + Self { + start_time: Instant::now(), + progress_count: Arc::new(Mutex::new(0)), + } + } + + fn handle_progress(&self, params: &ProgressNotificationParam) { + if let Ok(mut count) = self.progress_count.lock() { + *count += 1; + } + let elapsed = self.start_time.elapsed(); + + tracing::info!( + "Progress update [{}]: {}/{} - {} (elapsed: {:.1}s)", + params.progress_token.0, + params.progress, + params.total.unwrap_or(0), + params.message.as_deref().unwrap_or(""), + elapsed.as_secs_f64() + ); + } + + fn print_summary(&self) { + let elapsed = self.start_time.elapsed(); + if let Ok(count) = self.progress_count.lock() { + tracing::info!("Total progress notifications received: {}", *count); + } + tracing::info!("Total time elapsed: {:.2}s", elapsed.as_secs_f64()); + } +} + +// Progress-aware client handler +#[derive(Debug, Clone)] +struct ProgressAwareClient { + tracker: Arc>>, +} + +impl ProgressAwareClient { + fn new() -> Self { + Self { + tracker: Arc::new(Mutex::new(None)), + } + } + + fn start_tracking(&self) { + if let Ok(mut tracker) = self.tracker.lock() { + *tracker = Some(ProgressTracker::new()); + } + } + + fn stop_tracking(&self) { + if let Ok(mut tracker_opt) = self.tracker.lock() { + if let Some(tracker) = tracker_opt.take() { + tracker.print_summary(); + } + } + } +} + +impl ClientHandler for ProgressAwareClient { + async fn on_progress( + &self, + params: ProgressNotificationParam, + _context: NotificationContext, + ) { + if let Ok(tracker_opt) = self.tracker.lock() { + if let Some(tracker) = tracker_opt.as_ref() { + tracker.handle_progress(¶ms); + } + } + } + + fn get_info(&self) -> ClientInfo { + ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "progress-test-client".to_string(), + version: "1.0.0".to_string(), + }, + } + } +} + +async fn test_stdio_transport(records: u32) -> Result<()> { + tracing::info!("Testing STDIO Transport"); + tracing::info!("================================"); + + // Start server process + let mut server_cmd = Command::new("cargo"); + server_cmd + .current_dir("../servers") + .arg("run") + .arg("--example") + .arg("servers_progress_demo") + .arg("--") + .arg("stdio"); + + // Create progress-aware client handler + let client_handler = ProgressAwareClient::new(); + client_handler.start_tracking(); + let client_handler_clone = client_handler.clone(); + + let service = client_handler + .serve(TokioChildProcess::new(server_cmd)?) + .await?; + + // Initialize + let server_info = service.peer_info(); + if let Some(info) = server_info { + tracing::info!("Connected to server: {:?}", info.server_info.name); + } + + // List tools + let tools = service.list_all_tools().await?; + tracing::info!( + "Available tools: {:?}", + tools.iter().map(|t| &t.name).collect::>() + ); + + // Call stream processor tool + tracing::info!("Starting to process {} records...", records); + let tool_result = service + .call_tool(CallToolRequestParam { + name: "stream_processor".into(), + arguments: None, + }) + .await?; + + if let Some(content) = tool_result.content.first() { + if let Some(text) = content.as_text() { + tracing::info!("Processing completed: {}", text.text); + } + } + + service.cancel().await?; + client_handler_clone.stop_tracking(); + tracing::info!("STDIO transport test completed successfully!"); + Ok(()) +} +// Test SSE transport, must run the server with `cargo run --example servers_progress_demo -- sse` in the servers directory +async fn test_sse_transport(sse_url: &str, records: u32) -> Result<()> { + tracing::info!("Testing SSE Transport"); + tracing::info!("========================="); + tracing::info!("SSE URL: {}", sse_url); + + // Wait a bit for server to be ready + sleep(Duration::from_secs(1)).await; + + let transport = SseClientTransport::start(sse_url).await?; + + // Create progress-aware client handler + let client_handler = ProgressAwareClient::new(); + client_handler.start_tracking(); + let client_handler_clone = client_handler.clone(); + + let client = client_handler.serve(transport).await.inspect_err(|e| { + tracing::error!("SSE client error: {:?}", e); + })?; + + // Initialize + let server_info = client.peer_info(); + if let Some(info) = server_info { + tracing::info!("Connected to server: {:?}", info.server_info.name); + } + + // List tools + let tools = client.list_tools(Default::default()).await?; + tracing::info!( + "Available tools: {:?}", + tools.tools.iter().map(|t| &t.name).collect::>() + ); + + // Call stream processor tool + tracing::info!("Starting to process {} records...", records); + let tool_result = client + .call_tool(CallToolRequestParam { + name: "stream_processor".into(), + arguments: None, + }) + .await?; + + if let Some(content) = tool_result.content.first() { + if let Some(text) = content.as_text() { + tracing::info!("Processing completed: {}", text.text); + } + } + + client.cancel().await?; + client_handler_clone.stop_tracking(); + tracing::info!("SSE transport test completed successfully!"); + Ok(()) +} +// Test HTTP transport, must run the server with `cargo run --example servers_progress_demo -- http` in the servers directory +async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { + tracing::info!("Testing HTTP Streaming Transport"); + tracing::info!("====================================="); + tracing::info!("HTTP URL: {}", http_url); + + // Wait a bit for server to be ready + sleep(Duration::from_secs(1)).await; + + let transport = StreamableHttpClientTransport::from_uri(http_url); + + // Create progress-aware client handler + let client_handler = ProgressAwareClient::new(); + client_handler.start_tracking(); + let client_handler_clone = client_handler.clone(); + + let client = client_handler.serve(transport).await.inspect_err(|e| { + tracing::error!("HTTP client error: {:?}", e); + })?; + + // Initialize + let server_info = client.peer_info(); + if let Some(info) = server_info { + tracing::info!("Connected to server: {:?}", info.server_info.name); + } + + // List tools + let tools = client.list_tools(Default::default()).await?; + tracing::info!( + "Available tools: {:?}", + tools.tools.iter().map(|t| &t.name).collect::>() + ); + + // Call stream processor tool + tracing::info!("Starting to process {} records...", records); + let tool_result = client + .call_tool(CallToolRequestParam { + name: "stream_processor".into(), + arguments: None, + }) + .await?; + + if let Some(content) = tool_result.content.first() { + if let Some(text) = content.as_text() { + tracing::info!("processing completed: {}", text.text); + } + } + + client.cancel().await?; + client_handler_clone.stop_tracking(); + tracing::info!("HTTP transport test completed successfully!"); + Ok(()) +} + +async fn run_single_test(transport_type: &TransportType, args: &Args) -> Result { + match transport_type { + TransportType::Stdio => { + test_stdio_transport(args.records).await?; + Ok(true) + } + TransportType::Sse => match test_sse_transport(&args.sse_url, args.records).await { + Ok(_) => Ok(true), + Err(e) => { + tracing::error!("SSE test failed: {}", e); + Ok(false) + } + }, + TransportType::Http => match test_http_transport(&args.http_url, args.records).await { + Ok(_) => Ok(true), + Err(e) => { + tracing::error!("HTTP test failed: {}", e); + Ok(false) + } + }, + TransportType::All => { + // This case is handled in main + Ok(true) + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + tracing::info!("RMCP Progress Notification Test Client"); + tracing::info!("=========================================="); + + match &args.transport { + TransportType::All => { + // Test all transport types + let mut results = std::collections::HashMap::new(); + + for transport_type in [ + TransportType::Stdio, + TransportType::Sse, + TransportType::Http, + ] { + let transport_name = format!("{:?}", transport_type).to_uppercase(); + tracing::info!("\n"); + + match run_single_test(&transport_type, &args).await { + Ok(success) => { + results.insert(transport_name, success); + } + Err(e) => { + tracing::error!("{} test failed: {}", transport_name, e); + results.insert(transport_name, false); + } + } + + sleep(Duration::from_secs(2)).await; + } + + // Print summary + tracing::info!("\n=========================================="); + tracing::info!("Test Results Summary"); + tracing::info!("=========================================="); + + for (transport, success) in &results { + let status = if *success { "PASSED" } else { "FAILED" }; + tracing::info!(" {:<10}: {}", transport, status); + } + } + transport_type => { + run_single_test(transport_type, &args).await?; + } + } + + Ok(()) +} diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 776b5f9e..f0601f60 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -1,5 +1,3 @@ - - [package] name = "mcp-server-examples" version = "0.1.5" @@ -9,6 +7,7 @@ publish = false [dependencies] rmcp = { path = "../../crates/rmcp", features = [ "server", + "client", "transport-sse-server", "transport-io", "transport-streamable-http-server", @@ -78,3 +77,7 @@ path = "src/simple_auth_sse.rs" [[example]] name = "counter_hyper_streamable_http" path = "src/counter_hyper_streamable_http.rs" + +[[example]] +name = "servers_progress_demo" +path = "src/progress_demo_test.rs" diff --git a/examples/servers/README.md b/examples/servers/README.md index c80b130a..14280d06 100644 --- a/examples/servers/README.md +++ b/examples/servers/README.md @@ -72,6 +72,14 @@ A simplified OAuth example showing basic token-based authentication. - Simplified authentication flow - Good starting point for adding authentication to MCP servers +### Progress Demo Server (`progress_demo.rs`) + +A server that demonstrates progress notifications during long-running operations. + +- Provides a stream_processor tool that generates progress notifications +- Demonstrates progress notifications during long-running operations +- Can be run with `cargo run --example servers_progress_demo -- {stdio|sse|http|all}` + ## How to Run Each example can be run using Cargo: @@ -97,6 +105,7 @@ cargo run --example servers_complex_auth_sse # Run the simple OAuth SSE server cargo run --example servers_simple_auth_sse + ``` ## Testing with MCP Inspector diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs index 5919bccd..674a8b51 100644 --- a/examples/servers/src/common/mod.rs +++ b/examples/servers/src/common/mod.rs @@ -1,3 +1,4 @@ pub mod calculator; pub mod counter; pub mod generic_service; +pub mod progress_demo; diff --git a/examples/servers/src/common/progress_demo.rs b/examples/servers/src/common/progress_demo.rs new file mode 100644 index 00000000..0738aa87 --- /dev/null +++ b/examples/servers/src/common/progress_demo.rs @@ -0,0 +1,127 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::Stream; +use rmcp::{Error as McpError, RoleServer, ServerHandler, model::*, service::RequestContext, tool}; +use serde_json::json; +use tokio_stream::StreamExt; +use tracing::debug; + +// a Stream data source that generates data in chunks +#[derive(Clone)] +struct StreamDataSource { + data: Vec, + chunk_size: usize, + position: usize, +} + +impl StreamDataSource { + pub fn new(data: Vec, chunk_size: usize) -> Self { + Self { + data, + chunk_size, + position: 0, + } + } + pub fn from_text(text: &str) -> Self { + Self::new(text.as_bytes().to_vec(), 1) + } +} + +impl Stream for StreamDataSource { + type Item = Result, io::Error>; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + if this.position >= this.data.len() { + return Poll::Ready(None); + } + + let start = this.position; + let end = (start + this.chunk_size).min(this.data.len()); + let chunk = this.data[start..end].to_vec(); + this.position = end; + Poll::Ready(Some(Ok(chunk))) + } +} + +#[derive(Clone)] +pub struct ProgressDemo { + data_source: StreamDataSource, +} + +#[tool(tool_box)] +impl ProgressDemo { + #[allow(dead_code)] + pub fn new() -> Self { + Self { + data_source: StreamDataSource::from_text("Hello, world!"), + } + } + #[tool(description = "Process data stream with progress updates")] + async fn stream_processor( + &self, + ctx: RequestContext, + ) -> Result { + let mut counter = 0u32; + + let mut data_source = self.data_source.clone(); + loop { + let chunk = data_source.next().await; + if chunk.is_none() { + break; + } + + let chunk = chunk.unwrap().unwrap(); + let chunk_str = String::from_utf8_lossy(&chunk); + counter += 1; + // create progress notification param + let progress_param = ProgressNotificationParam { + progress_token: ProgressToken(NumberOrString::Number(counter)), + progress: counter, + total: None, + message: Some(chunk_str.to_string()), + }; + + match ctx.peer.notify_progress(progress_param).await { + Ok(_) => { + debug!("Processed record: {}", chunk_str); + } + Err(e) => { + return Err(McpError::internal_error( + format!("Failed to notify progress: {}", e), + Some(json!({ + "record": chunk_str, + "progress": counter, + "error": e.to_string() + })), + )); + } + } + } + + Ok(CallToolResult::success(vec![Content::text(format!( + "Processed {} records successfully", + counter + ))])) + } +} + +#[tool(tool_box)] +impl ServerHandler for ProgressDemo { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2024_11_05, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation::from_build_env(), + instructions: Some( + "This server demonstrates progress notifications during long-running operations. \ + Use the tools to see real-time progress updates for batch processing" + .to_string(), + ), + } + } +} diff --git a/examples/servers/src/progress_demo_test.rs b/examples/servers/src/progress_demo_test.rs new file mode 100644 index 00000000..ebebbc4d --- /dev/null +++ b/examples/servers/src/progress_demo_test.rs @@ -0,0 +1,177 @@ +use std::env; + +use rmcp::{ + ServiceExt, + transport::{ + sse_server::{SseServer, SseServerConfig}, + stdio, + streamable_http_server::{StreamableHttpService, session::local::LocalSessionManager}, + }, +}; + +mod common; +use common::progress_demo::ProgressDemo; + +const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; +const HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Get transport mode from environment variable or command line argument + let transport_mode = env::args() + .nth(1) + .unwrap_or_else(|| env::var("TRANSPORT_MODE").unwrap_or_else(|_| "stdio".to_string())); + + match transport_mode.as_str() { + "stdio" => run_stdio().await, + "sse" => run_sse().await, + "http" | "streamhttp" => run_streamable_http().await, + "all" => run_all_transports().await, + _ => { + eprintln!( + "Usage: {} [stdio|sse|http|all]", + env::args().next().unwrap() + ); + eprintln!("Supported transport modes:"); + eprintln!(" stdio - Standard input/output (default)"); + eprintln!( + " sse - Server-Sent Events at http://{}", + SSE_BIND_ADDRESS + ); + eprintln!( + " http - Streamable HTTP at http://{}", + HTTP_BIND_ADDRESS + ); + eprintln!(" all - Run all transports simultaneously"); + std::process::exit(1); + } + } +} + +async fn run_stdio() -> anyhow::Result<()> { + let server = ProgressDemo::new(); + let service = server.serve(stdio()).await.inspect_err(|e| { + tracing::error!("stdio serving error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) +} + +async fn run_sse() -> anyhow::Result<()> { + println!("Running SSE server"); + let config = SseServerConfig { + bind: SSE_BIND_ADDRESS.parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: tokio_util::sync::CancellationToken::new(), + sse_keep_alive: None, + }; + + let (sse_server, router) = SseServer::new(config); + + // Start the HTTP server for SSE + let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; + let ct = sse_server.config.ct.child_token(); + + let server = axum::serve(listener, router).with_graceful_shutdown(async move { + ct.cancelled().await; + tracing::info!("SSE server cancelled"); + }); + + tokio::spawn(async move { + if let Err(e) = server.await { + tracing::error!(error = %e, "SSE server shutdown with error"); + } + }); + + // Start the MCP service with SSE transport + let ct = sse_server.with_service(ProgressDemo::new); + + tracing::info!( + "Progress Demo SSE server started at http://{}/sse", + SSE_BIND_ADDRESS + ); + tracing::info!("Press Ctrl+C to shutdown"); + + tokio::signal::ctrl_c().await?; + ct.cancel(); + Ok(()) +} + +async fn run_streamable_http() -> anyhow::Result<()> { + println!("Running Streamable HTTP server"); + let service = StreamableHttpService::new( + || Ok(ProgressDemo::new()), + LocalSessionManager::default().into(), + Default::default(), + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind(HTTP_BIND_ADDRESS).await?; + + tracing::info!( + "Progress Demo HTTP server started at http://{}/mcp", + HTTP_BIND_ADDRESS + ); + tracing::info!("Press Ctrl+C to shutdown"); + + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() }) + .await; + + Ok(()) +} + +async fn run_all_transports() -> anyhow::Result<()> { + println!("Running all transports"); + // Start SSE server + let sse_config = SseServerConfig { + bind: SSE_BIND_ADDRESS.parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: tokio_util::sync::CancellationToken::new(), + sse_keep_alive: None, + }; + + let (sse_server, sse_router) = SseServer::new(sse_config); + let sse_listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; + let sse_ct = sse_server.config.ct.child_token(); + + // Start Streamable HTTP server + let http_service = StreamableHttpService::new( + || Ok(ProgressDemo::new()), + LocalSessionManager::default().into(), + Default::default(), + ); + let http_router = axum::Router::new().nest_service("/mcp", http_service); + let http_listener = tokio::net::TcpListener::bind(HTTP_BIND_ADDRESS).await?; + + // Start SSE HTTP server + let sse_http_server = + axum::serve(sse_listener, sse_router).with_graceful_shutdown(async move { + sse_ct.cancelled().await; + tracing::info!("SSE server cancelled"); + }); + + tokio::spawn(async move { + if let Err(e) = sse_http_server.await { + tracing::error!(error = %e, "SSE server shutdown with error"); + } + }); + + // Start Streamable HTTP server + tokio::spawn(async move { + let _ = axum::serve(http_listener, http_router) + .with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() }) + .await; + }); + + // Start MCP service with SSE + let mcp_sse_ct = sse_server.with_service(ProgressDemo::new); + + tokio::signal::ctrl_c().await?; + mcp_sse_ct.cancel(); + + Ok(()) +}