Skip to content

feat: use axum WS #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tracing = "0.1.41"

# axum
axum = { version = "0.8.1", optional = true }
mime = { version = "0.3.17", optional = true}
mime = { version = "0.3.17", optional = true }

# pubsub
tokio-stream = { version = "0.1.17", optional = true }
Expand All @@ -41,11 +41,12 @@ futures-util = { version = "0.3.31", optional = true }
[dev-dependencies]
tempfile = "3.15.0"
tracing-subscriber = "0.3.19"
axum = { version = "*", features = ["macros"] }

[features]
default = ["axum", "ws", "ipc"]
axum = ["dep:axum", "dep:mime"]
pubsub = ["dep:tokio-stream"]
pubsub = ["dep:tokio-stream", "axum?/ws"]
ipc = ["pubsub", "dep:interprocess"]
ws = ["pubsub", "dep:tokio-tungstenite", "dep:futures-util"]

Expand Down
30 changes: 26 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,32 @@
//! # }}
//! ```
//!
//! For WS and IPC connections, the `pubsub` module provides implementations of
//! the `Connect` trait for [`std::net::SocketAddr`] to create simple WS
//! servers, and [`interprocess::local_socket::ListenerOptions`] to create
//! simple IPC servers.
//! Routers can also be served over axum websockets. When both `axum` and
//! `pubsub` features are enabled, the `pubsub` module provides
//! [`pubsub::AxumWsCfg`] and the [`pubsub::ajj_websocket`] axum handler. This
//! handler will serve the router over websockets at a specific route.
//!
//! ```no_run
//! # #[cfg(all(feature = "axum", feature = "pubsub"))]
//! # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
//! # use std::sync::Arc;
//! # {
//! # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
//! // The config object contains the tokio runtime handle, and the
//! // notification buffer size.
//! let cfg = AxumWsCfg::new(router);
//!
//! axum
//! .route("/ws", axum::routing::any(ajj_websocket))
//! .with_state(cfg)
//! # }}
//! ```
//!
//! For IPC and non-axum WebSocket connections, the `pubsub` module provides
//! implementations of the `Connect` trait for [`std::net::SocketAddr`] to
//! create simple WS servers, and
//! [`interprocess::local_socket::ListenerOptions`] to create simple IPC
//! servers.
//!
//! ```no_run
//! # #[cfg(feature = "pubsub")]
Expand Down
225 changes: 225 additions & 0 deletions src/pubsub/axum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
//! WebSocket connection manager for [`axum`]
//!
//! How this works:
//! `axum` does not provide a connection pattern that allows us to iplement
//! [`Listener`] or [`Connect`] directly. Instead, it uses a
//! [`WebSocketUpgrade`] to upgrade a connection to a WebSocket. This means
//! that we cannot use the [`Listener`] trait directly. Instead, we make a
//! [`AxumWsCfg`] that will be the [`State`] for our handler.
//!
//! The [`ajj_websocket`] handler serves the role of the [`Listener`] in this
//! case.
//!
//! [`Connect`]: crate::pubsub::Connect

use crate::{
pubsub::{shared::ConnectionManager, Listener},
Router,
};
use axum::{
extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
},
response::Response,
};
use bytes::Bytes;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, Stream, StreamExt,
};
use serde_json::value::RawValue;
use std::{
convert::Infallible,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::runtime::Handle;
use tracing::debug;

pub(crate) type SendHalf = SplitSink<WebSocket, Message>;
pub(crate) type RecvHalf = SplitStream<WebSocket>;

struct AxumListener;

impl Listener for AxumListener {
type RespSink = SendHalf;

type ReqStream = WsJsonStream;

type Error = Infallible;

async fn accept(&self) -> Result<(Self::RespSink, Self::ReqStream), Self::Error> {
unreachable!()
}
}

/// Configuration details for WebSocket connections using [`axum::extract::ws`].
///
/// The main points of configuration are:
/// - The runtime [`Handle`] on which to execute tasks, which can be set with
/// [`Self::with_handle`].
/// - The notification buffer size per client, which can be set with
/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`]
/// module documentation for more details.
#[derive(Clone)]
pub struct AxumWsCfg {
inner: Arc<ConnectionManager>,
}

impl core::fmt::Debug for AxumWsCfg {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AxumWsCfg")
.field(
"notification_buffer_per_client",
&self.inner.notification_buffer_per_task,
)
.field("next_id", &self.inner.next_id)
.finish()
}
}

impl From<Router<()>> for AxumWsCfg {
fn from(router: Router<()>) -> Self {
Self::new(router)
}
}

impl AxumWsCfg {
/// Create a new [`AxumWsCfg`] with the given [`Router`].
pub fn new(router: Router<()>) -> Self {
Self {
inner: ConnectionManager::new(router).into(),
}
}

fn into_inner(self) -> ConnectionManager {
match Arc::try_unwrap(self.inner) {
Ok(inner) => inner,
Err(arc) => ConnectionManager {
root_tasks: arc.root_tasks.clone(),
next_id: arc.next_id.clone(),
router: arc.router.clone(),
notification_buffer_per_task: arc.notification_buffer_per_task,
},
}
}

/// Set the handle on which to execute tasks.
pub fn with_handle(self, handle: Handle) -> Self {
Self {
inner: self.into_inner().with_handle(handle).into(),
}
}

/// Set the notification buffer size per client.
pub fn with_notification_buffer_per_client(
self,
notification_buffer_per_client: usize,
) -> Self {
Self {
inner: self
.into_inner()
.with_notification_buffer_per_client(notification_buffer_per_client)
.into(),
}
}
}

/// Axum handler for WebSocket connections. Used to serve
///
/// ```no_run
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
/// # use std::sync::Arc;
/// # {
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
/// // The config object contains the tokio runtime handle, and the
/// // notification buffer size.
/// let cfg = AxumWsCfg::new(router);
///
/// axum
/// .route("/ws", axum::routing::any(ajj_websocket))
/// .with_state(cfg)
/// # }}
/// ```
pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State<AxumWsCfg>) -> Response {
ws.on_upgrade(move |ws| {
let (sink, stream) = ws.split();

state
.inner
.handle_new_connection::<AxumListener>(stream.into(), sink);

async {}
})
}

/// Simple stream adapter for extracting text from a [`WebSocket`].
#[derive(Debug)]
struct WsJsonStream {
inner: RecvHalf,
complete: bool,
}

impl From<RecvHalf> for WsJsonStream {
fn from(inner: RecvHalf) -> Self {
Self {
inner,
complete: false,
}
}
}

impl WsJsonStream {
/// Handle an incoming [`Message`]
fn handle(&self, message: Message) -> Result<Option<Bytes>, &'static str> {
match message {
Message::Text(text) => Ok(Some(text.into())),
Message::Close(Some(frame)) => {
let s = "Received close frame with data";
let reason = format!("{} ({})", frame.reason, frame.code);
debug!(%reason, "{}", &s);
Err(s)
}
Message::Close(None) => {
let s = "WS client has gone away";
debug!("{}", &s);
Err(s)
}
_ => Ok(None),
}
}
}

impl Stream for WsJsonStream {
type Item = Bytes;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.complete {
return Poll::Ready(None);
}

let Some(Ok(msg)) = ready!(self.inner.poll_next_unpin(cx)) else {
self.complete = true;
return Poll::Ready(None);
};

match self.handle(msg) {
Ok(Some(item)) => return Poll::Ready(Some(item)),
Ok(None) => continue,
Err(_) => self.complete = true,
}
}
}
}

impl crate::pubsub::JsonSink for SendHalf {
type Error = axum::Error;

async fn send_json(&mut self, json: Box<RawValue>) -> Result<(), Self::Error> {
self.send(Message::text(json.get())).await
}
}
5 changes: 5 additions & 0 deletions src/pubsub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,8 @@ pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out};

#[cfg(feature = "ws")]
mod ws;

#[cfg(feature = "axum")]
mod axum;
#[cfg(feature = "axum")]
pub use axum::{ajj_websocket, AxumWsCfg};
52 changes: 40 additions & 12 deletions src/pubsub/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::{
};
use core::fmt;
use serde_json::value::RawValue;
use tokio::{pin, select, sync::mpsc, task::JoinHandle};
use std::sync::{atomic::AtomicU64, Arc};
use tokio::{pin, runtime::Handle, select, sync::mpsc, task::JoinHandle};
use tokio_stream::StreamExt;
use tokio_util::sync::WaitForCancellationFutureOwned;
use tracing::{debug, debug_span, error, trace, Instrument};
Expand All @@ -32,10 +33,7 @@ where
/// This future is a simple loop that accepts new connections, and uses
/// the [`ConnectionManager`] to handle them.
pub(crate) async fn task_future(self) {
let ListenerTask {
listener,
mut manager,
} = self;
let ListenerTask { listener, manager } = self;

loop {
let (resp_sink, req_stream) = match listener.accept().await {
Expand Down Expand Up @@ -64,19 +62,49 @@ where
pub(crate) struct ConnectionManager {
pub(crate) root_tasks: TaskSet,

pub(crate) next_id: ConnectionId,
pub(crate) next_id: Arc<AtomicU64>,

pub(crate) router: crate::Router<()>,

pub(crate) notification_buffer_per_task: usize,
}

impl ConnectionManager {
/// Create a new [`ConnectionManager`] with the given [`crate::Router`].
pub(crate) fn new(router: crate::Router<()>) -> Self {
Self {
root_tasks: Handle::current().into(),
next_id: AtomicU64::new(0).into(),
router,
notification_buffer_per_task: DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT,
}
}

/// Set the root task set.
pub(crate) fn with_root_tasks(mut self, root_tasks: TaskSet) -> Self {
self.root_tasks = root_tasks;
self
}

/// Set the handle, overriding the root tasks.
pub(crate) fn with_handle(mut self, handle: Handle) -> Self {
self.root_tasks = handle.into();
self
}

/// Set the notification buffer size per task.
pub(crate) const fn with_notification_buffer_per_client(
mut self,
notification_buffer_per_client: usize,
) -> Self {
self.notification_buffer_per_task = notification_buffer_per_client;
self
}

/// Increment the connection ID counter and return an unused ID.
fn next_id(&mut self) -> ConnectionId {
let id = self.next_id;
self.next_id += 1;
id
fn next_id(&self) -> ConnectionId {
self.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}

/// Get a clone of the router.
Expand Down Expand Up @@ -114,7 +142,7 @@ impl ConnectionManager {
}

/// Spawn a new [`RouteTask`] and [`WriteTask`] for a connection.
fn spawn_tasks<T: Listener>(&mut self, requests: In<T>, connection: Out<T>) {
fn spawn_tasks<T: Listener>(&self, requests: In<T>, connection: Out<T>) {
let conn_id = self.next_id();
let (rt, wt) = self.make_tasks::<T>(conn_id, requests, connection);
rt.spawn();
Expand All @@ -123,7 +151,7 @@ impl ConnectionManager {

/// Handle a new connection, enrolling it in the write task, and spawning
/// its route task.
fn handle_new_connection<T: Listener>(&mut self, requests: In<T>, connection: Out<T>) {
pub(crate) fn handle_new_connection<T: Listener>(&self, requests: In<T>, connection: Out<T>) {
self.spawn_tasks::<T>(requests, connection);
}
}
Expand Down
Loading