From 0558286b2f07176b1e56e54b8c79ebc7920ef9e7 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 09:41:11 -0300 Subject: [PATCH 01/10] feat: use axum WS --- Cargo.toml | 5 +- src/lib.rs | 30 +++++- src/pubsub/axum.rs | 229 +++++++++++++++++++++++++++++++++++++++++++ src/pubsub/mod.rs | 5 + src/pubsub/shared.rs | 52 +++++++--- src/pubsub/trait.rs | 12 +-- 6 files changed, 309 insertions(+), 24 deletions(-) create mode 100644 src/pubsub/axum.rs diff --git a/Cargo.toml b/Cargo.toml index abd0d0e..02500f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ tower = { version = "0.5.2", features = ["util"] } tracing = "0.1.41" # axum -axum = { version = "0.8.1", optional = true } -mime = { version = "0.3.17", optional = true} +axum = { version = "0.8.1", optional = true, features = ["ws"] } +mime = { version = "0.3.17", optional = true } # pubsub tokio-stream = { version = "0.1.17", optional = true } @@ -41,6 +41,7 @@ futures-util = { version = "0.3.31", optional = true } [dev-dependencies] tempfile = "3.15.0" tracing-subscriber = "0.3.19" +axum = { version = "*", features = ["macros"] } [features] default = ["axum", "ws", "ipc"] diff --git a/src/lib.rs b/src/lib.rs index 1d5ba76..226b1a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,10 +86,32 @@ //! # }} //! ``` //! -//! For WS and IPC connections, the `pubsub` module provides implementations of -//! the `Connect` trait for [`std::net::SocketAddr`] to create simple WS -//! servers, and [`interprocess::local_socket::ListenerOptions`] to create -//! simple IPC servers. +//! Routers can also be served over axum websockets. When both `axum` and +//! `pubsub` features are enabled, the `pubsub` module provides +//! [`pubsub::AxumWsCfg`] and the [`pubsub::ajj_websocket`] axum handler. This +//! handler will serve the router over websockets at a specific route. +//! +//! ```no_run +//! # #[cfg(all(feature = "axum", feature = "pubsub"))] +//! # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +//! # use std::sync::Arc; +//! # { +//! # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +//! // The config object contains the tokio runtime handle, and the +//! // notification buffer size. +//! let cfg = AxumWsCfg::new(router); +//! +//! axum +//! .route("/ws", axum::routing::any(ajj_websocket)) +//! .with_state(cfg) +//! # }} +//! ``` +//! +//! For IPC and non-axum WebSocket connections, the `pubsub` module provides +//! implementations of the `Connect` trait for [`std::net::SocketAddr`] to +//! create simple WS servers, and +//! [`interprocess::local_socket::ListenerOptions`] to create simple IPC +//! servers. //! //! ```no_run //! # #[cfg(feature = "pubsub")] diff --git a/src/pubsub/axum.rs b/src/pubsub/axum.rs new file mode 100644 index 0000000..56e59c8 --- /dev/null +++ b/src/pubsub/axum.rs @@ -0,0 +1,229 @@ +//! WebSocket connection manager for [`axum`] +//! +//! How this works: +//! `axum` does not provide a connection pattern that allows us to iplement +//! [`Listener`] or [`Connect`] directly. Instead, it uses a +//! [`WebSocketUpgrade`] to upgrade a connection to a WebSocket. This means +//! that we cannot use the [`Listener`] trait directly. Instead, we make a +//! [`AxumWsCfg`] that will be the [`State`] for our handler. +//! +//! The [`ajj_websocket`] handler serves the role of the [`Listener`] in this +//! case. +//! +//! [`Connect`]: crate::pubsub::Connect + +use crate::{ + pubsub::{shared::ConnectionManager, Listener}, + Router, +}; +use axum::{ + extract::{ + ws::{Message, WebSocket}, + State, WebSocketUpgrade, + }, + response::Response, +}; +use bytes::Bytes; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, Stream, StreamExt, +}; +use serde_json::value::RawValue; +use std::{ + convert::Infallible, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; +use tokio::runtime::Handle; +use tracing::debug; + +pub(crate) type SendHalf = SplitSink; +pub(crate) type RecvHalf = SplitStream; + +struct AxumListener; + +impl Listener for AxumListener { + type RespSink = SendHalf; + + type ReqStream = WsJsonStream; + + type Error = Infallible; + + fn accept( + &self, + ) -> impl std::prelude::rust_2024::Future< + Output = Result<(Self::RespSink, Self::ReqStream), Self::Error>, + > + Send { + async { unreachable!() } + } +} + +/// Configuration details for WebSocket connections using [`axum::extract::ws`]. +/// +/// The main points of configuration are: +/// - The runtime [`Handle`] on which to execute tasks, which can be set with +/// [`Self::with_handle`]. +/// - The notification buffer size per client, which can be set with +/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`] +/// module documentation for more details. +#[derive(Clone)] +pub struct AxumWsCfg { + inner: Arc, +} + +impl core::fmt::Debug for AxumWsCfg { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("AxumWsCfg") + .field( + "notification_buffer_per_client", + &self.inner.notification_buffer_per_task, + ) + .field("next_id", &self.inner.next_id) + .finish() + } +} + +impl From> for AxumWsCfg { + fn from(router: Router<()>) -> Self { + Self::new(router) + } +} + +impl AxumWsCfg { + /// Create a new [`AxumWsCfg`] with the given [`Router`]. + pub fn new(router: Router<()>) -> Self { + Self { + inner: ConnectionManager::new(router).into(), + } + } + + fn into_inner(self) -> ConnectionManager { + match Arc::try_unwrap(self.inner) { + Ok(inner) => inner, + Err(arc) => ConnectionManager { + root_tasks: arc.root_tasks.clone(), + next_id: arc.next_id.clone(), + router: arc.router.clone(), + notification_buffer_per_task: arc.notification_buffer_per_task, + }, + } + } + + /// Set the handle on which to execute tasks. + pub fn with_handle(self, handle: Handle) -> Self { + Self { + inner: self.into_inner().with_handle(handle).into(), + } + } + + /// Set the notification buffer size per client. + pub fn with_notification_buffer_per_client( + self, + notification_buffer_per_client: usize, + ) -> Self { + Self { + inner: self + .into_inner() + .with_notification_buffer_per_client(notification_buffer_per_client) + .into(), + } + } +} + +/// Axum handler for WebSocket connections. Used to serve +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # use std::sync::Arc; +/// # { +/// # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +/// // The config object contains the tokio runtime handle, and the +/// // notification buffer size. +/// let cfg = AxumWsCfg::new(router); +/// +/// axum +/// .route("/ws", axum::routing::any(ajj_websocket)) +/// .with_state(cfg) +/// # }} +/// ``` +pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State>) -> Response { + ws.on_upgrade(move |ws| { + let (sink, stream) = ws.split(); + + state + .inner + .handle_new_connection::(stream.into(), sink); + + async {} + }) +} + +/// Simple stream adapter for extracting text from a [`WebSocket`]. +#[derive(Debug)] +struct WsJsonStream { + inner: RecvHalf, + complete: bool, +} + +impl From for WsJsonStream { + fn from(inner: RecvHalf) -> Self { + Self { + inner, + complete: false, + } + } +} + +impl WsJsonStream { + /// Handle an incoming [`Message`] + fn handle(&self, message: Message) -> Result, &'static str> { + match message { + Message::Text(text) => Ok(Some(text.into())), + Message::Close(Some(frame)) => { + let s = "Received close frame with data"; + let reason = format!("{} ({})", frame.reason, frame.code); + debug!(%reason, "{}", &s); + Err(s) + } + Message::Close(None) => { + let s = "WS client has gone away"; + debug!("{}", &s); + Err(s) + } + _ => Ok(None), + } + } +} + +impl Stream for WsJsonStream { + type Item = Bytes; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.complete { + return Poll::Ready(None); + } + + let Some(Ok(msg)) = ready!(self.inner.poll_next_unpin(cx)) else { + self.complete = true; + return Poll::Ready(None); + }; + + match self.handle(msg) { + Ok(Some(item)) => return Poll::Ready(Some(item)), + Ok(None) => continue, + Err(_) => self.complete = true, + } + } + } +} + +impl crate::pubsub::JsonSink for SendHalf { + type Error = axum::Error; + + async fn send_json(&mut self, json: Box) -> Result<(), Self::Error> { + self.send(Message::text(json.get())).await + } +} diff --git a/src/pubsub/mod.rs b/src/pubsub/mod.rs index abb68dd..64dfb79 100644 --- a/src/pubsub/mod.rs +++ b/src/pubsub/mod.rs @@ -105,3 +105,8 @@ pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out}; #[cfg(feature = "ws")] mod ws; + +#[cfg(feature = "axum")] +mod axum; +#[cfg(feature = "axum")] +pub use axum::{ajj_websocket, AxumWsCfg}; diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index c4c324f..6a0e92f 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -5,7 +5,8 @@ use crate::{ }; use core::fmt; use serde_json::value::RawValue; -use tokio::{pin, select, sync::mpsc, task::JoinHandle}; +use std::sync::{atomic::AtomicU64, Arc}; +use tokio::{pin, runtime::Handle, select, sync::mpsc, task::JoinHandle}; use tokio_stream::StreamExt; use tokio_util::sync::WaitForCancellationFutureOwned; use tracing::{debug, debug_span, error, trace, Instrument}; @@ -32,10 +33,7 @@ where /// This future is a simple loop that accepts new connections, and uses /// the [`ConnectionManager`] to handle them. pub(crate) async fn task_future(self) { - let ListenerTask { - listener, - mut manager, - } = self; + let ListenerTask { listener, manager } = self; loop { let (resp_sink, req_stream) = match listener.accept().await { @@ -64,7 +62,7 @@ where pub(crate) struct ConnectionManager { pub(crate) root_tasks: TaskSet, - pub(crate) next_id: ConnectionId, + pub(crate) next_id: Arc, pub(crate) router: crate::Router<()>, @@ -72,11 +70,41 @@ pub(crate) struct ConnectionManager { } impl ConnectionManager { + /// Create a new [`ConnectionManager`] with the given [`crate::Router`]. + pub(crate) fn new(router: crate::Router<()>) -> Self { + Self { + root_tasks: Handle::current().into(), + next_id: AtomicU64::new(0).into(), + router, + notification_buffer_per_task: DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT, + } + } + + /// Set the root task set. + pub(crate) fn with_root_tasks(mut self, root_tasks: TaskSet) -> Self { + self.root_tasks = root_tasks; + self + } + + /// Set the handle, overriding the root tasks. + pub(crate) fn with_handle(mut self, handle: Handle) -> Self { + self.root_tasks = handle.into(); + self + } + + /// Set the notification buffer size per task. + pub(crate) fn with_notification_buffer_per_client( + mut self, + notification_buffer_per_client: usize, + ) -> Self { + self.notification_buffer_per_task = notification_buffer_per_client; + self + } + /// Increment the connection ID counter and return an unused ID. - fn next_id(&mut self) -> ConnectionId { - let id = self.next_id; - self.next_id += 1; - id + fn next_id(&self) -> ConnectionId { + self.next_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) } /// Get a clone of the router. @@ -114,7 +142,7 @@ impl ConnectionManager { } /// Spawn a new [`RouteTask`] and [`WriteTask`] for a connection. - fn spawn_tasks(&mut self, requests: In, connection: Out) { + fn spawn_tasks(&self, requests: In, connection: Out) { let conn_id = self.next_id(); let (rt, wt) = self.make_tasks::(conn_id, requests, connection); rt.spawn(); @@ -123,7 +151,7 @@ impl ConnectionManager { /// Handle a new connection, enrolling it in the write task, and spawning /// its route task. - fn handle_new_connection(&mut self, requests: In, connection: Out) { + pub(crate) fn handle_new_connection(&self, requests: In, connection: Out) { self.spawn_tasks::(requests, connection); } } diff --git a/src/pubsub/trait.rs b/src/pubsub/trait.rs index 39281e0..94658f3 100644 --- a/src/pubsub/trait.rs +++ b/src/pubsub/trait.rs @@ -99,16 +99,16 @@ pub trait Connect: Send + Sync + Sized { let root_tasks: TaskSet = handle.into(); let notification_buffer_per_task = self.notification_buffer_size(); + let manager = ConnectionManager::new(router) + .with_root_tasks(root_tasks.clone()) + .with_notification_buffer_per_client(notification_buffer_per_task); + ListenerTask { listener: self.make_listener().await?, - manager: ConnectionManager { - next_id: 0, - router, - notification_buffer_per_task, - root_tasks: root_tasks.clone(), - }, + manager, } .spawn(); + Ok(root_tasks.into()) } } From 5a075d643d4e8372d1393057bf61d3e7524603fc Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 09:43:49 -0300 Subject: [PATCH 02/10] lint: clippy --- src/pubsub/axum.rs | 10 +++------- src/pubsub/shared.rs | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/pubsub/axum.rs b/src/pubsub/axum.rs index 56e59c8..7d6cedc 100644 --- a/src/pubsub/axum.rs +++ b/src/pubsub/axum.rs @@ -50,12 +50,8 @@ impl Listener for AxumListener { type Error = Infallible; - fn accept( - &self, - ) -> impl std::prelude::rust_2024::Future< - Output = Result<(Self::RespSink, Self::ReqStream), Self::Error>, - > + Send { - async { unreachable!() } + async fn accept(&self) -> Result<(Self::RespSink, Self::ReqStream), Self::Error> { + unreachable!() } } @@ -65,7 +61,7 @@ impl Listener for AxumListener { /// - The runtime [`Handle`] on which to execute tasks, which can be set with /// [`Self::with_handle`]. /// - The notification buffer size per client, which can be set with -/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`] +/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`] /// module documentation for more details. #[derive(Clone)] pub struct AxumWsCfg { diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 6a0e92f..8e8e285 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -93,7 +93,7 @@ impl ConnectionManager { } /// Set the notification buffer size per task. - pub(crate) fn with_notification_buffer_per_client( + pub(crate) const fn with_notification_buffer_per_client( mut self, notification_buffer_per_client: usize, ) -> Self { From 250eccb39e66db6ac6c167a1f3bf38bbbe52d856 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 09:46:25 -0300 Subject: [PATCH 03/10] fix: dep spec better spec good now --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 02500f1..4ee2cb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ tower = { version = "0.5.2", features = ["util"] } tracing = "0.1.41" # axum -axum = { version = "0.8.1", optional = true, features = ["ws"] } +axum = { version = "0.8.1", optional = true } mime = { version = "0.3.17", optional = true } # pubsub @@ -46,7 +46,7 @@ axum = { version = "*", features = ["macros"] } [features] default = ["axum", "ws", "ipc"] axum = ["dep:axum", "dep:mime"] -pubsub = ["dep:tokio-stream"] +pubsub = ["dep:tokio-stream", "axum?/ws"] ipc = ["pubsub", "dep:interprocess"] ws = ["pubsub", "dep:tokio-tungstenite", "dep:futures-util"] From 3e90c1867cb46cb572940ae17ec68556a612cace Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 10:05:41 -0300 Subject: [PATCH 04/10] feat: axum_ws tests --- src/pubsub/axum.rs | 2 +- tests/axum_ws.rs | 35 ++++++++++++++++++++++++++ tests/common/mod.rs | 2 ++ tests/common/ws_client.rs | 53 +++++++++++++++++++++++++++++++++++++++ tests/ws.rs | 53 ++------------------------------------- 5 files changed, 93 insertions(+), 52 deletions(-) create mode 100644 tests/axum_ws.rs create mode 100644 tests/common/ws_client.rs diff --git a/src/pubsub/axum.rs b/src/pubsub/axum.rs index 7d6cedc..b536556 100644 --- a/src/pubsub/axum.rs +++ b/src/pubsub/axum.rs @@ -144,7 +144,7 @@ impl AxumWsCfg { /// .with_state(cfg) /// # }} /// ``` -pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State>) -> Response { +pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State) -> Response { ws.on_upgrade(move |ws| { let (sink, stream) = ws.split(); diff --git a/tests/axum_ws.rs b/tests/axum_ws.rs new file mode 100644 index 0000000..ea1a917 --- /dev/null +++ b/tests/axum_ws.rs @@ -0,0 +1,35 @@ +#![cfg(all(feature = "ws", feature = "axum"))] + +mod common; +use common::{test_router, ws_client::ws_client}; + +use ajj::pubsub::AxumWsCfg; +use axum::routing::any; + +// const SOCKET: SocketAddr = SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3399); +const SOCKET_STR: &str = "127.0.0.1:3399"; +const URL: &str = "ws://127.0.0.1:3399"; + +/// Serve the WebSocket server using Axum. +async fn serve() { + let router = test_router(); + + let axum_router = axum::Router::new() + .route("/", any(ajj::pubsub::ajj_websocket)) + .with_state::<()>(AxumWsCfg::new(router)); + + let listener = tokio::net::TcpListener::bind(SOCKET_STR).await.unwrap(); + axum::serve(listener, axum_router).await.unwrap(); +} + +#[tokio::test] +async fn test_ws() { + let _server = tokio::spawn(serve()); + + // Give the server a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let mut client = ws_client(URL).await; + common::basic_tests(&mut client).await; + common::batch_tests(&mut client).await; +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 7c8db3b..b42212a 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,5 @@ +pub mod ws_client; + use ajj::{HandlerCtx, Router}; use serde_json::{json, Value}; use std::time::Duration; diff --git a/tests/common/ws_client.rs b/tests/common/ws_client.rs new file mode 100644 index 0000000..a69e383 --- /dev/null +++ b/tests/common/ws_client.rs @@ -0,0 +1,53 @@ +#![cfg(all(feature = "pubsub", any(feature = "ws", feature = "axum")))] + +use super::TestClient; +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::{ + tungstenite::{client::IntoClientRequest, Message}, + MaybeTlsStream, WebSocketStream, +}; + +/// Create a WebSocket client for testing. +pub async fn ws_client(s: &str) -> WsClient { + let request = s.into_client_request().unwrap(); + let (socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); + + WsClient { socket, id: 0 } +} + +pub struct WsClient { + socket: WebSocketStream>, + id: usize, +} + +impl WsClient { + async fn send_inner(&mut self, msg: &S) { + self.socket + .send(Message::Text(serde_json::to_string(msg).unwrap().into())) + .await + .unwrap(); + } + + async fn recv_inner(&mut self) -> D { + match self.socket.next().await.unwrap().unwrap() { + Message::Text(text) => serde_json::from_str(&text).unwrap(), + _ => panic!("unexpected message type"), + } + } +} + +impl TestClient for WsClient { + fn next_id(&mut self) -> usize { + let id = self.id; + self.id += 1; + id + } + + async fn send_raw(&mut self, msg: &S) { + self.send_inner(msg).await; + } + + async fn recv(&mut self) -> D { + self.recv_inner().await + } +} diff --git a/tests/ws.rs b/tests/ws.rs index 21543d2..0b661db 100644 --- a/tests/ws.rs +++ b/tests/ws.rs @@ -1,15 +1,10 @@ #![cfg(feature = "ws")] mod common; -use common::{test_router, TestClient}; +use common::{test_router, ws_client::ws_client}; use ajj::pubsub::{Connect, ServerShutdown}; -use futures_util::{SinkExt, StreamExt}; use std::net::{Ipv4Addr, SocketAddr}; -use tokio_tungstenite::{ - tungstenite::{client::IntoClientRequest, Message}, - MaybeTlsStream, WebSocketStream, -}; const WS_SOCKET: SocketAddr = SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3383); @@ -20,54 +15,10 @@ async fn serve_ws() -> ServerShutdown { WS_SOCKET.serve(router).await.unwrap() } -struct WsClient { - socket: WebSocketStream>, - id: usize, -} - -impl WsClient { - async fn send_inner(&mut self, msg: &S) { - self.socket - .send(Message::Text(serde_json::to_string(msg).unwrap().into())) - .await - .unwrap(); - } - - async fn recv_inner(&mut self) -> D { - match self.socket.next().await.unwrap().unwrap() { - Message::Text(text) => serde_json::from_str(&text).unwrap(), - _ => panic!("unexpected message type"), - } - } -} - -impl TestClient for WsClient { - fn next_id(&mut self) -> usize { - let id = self.id; - self.id += 1; - id - } - - async fn send_raw(&mut self, msg: &S) { - self.send_inner(msg).await; - } - - async fn recv(&mut self) -> D { - self.recv_inner().await - } -} - -async fn ws_client() -> WsClient { - let request = WS_SOCKET_STR.into_client_request().unwrap(); - let (socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); - - WsClient { socket, id: 0 } -} - #[tokio::test] async fn test_ws() { let _server = serve_ws().await; - let mut client = ws_client().await; + let mut client = ws_client(WS_SOCKET_STR).await; common::basic_tests(&mut client).await; common::batch_tests(&mut client).await; } From 1f494f0bfdda9a1a9f65c3bacb8126ffc6cfedd2 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 10:08:25 -0300 Subject: [PATCH 05/10] lint: clippy --- tests/common/ws_client.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/common/ws_client.rs b/tests/common/ws_client.rs index a69e383..d6a0c16 100644 --- a/tests/common/ws_client.rs +++ b/tests/common/ws_client.rs @@ -8,6 +8,7 @@ use tokio_tungstenite::{ }; /// Create a WebSocket client for testing. +#[allow(dead_code)] pub async fn ws_client(s: &str) -> WsClient { let request = s.into_client_request().unwrap(); let (socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); From e3f3e14421eaa6b43b8c0ced03a0463ddcdbe5c9 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 10:10:03 -0300 Subject: [PATCH 06/10] nit: remove dead line --- tests/axum_ws.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/axum_ws.rs b/tests/axum_ws.rs index ea1a917..1a8693e 100644 --- a/tests/axum_ws.rs +++ b/tests/axum_ws.rs @@ -6,7 +6,6 @@ use common::{test_router, ws_client::ws_client}; use ajj::pubsub::AxumWsCfg; use axum::routing::any; -// const SOCKET: SocketAddr = SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3399); const SOCKET_STR: &str = "127.0.0.1:3399"; const URL: &str = "ws://127.0.0.1:3399"; From 8f4ce67d039d5a0718e3b134ad6c7662c5fe7fc1 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 29 Mar 2025 10:10:33 -0300 Subject: [PATCH 07/10] test: lower delay --- tests/axum_ws.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/axum_ws.rs b/tests/axum_ws.rs index 1a8693e..4c1339d 100644 --- a/tests/axum_ws.rs +++ b/tests/axum_ws.rs @@ -26,7 +26,7 @@ async fn test_ws() { let _server = tokio::spawn(serve()); // Give the server a moment to start - tokio::time::sleep(std::time::Duration::from_secs(1)).await; + tokio::time::sleep(std::time::Duration::from_millis(250)).await; let mut client = ws_client(URL).await; common::basic_tests(&mut client).await; From 21920d033210fd19c450e054bf8881e6f9eb6c61 Mon Sep 17 00:00:00 2001 From: James Date: Sun, 30 Mar 2025 02:56:16 -0300 Subject: [PATCH 08/10] chore: remove unused inports in examples --- src/lib.rs | 1 - src/pubsub/axum.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 226b1a4..5955ec2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,7 +94,6 @@ //! ```no_run //! # #[cfg(all(feature = "axum", feature = "pubsub"))] //! # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; -//! # use std::sync::Arc; //! # { //! # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ //! // The config object contains the tokio runtime handle, and the diff --git a/src/pubsub/axum.rs b/src/pubsub/axum.rs index b536556..490ff5f 100644 --- a/src/pubsub/axum.rs +++ b/src/pubsub/axum.rs @@ -132,7 +132,6 @@ impl AxumWsCfg { /// ```no_run /// # #[cfg(all(feature = "axum", feature = "pubsub"))] /// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; -/// # use std::sync::Arc; /// # { /// # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ /// // The config object contains the tokio runtime handle, and the From 87706ea0aa12752223de19bbcd07dfa9b41d8241 Mon Sep 17 00:00:00 2001 From: James Date: Sun, 30 Mar 2025 03:18:14 -0300 Subject: [PATCH 09/10] docs: expand them :) --- src/lib.rs | 9 +++-- src/pubsub/axum.rs | 83 ++++++++++++++++++++++++++++++++++++++++++-- src/pubsub/shared.rs | 5 +-- 3 files changed, 90 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5955ec2..cb67dc9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,7 +89,9 @@ //! Routers can also be served over axum websockets. When both `axum` and //! `pubsub` features are enabled, the `pubsub` module provides //! [`pubsub::AxumWsCfg`] and the [`pubsub::ajj_websocket`] axum handler. This -//! handler will serve the router over websockets at a specific route. +//! handler will serve the router over websockets at a specific route. The +//! router is a property of the `AxumWsCfg` object, and is passed to the +//! handler via axum's `State` extractor. //! //! ```no_run //! # #[cfg(all(feature = "axum", feature = "pubsub"))] @@ -110,7 +112,10 @@ //! implementations of the `Connect` trait for [`std::net::SocketAddr`] to //! create simple WS servers, and //! [`interprocess::local_socket::ListenerOptions`] to create simple IPC -//! servers. +//! servers. We generally recommend using `axum` for WebSocket connections, as +//! it provides a more complete and robust implementation, however, users +//! needing additional control, or wanting to avoid the `axum` dependency +//! can use the `pubsub` module directly. //! //! ```no_run //! # #[cfg(feature = "pubsub")] diff --git a/src/pubsub/axum.rs b/src/pubsub/axum.rs index 490ff5f..338b099 100644 --- a/src/pubsub/axum.rs +++ b/src/pubsub/axum.rs @@ -59,10 +59,34 @@ impl Listener for AxumListener { /// /// The main points of configuration are: /// - The runtime [`Handle`] on which to execute tasks, which can be set with -/// [`Self::with_handle`]. +/// [`Self::with_handle`]. This defaults to the current thread's runtime +/// handle. /// - The notification buffer size per client, which can be set with /// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`] /// module documentation for more details. +/// +/// This struct is used as the [`State`] for the [`ajj_websocket`] handler, and +/// should be created from a fully-configured [`Router<()>`]. +/// +/// # Note +/// +/// If [`AxumWsCfg`] is NOT used within a `tokio` runtime, +/// [`AxumWsCfg::with_handle`] MUST be called to set the runtime handle before +/// any requests are routed. Attempting to execute a task without an active +/// runtime will result in a panic. +/// +/// # Example +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, axum: axum::Router, handle: tokio::runtime::Handle) { +/// let cfg = AxumWsCfg::from(router) +/// .with_handle(handle) +/// .with_notification_buffer_per_client(10); +/// # }} +/// ``` #[derive(Clone)] pub struct AxumWsCfg { inner: Arc, @@ -113,7 +137,8 @@ impl AxumWsCfg { } } - /// Set the notification buffer size per client. + /// Set the notification buffer size per client. See the [`crate::pubsub`] + /// module documentation for more details. pub fn with_notification_buffer_per_client( self, notification_buffer_per_client: usize, @@ -127,7 +152,15 @@ impl AxumWsCfg { } } -/// Axum handler for WebSocket connections. Used to serve +/// Axum handler for WebSocket connections. +/// +/// Used to serve [`crate::Router`]s over WebSocket connections via [`axum`]'s +/// built-in WebSocket support. This handler is used in conjunction with +/// [`AxumWsCfg`], which is passed as the [`State`] to the handler. +/// +/// # Examples +/// +/// Basic usage: /// /// ```no_run /// # #[cfg(all(feature = "axum", feature = "pubsub"))] @@ -143,6 +176,50 @@ impl AxumWsCfg { /// .with_state(cfg) /// # }} /// ``` +/// +/// The [`Router`] is a property of the [`AxumWsCfg`]. This means it is not +/// paramterized until the [`axum::Router::with_state`] method is called. This +/// has two significant consequences: +/// 1. You can easily register the same [`Router`] with multiple handlers. +/// 2. In order to register a second [`Router`] you need a second [`AxumWsCfg`]. +/// +/// Registering the same [`Router`] with multiple handlers: +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +/// // The config object contains the tokio runtime handle, and the +/// // notification buffer size. +/// let cfg = AxumWsCfg::new(router); +/// +/// axum +/// .route("/ws", axum::routing::any(ajj_websocket)) +/// .route("/super-secret-ws", axum::routing::any(ajj_websocket)) +/// .with_state(cfg) +/// # }} +/// ``` +/// +/// Registering a second [`Router`] at a different path: +/// +/// ```no_run +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; +/// # { +/// # async fn _main(router: Router<()>, other_router: Router<()>, axum: axum::Router) -> axum::Router<()>{ +/// // The config object contains the tokio runtime handle, and the +/// // notification buffer size. +/// let cfg = AxumWsCfg::new(router); +/// let other_cfg = AxumWsCfg::new(other_router); +/// +/// axum +/// .route("/really-cool-ws-1", axum::routing::any(ajj_websocket)) +/// .with_state(cfg) +/// .route("/even-cooler-ws-2", axum::routing::any(ajj_websocket)) +/// .with_state(other_cfg) +/// # }} +/// ``` pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State) -> Response { ws.on_upgrade(move |ws| { let (sink, stream) = ws.split(); diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 8e8e285..fe9bde6 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -73,7 +73,7 @@ impl ConnectionManager { /// Create a new [`ConnectionManager`] with the given [`crate::Router`]. pub(crate) fn new(router: crate::Router<()>) -> Self { Self { - root_tasks: Handle::current().into(), + root_tasks: Default::default(), next_id: AtomicU64::new(0).into(), router, notification_buffer_per_task: DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT, @@ -86,7 +86,8 @@ impl ConnectionManager { self } - /// Set the handle, overriding the root tasks. + /// Set the handle, overriding the root tasks. This should generally not be + /// used after tasks have been spawned. pub(crate) fn with_handle(mut self, handle: Handle) -> Self { self.root_tasks = handle.into(); self From 0f01a22b0a8ae69a35708daaee2a6195584befde Mon Sep 17 00:00:00 2001 From: James Date: Sun, 30 Mar 2025 03:21:52 -0300 Subject: [PATCH 10/10] fix: with_handle --- src/pubsub/shared.rs | 5 +++-- src/tasks.rs | 12 ++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index fe9bde6..74ed4b3 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -80,7 +80,8 @@ impl ConnectionManager { } } - /// Set the root task set. + /// Set the root task set. This should generally not be used after tasks + /// have been spawned. pub(crate) fn with_root_tasks(mut self, root_tasks: TaskSet) -> Self { self.root_tasks = root_tasks; self @@ -89,7 +90,7 @@ impl ConnectionManager { /// Set the handle, overriding the root tasks. This should generally not be /// used after tasks have been spawned. pub(crate) fn with_handle(mut self, handle: Handle) -> Self { - self.root_tasks = handle.into(); + self.root_tasks = self.root_tasks.with_handle(handle); self } diff --git a/src/tasks.rs b/src/tasks.rs index f445d34..1bb9cef 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -20,14 +20,14 @@ pub(crate) struct TaskSet { impl From for TaskSet { fn from(handle: Handle) -> Self { - Self::with_handle(handle) + Self::from_handle(handle) } } #[allow(dead_code)] // used in pubsub and axum features impl TaskSet { /// Create a new [`TaskSet`] with a handle. - pub(crate) fn with_handle(handle: Handle) -> Self { + pub(crate) fn from_handle(handle: Handle) -> Self { Self { tasks: TaskTracker::new(), token: CancellationToken::new(), @@ -35,6 +35,14 @@ impl TaskSet { } } + /// Change the handle for the task set. This is used to spawn tasks on a + /// specific runtime.This should generally not be called after tasks have + /// been spawned. + pub(crate) fn with_handle(mut self, handle: Handle) -> Self { + self.handle = Some(handle); + self + } + /// Get a handle to the runtime that the task set is running on. /// /// ## Panics