diff --git a/Cargo.toml b/Cargo.toml index abd0d0e..4ee2cb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ tracing = "0.1.41" # axum axum = { version = "0.8.1", optional = true } -mime = { version = "0.3.17", optional = true} +mime = { version = "0.3.17", optional = true } # pubsub tokio-stream = { version = "0.1.17", optional = true } @@ -41,11 +41,12 @@ futures-util = { version = "0.3.31", optional = true } [dev-dependencies] tempfile = "3.15.0" tracing-subscriber = "0.3.19" +axum = { version = "*", features = ["macros"] } [features] default = ["axum", "ws", "ipc"] axum = ["dep:axum", "dep:mime"] -pubsub = ["dep:tokio-stream"] +pubsub = ["dep:tokio-stream", "axum?/ws"] ipc = ["pubsub", "dep:interprocess"] ws = ["pubsub", "dep:tokio-tungstenite", "dep:futures-util"] diff --git a/src/lib.rs b/src/lib.rs index 1d5ba76..cb67dc9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,10 +86,36 @@ //! # }} //! ``` //! -//! For WS and IPC connections, the `pubsub` module provides implementations of -//! the `Connect` trait for [`std::net::SocketAddr`] to create simple WS -//! servers, and [`interprocess::local_socket::ListenerOptions`] to create -//! simple IPC servers. +//! Routers can also be served over axum websockets. When both `axum` and +//! `pubsub` features are enabled, the `pubsub` module provides +//! [`pubsub::AxumWsCfg`] and the [`pubsub::ajj_websocket`] axum handler. This +//! handler will serve the router over websockets at a specific route. The +//! router is a property of the `AxumWsCfg` object, and is passed to the +//! handler via axum's `State` extractor. +//! +//! ```no_run +//! # #[cfg(all(feature = "axum", feature = "pubsub"))] +//! # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +//! # { +//! # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +//! // The config object contains the tokio runtime handle, and the +//! // notification buffer size. +//! let cfg = AxumWsCfg::new(router); +//! +//! axum +//! .route("/ws", axum::routing::any(ajj_websocket)) +//! .with_state(cfg) +//! # }} +//! ``` +//! +//! For IPC and non-axum WebSocket connections, the `pubsub` module provides +//! implementations of the `Connect` trait for [`std::net::SocketAddr`] to +//! create simple WS servers, and +//! [`interprocess::local_socket::ListenerOptions`] to create simple IPC +//! servers. We generally recommend using `axum` for WebSocket connections, as +//! it provides a more complete and robust implementation, however, users +//! needing additional control, or wanting to avoid the `axum` dependency +//! can use the `pubsub` module directly. //! //! ```no_run //! # #[cfg(feature = "pubsub")] diff --git a/src/pubsub/axum.rs b/src/pubsub/axum.rs new file mode 100644 index 0000000..338b099 --- /dev/null +++ b/src/pubsub/axum.rs @@ -0,0 +1,301 @@ +//! WebSocket connection manager for [`axum`] +//! +//! How this works: +//! `axum` does not provide a connection pattern that allows us to iplement +//! [`Listener`] or [`Connect`] directly. Instead, it uses a +//! [`WebSocketUpgrade`] to upgrade a connection to a WebSocket. This means +//! that we cannot use the [`Listener`] trait directly. Instead, we make a +//! [`AxumWsCfg`] that will be the [`State`] for our handler. +//! +//! The [`ajj_websocket`] handler serves the role of the [`Listener`] in this +//! case. +//! +//! [`Connect`]: crate::pubsub::Connect + +use crate::{ + pubsub::{shared::ConnectionManager, Listener}, + Router, +}; +use axum::{ + extract::{ + ws::{Message, WebSocket}, + State, WebSocketUpgrade, + }, + response::Response, +}; +use bytes::Bytes; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, Stream, StreamExt, +}; +use serde_json::value::RawValue; +use std::{ + convert::Infallible, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; +use tokio::runtime::Handle; +use tracing::debug; + +pub(crate) type SendHalf = SplitSink; +pub(crate) type RecvHalf = SplitStream; + +struct AxumListener; + +impl Listener for AxumListener { + type RespSink = SendHalf; + + type ReqStream = WsJsonStream; + + type Error = Infallible; + + async fn accept(&self) -> Result<(Self::RespSink, Self::ReqStream), Self::Error> { + unreachable!() + } +} + +/// Configuration details for WebSocket connections using [`axum::extract::ws`]. +/// +/// The main points of configuration are: +/// - The runtime [`Handle`] on which to execute tasks, which can be set with +/// [`Self::with_handle`]. This defaults to the current thread's runtime +/// handle. +/// - The notification buffer size per client, which can be set with +/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`] +/// module documentation for more details. +/// +/// This struct is used as the [`State`] for the [`ajj_websocket`] handler, and +/// should be created from a fully-configured [`Router<()>`]. +/// +/// # Note +/// +/// If [`AxumWsCfg`] is NOT used within a `tokio` runtime, +/// [`AxumWsCfg::with_handle`] MUST be called to set the runtime handle before +/// any requests are routed. Attempting to execute a task without an active +/// runtime will result in a panic. +/// +/// # Example +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, axum: axum::Router, handle: tokio::runtime::Handle) { +/// let cfg = AxumWsCfg::from(router) +/// .with_handle(handle) +/// .with_notification_buffer_per_client(10); +/// # }} +/// ``` +#[derive(Clone)] +pub struct AxumWsCfg { + inner: Arc, +} + +impl core::fmt::Debug for AxumWsCfg { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("AxumWsCfg") + .field( + "notification_buffer_per_client", + &self.inner.notification_buffer_per_task, + ) + .field("next_id", &self.inner.next_id) + .finish() + } +} + +impl From> for AxumWsCfg { + fn from(router: Router<()>) -> Self { + Self::new(router) + } +} + +impl AxumWsCfg { + /// Create a new [`AxumWsCfg`] with the given [`Router`]. + pub fn new(router: Router<()>) -> Self { + Self { + inner: ConnectionManager::new(router).into(), + } + } + + fn into_inner(self) -> ConnectionManager { + match Arc::try_unwrap(self.inner) { + Ok(inner) => inner, + Err(arc) => ConnectionManager { + root_tasks: arc.root_tasks.clone(), + next_id: arc.next_id.clone(), + router: arc.router.clone(), + notification_buffer_per_task: arc.notification_buffer_per_task, + }, + } + } + + /// Set the handle on which to execute tasks. + pub fn with_handle(self, handle: Handle) -> Self { + Self { + inner: self.into_inner().with_handle(handle).into(), + } + } + + /// Set the notification buffer size per client. See the [`crate::pubsub`] + /// module documentation for more details. + pub fn with_notification_buffer_per_client( + self, + notification_buffer_per_client: usize, + ) -> Self { + Self { + inner: self + .into_inner() + .with_notification_buffer_per_client(notification_buffer_per_client) + .into(), + } + } +} + +/// Axum handler for WebSocket connections. +/// +/// Used to serve [`crate::Router`]s over WebSocket connections via [`axum`]'s +/// built-in WebSocket support. This handler is used in conjunction with +/// [`AxumWsCfg`], which is passed as the [`State`] to the handler. +/// +/// # Examples +/// +/// Basic usage: +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +/// // The config object contains the tokio runtime handle, and the +/// // notification buffer size. +/// let cfg = AxumWsCfg::new(router); +/// +/// axum +/// .route("/ws", axum::routing::any(ajj_websocket)) +/// .with_state(cfg) +/// # }} +/// ``` +/// +/// The [`Router`] is a property of the [`AxumWsCfg`]. This means it is not +/// paramterized until the [`axum::Router::with_state`] method is called. This +/// has two significant consequences: +/// 1. You can easily register the same [`Router`] with multiple handlers. +/// 2. In order to register a second [`Router`] you need a second [`AxumWsCfg`]. +/// +/// Registering the same [`Router`] with multiple handlers: +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +/// // The config object contains the tokio runtime handle, and the +/// // notification buffer size. +/// let cfg = AxumWsCfg::new(router); +/// +/// axum +/// .route("/ws", axum::routing::any(ajj_websocket)) +/// .route("/super-secret-ws", axum::routing::any(ajj_websocket)) +/// .with_state(cfg) +/// # }} +/// ``` +/// +/// Registering a second [`Router`] at a different path: +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, other_router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +/// // The config object contains the tokio runtime handle, and the +/// // notification buffer size. +/// let cfg = AxumWsCfg::new(router); +/// let other_cfg = AxumWsCfg::new(other_router); +/// +/// axum +/// .route("/really-cool-ws-1", axum::routing::any(ajj_websocket)) +/// .with_state(cfg) +/// .route("/even-cooler-ws-2", axum::routing::any(ajj_websocket)) +/// .with_state(other_cfg) +/// # }} +/// ``` +pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State) -> Response { + ws.on_upgrade(move |ws| { + let (sink, stream) = ws.split(); + + state + .inner + .handle_new_connection::(stream.into(), sink); + + async {} + }) +} + +/// Simple stream adapter for extracting text from a [`WebSocket`]. +#[derive(Debug)] +struct WsJsonStream { + inner: RecvHalf, + complete: bool, +} + +impl From for WsJsonStream { + fn from(inner: RecvHalf) -> Self { + Self { + inner, + complete: false, + } + } +} + +impl WsJsonStream { + /// Handle an incoming [`Message`] + fn handle(&self, message: Message) -> Result, &'static str> { + match message { + Message::Text(text) => Ok(Some(text.into())), + Message::Close(Some(frame)) => { + let s = "Received close frame with data"; + let reason = format!("{} ({})", frame.reason, frame.code); + debug!(%reason, "{}", &s); + Err(s) + } + Message::Close(None) => { + let s = "WS client has gone away"; + debug!("{}", &s); + Err(s) + } + _ => Ok(None), + } + } +} + +impl Stream for WsJsonStream { + type Item = Bytes; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.complete { + return Poll::Ready(None); + } + + let Some(Ok(msg)) = ready!(self.inner.poll_next_unpin(cx)) else { + self.complete = true; + return Poll::Ready(None); + }; + + match self.handle(msg) { + Ok(Some(item)) => return Poll::Ready(Some(item)), + Ok(None) => continue, + Err(_) => self.complete = true, + } + } + } +} + +impl crate::pubsub::JsonSink for SendHalf { + type Error = axum::Error; + + async fn send_json(&mut self, json: Box) -> Result<(), Self::Error> { + self.send(Message::text(json.get())).await + } +} diff --git a/src/pubsub/mod.rs b/src/pubsub/mod.rs index abb68dd..64dfb79 100644 --- a/src/pubsub/mod.rs +++ b/src/pubsub/mod.rs @@ -105,3 +105,8 @@ pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out}; #[cfg(feature = "ws")] mod ws; + +#[cfg(feature = "axum")] +mod axum; +#[cfg(feature = "axum")] +pub use axum::{ajj_websocket, AxumWsCfg}; diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index c4c324f..74ed4b3 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -5,7 +5,8 @@ use crate::{ }; use core::fmt; use serde_json::value::RawValue; -use tokio::{pin, select, sync::mpsc, task::JoinHandle}; +use std::sync::{atomic::AtomicU64, Arc}; +use tokio::{pin, runtime::Handle, select, sync::mpsc, task::JoinHandle}; use tokio_stream::StreamExt; use tokio_util::sync::WaitForCancellationFutureOwned; use tracing::{debug, debug_span, error, trace, Instrument}; @@ -32,10 +33,7 @@ where /// This future is a simple loop that accepts new connections, and uses /// the [`ConnectionManager`] to handle them. pub(crate) async fn task_future(self) { - let ListenerTask { - listener, - mut manager, - } = self; + let ListenerTask { listener, manager } = self; loop { let (resp_sink, req_stream) = match listener.accept().await { @@ -64,7 +62,7 @@ where pub(crate) struct ConnectionManager { pub(crate) root_tasks: TaskSet, - pub(crate) next_id: ConnectionId, + pub(crate) next_id: Arc, pub(crate) router: crate::Router<()>, @@ -72,11 +70,43 @@ pub(crate) struct ConnectionManager { } impl ConnectionManager { + /// Create a new [`ConnectionManager`] with the given [`crate::Router`]. + pub(crate) fn new(router: crate::Router<()>) -> Self { + Self { + root_tasks: Default::default(), + next_id: AtomicU64::new(0).into(), + router, + notification_buffer_per_task: DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT, + } + } + + /// Set the root task set. This should generally not be used after tasks + /// have been spawned. + pub(crate) fn with_root_tasks(mut self, root_tasks: TaskSet) -> Self { + self.root_tasks = root_tasks; + self + } + + /// Set the handle, overriding the root tasks. This should generally not be + /// used after tasks have been spawned. + pub(crate) fn with_handle(mut self, handle: Handle) -> Self { + self.root_tasks = self.root_tasks.with_handle(handle); + self + } + + /// Set the notification buffer size per task. + pub(crate) const fn with_notification_buffer_per_client( + mut self, + notification_buffer_per_client: usize, + ) -> Self { + self.notification_buffer_per_task = notification_buffer_per_client; + self + } + /// Increment the connection ID counter and return an unused ID. - fn next_id(&mut self) -> ConnectionId { - let id = self.next_id; - self.next_id += 1; - id + fn next_id(&self) -> ConnectionId { + self.next_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) } /// Get a clone of the router. @@ -114,7 +144,7 @@ impl ConnectionManager { } /// Spawn a new [`RouteTask`] and [`WriteTask`] for a connection. - fn spawn_tasks(&mut self, requests: In, connection: Out) { + fn spawn_tasks(&self, requests: In, connection: Out) { let conn_id = self.next_id(); let (rt, wt) = self.make_tasks::(conn_id, requests, connection); rt.spawn(); @@ -123,7 +153,7 @@ impl ConnectionManager { /// Handle a new connection, enrolling it in the write task, and spawning /// its route task. - fn handle_new_connection(&mut self, requests: In, connection: Out) { + pub(crate) fn handle_new_connection(&self, requests: In, connection: Out) { self.spawn_tasks::(requests, connection); } } diff --git a/src/pubsub/trait.rs b/src/pubsub/trait.rs index 39281e0..94658f3 100644 --- a/src/pubsub/trait.rs +++ b/src/pubsub/trait.rs @@ -99,16 +99,16 @@ pub trait Connect: Send + Sync + Sized { let root_tasks: TaskSet = handle.into(); let notification_buffer_per_task = self.notification_buffer_size(); + let manager = ConnectionManager::new(router) + .with_root_tasks(root_tasks.clone()) + .with_notification_buffer_per_client(notification_buffer_per_task); + ListenerTask { listener: self.make_listener().await?, - manager: ConnectionManager { - next_id: 0, - router, - notification_buffer_per_task, - root_tasks: root_tasks.clone(), - }, + manager, } .spawn(); + Ok(root_tasks.into()) } } diff --git a/src/tasks.rs b/src/tasks.rs index f445d34..1bb9cef 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -20,14 +20,14 @@ pub(crate) struct TaskSet { impl From for TaskSet { fn from(handle: Handle) -> Self { - Self::with_handle(handle) + Self::from_handle(handle) } } #[allow(dead_code)] // used in pubsub and axum features impl TaskSet { /// Create a new [`TaskSet`] with a handle. - pub(crate) fn with_handle(handle: Handle) -> Self { + pub(crate) fn from_handle(handle: Handle) -> Self { Self { tasks: TaskTracker::new(), token: CancellationToken::new(), @@ -35,6 +35,14 @@ impl TaskSet { } } + /// Change the handle for the task set. This is used to spawn tasks on a + /// specific runtime.This should generally not be called after tasks have + /// been spawned. + pub(crate) fn with_handle(mut self, handle: Handle) -> Self { + self.handle = Some(handle); + self + } + /// Get a handle to the runtime that the task set is running on. /// /// ## Panics diff --git a/tests/axum_ws.rs b/tests/axum_ws.rs new file mode 100644 index 0000000..4c1339d --- /dev/null +++ b/tests/axum_ws.rs @@ -0,0 +1,34 @@ +#![cfg(all(feature = "ws", feature = "axum"))] + +mod common; +use common::{test_router, ws_client::ws_client}; + +use ajj::pubsub::AxumWsCfg; +use axum::routing::any; + +const SOCKET_STR: &str = "127.0.0.1:3399"; +const URL: &str = "ws://127.0.0.1:3399"; + +/// Serve the WebSocket server using Axum. +async fn serve() { + let router = test_router(); + + let axum_router = axum::Router::new() + .route("/", any(ajj::pubsub::ajj_websocket)) + .with_state::<()>(AxumWsCfg::new(router)); + + let listener = tokio::net::TcpListener::bind(SOCKET_STR).await.unwrap(); + axum::serve(listener, axum_router).await.unwrap(); +} + +#[tokio::test] +async fn test_ws() { + let _server = tokio::spawn(serve()); + + // Give the server a moment to start + tokio::time::sleep(std::time::Duration::from_millis(250)).await; + + let mut client = ws_client(URL).await; + common::basic_tests(&mut client).await; + common::batch_tests(&mut client).await; +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 7c8db3b..b42212a 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,5 @@ +pub mod ws_client; + use ajj::{HandlerCtx, Router}; use serde_json::{json, Value}; use std::time::Duration; diff --git a/tests/common/ws_client.rs b/tests/common/ws_client.rs new file mode 100644 index 0000000..d6a0c16 --- /dev/null +++ b/tests/common/ws_client.rs @@ -0,0 +1,54 @@ +#![cfg(all(feature = "pubsub", any(feature = "ws", feature = "axum")))] + +use super::TestClient; +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::{ + tungstenite::{client::IntoClientRequest, Message}, + MaybeTlsStream, WebSocketStream, +}; + +/// Create a WebSocket client for testing. +#[allow(dead_code)] +pub async fn ws_client(s: &str) -> WsClient { + let request = s.into_client_request().unwrap(); + let (socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); + + WsClient { socket, id: 0 } +} + +pub struct WsClient { + socket: WebSocketStream>, + id: usize, +} + +impl WsClient { + async fn send_inner(&mut self, msg: &S) { + self.socket + .send(Message::Text(serde_json::to_string(msg).unwrap().into())) + .await + .unwrap(); + } + + async fn recv_inner(&mut self) -> D { + match self.socket.next().await.unwrap().unwrap() { + Message::Text(text) => serde_json::from_str(&text).unwrap(), + _ => panic!("unexpected message type"), + } + } +} + +impl TestClient for WsClient { + fn next_id(&mut self) -> usize { + let id = self.id; + self.id += 1; + id + } + + async fn send_raw(&mut self, msg: &S) { + self.send_inner(msg).await; + } + + async fn recv(&mut self) -> D { + self.recv_inner().await + } +} diff --git a/tests/ws.rs b/tests/ws.rs index 21543d2..0b661db 100644 --- a/tests/ws.rs +++ b/tests/ws.rs @@ -1,15 +1,10 @@ #![cfg(feature = "ws")] mod common; -use common::{test_router, TestClient}; +use common::{test_router, ws_client::ws_client}; use ajj::pubsub::{Connect, ServerShutdown}; -use futures_util::{SinkExt, StreamExt}; use std::net::{Ipv4Addr, SocketAddr}; -use tokio_tungstenite::{ - tungstenite::{client::IntoClientRequest, Message}, - MaybeTlsStream, WebSocketStream, -}; const WS_SOCKET: SocketAddr = SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3383); @@ -20,54 +15,10 @@ async fn serve_ws() -> ServerShutdown { WS_SOCKET.serve(router).await.unwrap() } -struct WsClient { - socket: WebSocketStream>, - id: usize, -} - -impl WsClient { - async fn send_inner(&mut self, msg: &S) { - self.socket - .send(Message::Text(serde_json::to_string(msg).unwrap().into())) - .await - .unwrap(); - } - - async fn recv_inner(&mut self) -> D { - match self.socket.next().await.unwrap().unwrap() { - Message::Text(text) => serde_json::from_str(&text).unwrap(), - _ => panic!("unexpected message type"), - } - } -} - -impl TestClient for WsClient { - fn next_id(&mut self) -> usize { - let id = self.id; - self.id += 1; - id - } - - async fn send_raw(&mut self, msg: &S) { - self.send_inner(msg).await; - } - - async fn recv(&mut self) -> D { - self.recv_inner().await - } -} - -async fn ws_client() -> WsClient { - let request = WS_SOCKET_STR.into_client_request().unwrap(); - let (socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); - - WsClient { socket, id: 0 } -} - #[tokio::test] async fn test_ws() { let _server = serve_ws().await; - let mut client = ws_client().await; + let mut client = ws_client(WS_SOCKET_STR).await; common::basic_tests(&mut client).await; common::batch_tests(&mut client).await; }