Skip to content

feat: use axum WS #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tracing = "0.1.41"

# axum
axum = { version = "0.8.1", optional = true }
mime = { version = "0.3.17", optional = true}
mime = { version = "0.3.17", optional = true }

# pubsub
tokio-stream = { version = "0.1.17", optional = true }
Expand All @@ -41,11 +41,12 @@ futures-util = { version = "0.3.31", optional = true }
[dev-dependencies]
tempfile = "3.15.0"
tracing-subscriber = "0.3.19"
axum = { version = "*", features = ["macros"] }

[features]
default = ["axum", "ws", "ipc"]
axum = ["dep:axum", "dep:mime"]
pubsub = ["dep:tokio-stream"]
pubsub = ["dep:tokio-stream", "axum?/ws"]
ipc = ["pubsub", "dep:interprocess"]
ws = ["pubsub", "dep:tokio-tungstenite", "dep:futures-util"]

Expand Down
34 changes: 30 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,36 @@
//! # }}
//! ```
//!
//! For WS and IPC connections, the `pubsub` module provides implementations of
//! the `Connect` trait for [`std::net::SocketAddr`] to create simple WS
//! servers, and [`interprocess::local_socket::ListenerOptions`] to create
//! simple IPC servers.
//! Routers can also be served over axum websockets. When both `axum` and
//! `pubsub` features are enabled, the `pubsub` module provides
//! [`pubsub::AxumWsCfg`] and the [`pubsub::ajj_websocket`] axum handler. This
//! handler will serve the router over websockets at a specific route. The
//! router is a property of the `AxumWsCfg` object, and is passed to the
//! handler via axum's `State` extractor.
//!
//! ```no_run
//! # #[cfg(all(feature = "axum", feature = "pubsub"))]
//! # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
//! # {
//! # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
//! // The config object contains the tokio runtime handle, and the
//! // notification buffer size.
//! let cfg = AxumWsCfg::new(router);
//!
//! axum
//! .route("/ws", axum::routing::any(ajj_websocket))
//! .with_state(cfg)
//! # }}
//! ```
//!
//! For IPC and non-axum WebSocket connections, the `pubsub` module provides
//! implementations of the `Connect` trait for [`std::net::SocketAddr`] to
//! create simple WS servers, and
//! [`interprocess::local_socket::ListenerOptions`] to create simple IPC
//! servers. We generally recommend using `axum` for WebSocket connections, as
//! it provides a more complete and robust implementation, however, users
//! needing additional control, or wanting to avoid the `axum` dependency
//! can use the `pubsub` module directly.
//!
//! ```no_run
//! # #[cfg(feature = "pubsub")]
Expand Down
301 changes: 301 additions & 0 deletions src/pubsub/axum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
//! WebSocket connection manager for [`axum`]
//!
//! How this works:
//! `axum` does not provide a connection pattern that allows us to iplement
//! [`Listener`] or [`Connect`] directly. Instead, it uses a
//! [`WebSocketUpgrade`] to upgrade a connection to a WebSocket. This means
//! that we cannot use the [`Listener`] trait directly. Instead, we make a
//! [`AxumWsCfg`] that will be the [`State`] for our handler.
//!
//! The [`ajj_websocket`] handler serves the role of the [`Listener`] in this
//! case.
//!
//! [`Connect`]: crate::pubsub::Connect

use crate::{
pubsub::{shared::ConnectionManager, Listener},
Router,
};
use axum::{
extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
},
response::Response,
};
use bytes::Bytes;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, Stream, StreamExt,
};
use serde_json::value::RawValue;
use std::{
convert::Infallible,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::runtime::Handle;
use tracing::debug;

pub(crate) type SendHalf = SplitSink<WebSocket, Message>;
pub(crate) type RecvHalf = SplitStream<WebSocket>;

struct AxumListener;

impl Listener for AxumListener {
type RespSink = SendHalf;

type ReqStream = WsJsonStream;

type Error = Infallible;

async fn accept(&self) -> Result<(Self::RespSink, Self::ReqStream), Self::Error> {
unreachable!()
}
}

/// Configuration details for WebSocket connections using [`axum::extract::ws`].
///
/// The main points of configuration are:
/// - The runtime [`Handle`] on which to execute tasks, which can be set with
/// [`Self::with_handle`]. This defaults to the current thread's runtime
/// handle.
/// - The notification buffer size per client, which can be set with
/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`]
/// module documentation for more details.
///
/// This struct is used as the [`State`] for the [`ajj_websocket`] handler, and
/// should be created from a fully-configured [`Router<()>`].
///
/// # Note
///
/// If [`AxumWsCfg`] is NOT used within a `tokio` runtime,
/// [`AxumWsCfg::with_handle`] MUST be called to set the runtime handle before
/// any requests are routed. Attempting to execute a task without an active
/// runtime will result in a panic.
///
/// # Example
///
/// ```no_run
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
/// # {
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>, handle: tokio::runtime::Handle) {
/// let cfg = AxumWsCfg::from(router)
/// .with_handle(handle)
/// .with_notification_buffer_per_client(10);
/// # }}
/// ```
#[derive(Clone)]
pub struct AxumWsCfg {
inner: Arc<ConnectionManager>,
}

impl core::fmt::Debug for AxumWsCfg {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AxumWsCfg")
.field(
"notification_buffer_per_client",
&self.inner.notification_buffer_per_task,
)
.field("next_id", &self.inner.next_id)
.finish()
}
}

impl From<Router<()>> for AxumWsCfg {
fn from(router: Router<()>) -> Self {
Self::new(router)
}
}

impl AxumWsCfg {
/// Create a new [`AxumWsCfg`] with the given [`Router`].
pub fn new(router: Router<()>) -> Self {
Self {
inner: ConnectionManager::new(router).into(),
}
}

fn into_inner(self) -> ConnectionManager {
match Arc::try_unwrap(self.inner) {
Ok(inner) => inner,
Err(arc) => ConnectionManager {
root_tasks: arc.root_tasks.clone(),
next_id: arc.next_id.clone(),
router: arc.router.clone(),
notification_buffer_per_task: arc.notification_buffer_per_task,
},
}
}

/// Set the handle on which to execute tasks.
pub fn with_handle(self, handle: Handle) -> Self {
Self {
inner: self.into_inner().with_handle(handle).into(),
}
}

/// Set the notification buffer size per client. See the [`crate::pubsub`]
/// module documentation for more details.
pub fn with_notification_buffer_per_client(
self,
notification_buffer_per_client: usize,
) -> Self {
Self {
inner: self
.into_inner()
.with_notification_buffer_per_client(notification_buffer_per_client)
.into(),
}
}
}

/// Axum handler for WebSocket connections.
///
/// Used to serve [`crate::Router`]s over WebSocket connections via [`axum`]'s
/// built-in WebSocket support. This handler is used in conjunction with
/// [`AxumWsCfg`], which is passed as the [`State`] to the handler.
///
/// # Examples
///
/// Basic usage:
///
/// ```no_run
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
/// # {
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
/// // The config object contains the tokio runtime handle, and the
/// // notification buffer size.
/// let cfg = AxumWsCfg::new(router);
///
/// axum
/// .route("/ws", axum::routing::any(ajj_websocket))
/// .with_state(cfg)
/// # }}
/// ```
///
/// The [`Router`] is a property of the [`AxumWsCfg`]. This means it is not
/// paramterized until the [`axum::Router::with_state`] method is called. This
/// has two significant consequences:
/// 1. You can easily register the same [`Router`] with multiple handlers.
/// 2. In order to register a second [`Router`] you need a second [`AxumWsCfg`].
///
/// Registering the same [`Router`] with multiple handlers:
///
/// ```no_run
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
/// # {
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
/// // The config object contains the tokio runtime handle, and the
/// // notification buffer size.
/// let cfg = AxumWsCfg::new(router);
///
/// axum
/// .route("/ws", axum::routing::any(ajj_websocket))
/// .route("/super-secret-ws", axum::routing::any(ajj_websocket))
/// .with_state(cfg)
/// # }}
/// ```
///
/// Registering a second [`Router`] at a different path:
///
/// ```no_run
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
/// # {
/// # async fn _main(router: Router<()>, other_router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
/// // The config object contains the tokio runtime handle, and the
/// // notification buffer size.
/// let cfg = AxumWsCfg::new(router);
/// let other_cfg = AxumWsCfg::new(other_router);
///
/// axum
/// .route("/really-cool-ws-1", axum::routing::any(ajj_websocket))
/// .with_state(cfg)
/// .route("/even-cooler-ws-2", axum::routing::any(ajj_websocket))
/// .with_state(other_cfg)
/// # }}
/// ```
pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State<AxumWsCfg>) -> Response {
ws.on_upgrade(move |ws| {
let (sink, stream) = ws.split();

state
.inner
.handle_new_connection::<AxumListener>(stream.into(), sink);

async {}
})
}

/// Simple stream adapter for extracting text from a [`WebSocket`].
#[derive(Debug)]
struct WsJsonStream {
inner: RecvHalf,
complete: bool,
}

impl From<RecvHalf> for WsJsonStream {
fn from(inner: RecvHalf) -> Self {
Self {
inner,
complete: false,
}
}
}

impl WsJsonStream {
/// Handle an incoming [`Message`]
fn handle(&self, message: Message) -> Result<Option<Bytes>, &'static str> {
match message {
Message::Text(text) => Ok(Some(text.into())),
Message::Close(Some(frame)) => {
let s = "Received close frame with data";
let reason = format!("{} ({})", frame.reason, frame.code);
debug!(%reason, "{}", &s);
Err(s)
}
Message::Close(None) => {
let s = "WS client has gone away";
debug!("{}", &s);
Err(s)
}
_ => Ok(None),
}
}
}

impl Stream for WsJsonStream {
type Item = Bytes;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.complete {
return Poll::Ready(None);
}

let Some(Ok(msg)) = ready!(self.inner.poll_next_unpin(cx)) else {
self.complete = true;
return Poll::Ready(None);
};

match self.handle(msg) {
Ok(Some(item)) => return Poll::Ready(Some(item)),
Ok(None) => continue,
Err(_) => self.complete = true,
}
}
}
}

impl crate::pubsub::JsonSink for SendHalf {
type Error = axum::Error;

async fn send_json(&mut self, json: Box<RawValue>) -> Result<(), Self::Error> {
self.send(Message::text(json.get())).await
}
}
5 changes: 5 additions & 0 deletions src/pubsub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,8 @@ pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out};

#[cfg(feature = "ws")]
mod ws;

#[cfg(feature = "axum")]
mod axum;
#[cfg(feature = "axum")]
pub use axum::{ajj_websocket, AxumWsCfg};
Loading