Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 221 additions & 1 deletion src/clob/ws/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::time::Duration;

use async_stream::try_stream;
use dashmap::mapref::one::{Ref, RefMut};
Expand All @@ -20,6 +21,7 @@ use crate::types::{Address, B256, Decimal, U256};
use crate::ws::ConnectionManager;
use crate::ws::config::Config;
use crate::ws::connection::ConnectionState;
use crate::ws::task::AbortOnDrop;

/// WebSocket client for real-time market data and user updates.
///
Expand Down Expand Up @@ -388,6 +390,19 @@ impl<S: State> Client<S> {
.sum()
}

/// Gracefully close all open WebSocket channels.
///
/// Sends a normal WebSocket Close frame for each open channel, waits up to
/// `timeout` for the peer Close frame, then drops the channel resources.
///
/// # Errors
///
/// Returns the first close error observed after attempting to close every
/// open channel.
pub async fn close(&self, timeout: Duration) -> Result<()> {
self.inner.close(timeout).await
}

/// Unsubscribe from orderbook updates for specific assets.
///
/// This decrements the reference count for each asset. The server unsubscribe
Expand All @@ -407,6 +422,22 @@ impl<S: State> Client<S> {
self.unsubscribe_orderbook(asset_ids)
}

/// Unsubscribe from last trade price updates for specific assets.
///
/// This decrements the reference count for each asset. The server unsubscribe
/// is only sent when no other subscriptions are using those assets.
pub fn unsubscribe_last_trade_price(&self, asset_ids: &[U256]) -> Result<()> {
self.unsubscribe_orderbook(asset_ids)
}

/// Unsubscribe from best bid/ask updates for specific assets.
///
/// This decrements the reference count for each asset. The server unsubscribe
/// is only sent when no other subscriptions are using those assets.
pub fn unsubscribe_best_bid_ask(&self, asset_ids: &[U256]) -> Result<()> {
self.unsubscribe_orderbook(asset_ids)
}

/// Unsubscribe from tick size change updates for specific assets.
///
/// This decrements the reference count for each asset. The server unsubscribe
Expand Down Expand Up @@ -591,6 +622,34 @@ impl<S: State> ClientInner<S> {
self.channels.get(&channel_type)
}

async fn close(&self, timeout: Duration) -> Result<()> {
let channels: Vec<(
ChannelType,
ConnectionManager<WsMessage, Arc<InterestTracker>>,
)> = self
.channels
.iter()
.map(|entry| (*entry.key(), entry.value().connection.clone()))
.collect();
Comment on lines +626 to +633
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Prevent channels from being created during close

ClientInner::close takes a one-time snapshot of self.channels and only closes/removes those entries. Because subscription APIs can concurrently call get_or_create_channel, a new channel inserted after this snapshot is never visited, so close() may return success while leaving a live WebSocket channel running. This appears in multi-task usage where one task shuts down while another subscribes.

Useful? React with 👍 / 👎.


let mut first_error = None;
for (channel_type, connection) in channels {
if let Err(error) = connection.close(timeout).await
&& first_error.is_none()
{
first_error = Some(error);
}

self.channels.remove(&channel_type);
}

if let Some(error) = first_error {
Err(error)
} else {
Ok(())
}
}

/// Helper to unsubscribe and remove connection if there are no more subscriptions on this channel
fn unsubscribe_and_cleanup<F>(&self, channel_type: ChannelType, unsubscribe_fn: F) -> Result<()>
where
Expand Down Expand Up @@ -622,6 +681,15 @@ impl<S: State> ClientInner<S> {
struct ChannelResources {
connection: ConnectionManager<WsMessage, Arc<InterestTracker>>,
subscriptions: Arc<SubscriptionManager>,
/// Owns the reconnection task spawned by
/// [`SubscriptionManager::start_reconnection_handler`]. The wrapper
/// aborts the task on drop, which releases the strong
/// `Arc<SubscriptionManager>` clone held by the task's future and
/// breaks the reference cycle that would otherwise leak the whole
/// channel (task, WebSocket, subscription manager) for the lifetime
/// of the process — see issue #325 and [`AbortOnDrop`].
#[expect(dead_code, reason = "Field held only for its Drop side effect")]
reconnect_handle: AbortOnDrop,
}

impl ChannelResources {
Expand All @@ -630,11 +698,12 @@ impl ChannelResources {
let connection = ConnectionManager::new(endpoint, config, Arc::clone(&interest))?;
let subscriptions = Arc::new(SubscriptionManager::new(connection.clone(), interest));

subscriptions.start_reconnection_handler();
let reconnect_handle = AbortOnDrop::new(subscriptions.start_reconnection_handler());

Ok(Self {
connection,
subscriptions,
reconnect_handle,
})
}

Expand Down Expand Up @@ -664,3 +733,154 @@ fn channel_endpoint(base: &str, channel: ChannelType) -> String {
};
format!("{trimmed}/ws/{segment}")
}

#[cfg(test)]
mod close_tests {
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};

use futures::StreamExt as _;
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;

use super::{ChannelType, Client, ConnectionState, SubscriptionManager};
use crate::types::U256;
use crate::ws::config::Config;

const UNROUTABLE_ENDPOINT: &str = "ws://127.0.0.1:1";

async fn wait_for_market_connection(client: &Client) {
let start = Instant::now();

loop {
if matches!(
client.connection_state(ChannelType::Market),
ConnectionState::Connected { .. }
) {
return;
}

assert!(
start.elapsed() < Duration::from_secs(2),
"market WebSocket did not connect"
);
tokio::time::sleep(Duration::from_millis(10)).await;
}
}

async fn wait_for_drop(weak: &Weak<SubscriptionManager>) {
let start = Instant::now();

while weak.strong_count() != 0 && start.elapsed() < Duration::from_secs(2) {
tokio::task::yield_now().await;
tokio::time::sleep(Duration::from_millis(10)).await;
}
}

#[tokio::test]
async fn client_close_timeout_drops_half_closed_socket() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test WebSocket listener");
let endpoint = format!("ws://{}", listener.local_addr().expect("local addr"));

let server = tokio::spawn(async move {
let (socket, _) = listener.accept().await.expect("accept client socket");
let mut websocket = accept_async(socket).await.expect("accept WebSocket");

let subscription = tokio::time::timeout(Duration::from_secs(2), websocket.next())
.await
.expect("subscription frame before timeout")
.expect("subscription frame")
.expect("valid subscription frame");
assert!(
subscription.is_text(),
"expected subscription text frame, got {subscription:?}"
);

// Do not read or echo the client's Close frame until after the
// client's close timeout has elapsed. This exercises the
// half-closed path where the peer does not complete the close
// handshake promptly.
tokio::time::sleep(Duration::from_millis(200)).await;

let close = tokio::time::timeout(Duration::from_secs(2), async {
loop {
let message = websocket
.next()
.await
.expect("close frame")
.expect("valid close frame");

if matches!(message, Message::Close(_)) {
return message;
}
}
})
.await
.expect("close frame before timeout");

assert!(matches!(close, Message::Close(_)));

matches!(
tokio::time::timeout(Duration::from_secs(2), websocket.next()).await,
Ok(None | Some(Err(_)))
)
});

let client = Client::new(&endpoint, Config::default()).expect("Client::new");
let _stream = client
.subscribe_orderbook(vec![U256::from(1_u64)])
.expect("subscribe_orderbook");
wait_for_market_connection(&client).await;

client
.close(Duration::from_millis(50))
.await
.expect("client close");

assert_eq!(
client.subscription_count(),
0,
"close should remove all channel resources"
);
assert!(
server.await.expect("server task"),
"server socket stayed open"
);
}

#[tokio::test]
async fn arc_strong_count_returns_to_zero_after_subscription_manager_drop() {
let client = Client::new(UNROUTABLE_ENDPOINT, Config::default()).expect("Client::new");
let stream = client
.subscribe_orderbook(vec![U256::from(1_u64)])
.expect("subscribe_orderbook");

let subscriptions = {
let resources = client
.inner
.channel(ChannelType::Market)
.expect("market channel resources");
Arc::clone(&resources.subscriptions)
};

assert!(
Arc::strong_count(&subscriptions) >= 2,
"reconnection task should have cloned the SubscriptionManager"
);

let weak = Arc::downgrade(&subscriptions);
drop(stream);
drop(client);
drop(subscriptions);
wait_for_drop(&weak).await;

assert!(
weak.upgrade().is_none(),
"SubscriptionManager leaked after client drop: strong_count={}",
weak.strong_count()
);
}
}
77 changes: 75 additions & 2 deletions src/clob/ws/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,17 @@ impl SubscriptionManager {
}

/// Start the reconnection handler that re-subscribes on connection recovery.
pub fn start_reconnection_handler(self: &Arc<Self>) {
///
/// Returns the [`tokio::task::JoinHandle`] for the spawned handler so the
/// caller can abort it when the owning client is dropped. The handler
/// holds a strong `Arc<Self>` clone and also owns a clone of the
/// underlying [`ConnectionManager`]; without external cancellation, the
/// `watch::Sender` it waits on can never close, so the task (and every
/// `Arc` it transitively keeps alive) leaks for the lifetime of the
/// process. Callers MUST retain the returned handle and `abort()` it in
/// their `Drop` impl to break this cycle — see
/// [`crate::clob::ws::client`] for the canonical pattern.
pub fn start_reconnection_handler(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
let this = Arc::clone(self);

tokio::spawn(async move {
Expand Down Expand Up @@ -140,7 +150,7 @@ impl SubscriptionManager {
}
}
}
});
})
}

/// Re-send subscription requests for all tracked assets and markets.
Expand Down Expand Up @@ -563,3 +573,66 @@ impl SubscriptionManager {
Ok(())
}
}

#[cfg(test)]
mod reconnect_handler_tests {
//! Regression tests for issue #325: the reconnection handler task used
//! to hold a strong `Arc<SubscriptionManager>` that was only released
//! when the connection's watch `Sender` closed — but the task itself
//! kept that `Sender` alive (via a cloned `ConnectionManager` inside
//! the manager), creating a refcount cycle that leaked the entire
//! channel whenever a `WsClient` was dropped.

use std::sync::{Arc, Weak};
use std::time::Duration;

use super::{InterestTracker, SubscriptionManager};
use crate::ws::ConnectionManager;
use crate::ws::config::Config;
use crate::ws::task::AbortOnDrop;

/// Endpoint that resolves immediately and refuses connections, so the
/// underlying connection task never blocks on DNS or a slow TCP handshake.
const UNROUTABLE_ENDPOINT: &str = "ws://127.0.0.1:1";

#[tokio::test]
async fn aborting_reconnect_handle_releases_subscription_manager() {
let interest = Arc::new(InterestTracker::new());
let connection = ConnectionManager::new(
UNROUTABLE_ENDPOINT.to_owned(),
Config::default(),
Arc::clone(&interest),
)
.expect("ConnectionManager::new");

let subscriptions = Arc::new(SubscriptionManager::new(connection, interest));
let reconnect_handle = AbortOnDrop::new(subscriptions.start_reconnection_handler());

// The spawned task holds an extra strong clone; with the owner clone
// we should observe at least 2 strong refs.
assert!(
Arc::strong_count(&subscriptions) >= 2,
"reconnection task should have cloned an Arc<SubscriptionManager>"
);

let weak: Weak<SubscriptionManager> = Arc::downgrade(&subscriptions);
drop(subscriptions);
drop(reconnect_handle);

// Abort is observed on the next scheduler tick. Poll briefly rather
// than sleeping a hardcoded amount — on a loaded CI runner the task
// may need a few yields before its future is actually dropped.
let start = std::time::Instant::now();
while weak.strong_count() != 0 && start.elapsed() < Duration::from_secs(2) {
tokio::task::yield_now().await;
tokio::time::sleep(Duration::from_millis(10)).await;
}

assert!(
weak.upgrade().is_none(),
"SubscriptionManager leaked after reconnect handle aborted: \
strong_count={} (issue #325 regression)",
weak.strong_count(),
);
}
}
Loading
Loading