diff --git a/src/clob/ws/client.rs b/src/clob/ws/client.rs index d18d45c8..e038dba8 100644 --- a/src/clob/ws/client.rs +++ b/src/clob/ws/client.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use async_stream::try_stream; use dashmap::mapref::one::{Ref, RefMut}; @@ -20,6 +21,7 @@ use crate::types::{Address, B256, Decimal, U256}; use crate::ws::ConnectionManager; use crate::ws::config::Config; use crate::ws::connection::ConnectionState; +use crate::ws::task::AbortOnDrop; /// WebSocket client for real-time market data and user updates. /// @@ -388,6 +390,19 @@ impl Client { .sum() } + /// Gracefully close all open WebSocket channels. + /// + /// Sends a normal WebSocket Close frame for each open channel, waits up to + /// `timeout` for the peer Close frame, then drops the channel resources. + /// + /// # Errors + /// + /// Returns the first close error observed after attempting to close every + /// open channel. + pub async fn close(&self, timeout: Duration) -> Result<()> { + self.inner.close(timeout).await + } + /// Unsubscribe from orderbook updates for specific assets. /// /// This decrements the reference count for each asset. The server unsubscribe @@ -407,6 +422,22 @@ impl Client { self.unsubscribe_orderbook(asset_ids) } + /// Unsubscribe from last trade price updates for specific assets. + /// + /// This decrements the reference count for each asset. The server unsubscribe + /// is only sent when no other subscriptions are using those assets. + pub fn unsubscribe_last_trade_price(&self, asset_ids: &[U256]) -> Result<()> { + self.unsubscribe_orderbook(asset_ids) + } + + /// Unsubscribe from best bid/ask updates for specific assets. + /// + /// This decrements the reference count for each asset. The server unsubscribe + /// is only sent when no other subscriptions are using those assets. + pub fn unsubscribe_best_bid_ask(&self, asset_ids: &[U256]) -> Result<()> { + self.unsubscribe_orderbook(asset_ids) + } + /// Unsubscribe from tick size change updates for specific assets. /// /// This decrements the reference count for each asset. The server unsubscribe @@ -591,6 +622,34 @@ impl ClientInner { self.channels.get(&channel_type) } + async fn close(&self, timeout: Duration) -> Result<()> { + let channels: Vec<( + ChannelType, + ConnectionManager>, + )> = self + .channels + .iter() + .map(|entry| (*entry.key(), entry.value().connection.clone())) + .collect(); + + let mut first_error = None; + for (channel_type, connection) in channels { + if let Err(error) = connection.close(timeout).await + && first_error.is_none() + { + first_error = Some(error); + } + + self.channels.remove(&channel_type); + } + + if let Some(error) = first_error { + Err(error) + } else { + Ok(()) + } + } + /// Helper to unsubscribe and remove connection if there are no more subscriptions on this channel fn unsubscribe_and_cleanup(&self, channel_type: ChannelType, unsubscribe_fn: F) -> Result<()> where @@ -622,6 +681,15 @@ impl ClientInner { struct ChannelResources { connection: ConnectionManager>, subscriptions: Arc, + /// Owns the reconnection task spawned by + /// [`SubscriptionManager::start_reconnection_handler`]. The wrapper + /// aborts the task on drop, which releases the strong + /// `Arc` clone held by the task's future and + /// breaks the reference cycle that would otherwise leak the whole + /// channel (task, WebSocket, subscription manager) for the lifetime + /// of the process — see issue #325 and [`AbortOnDrop`]. + #[expect(dead_code, reason = "Field held only for its Drop side effect")] + reconnect_handle: AbortOnDrop, } impl ChannelResources { @@ -630,11 +698,12 @@ impl ChannelResources { let connection = ConnectionManager::new(endpoint, config, Arc::clone(&interest))?; let subscriptions = Arc::new(SubscriptionManager::new(connection.clone(), interest)); - subscriptions.start_reconnection_handler(); + let reconnect_handle = AbortOnDrop::new(subscriptions.start_reconnection_handler()); Ok(Self { connection, subscriptions, + reconnect_handle, }) } @@ -664,3 +733,154 @@ fn channel_endpoint(base: &str, channel: ChannelType) -> String { }; format!("{trimmed}/ws/{segment}") } + +#[cfg(test)] +mod close_tests { + use std::sync::{Arc, Weak}; + use std::time::{Duration, Instant}; + + use futures::StreamExt as _; + use tokio::net::TcpListener; + use tokio_tungstenite::accept_async; + use tokio_tungstenite::tungstenite::Message; + + use super::{ChannelType, Client, ConnectionState, SubscriptionManager}; + use crate::types::U256; + use crate::ws::config::Config; + + const UNROUTABLE_ENDPOINT: &str = "ws://127.0.0.1:1"; + + async fn wait_for_market_connection(client: &Client) { + let start = Instant::now(); + + loop { + if matches!( + client.connection_state(ChannelType::Market), + ConnectionState::Connected { .. } + ) { + return; + } + + assert!( + start.elapsed() < Duration::from_secs(2), + "market WebSocket did not connect" + ); + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + async fn wait_for_drop(weak: &Weak) { + let start = Instant::now(); + + while weak.strong_count() != 0 && start.elapsed() < Duration::from_secs(2) { + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + #[tokio::test] + async fn client_close_timeout_drops_half_closed_socket() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test WebSocket listener"); + let endpoint = format!("ws://{}", listener.local_addr().expect("local addr")); + + let server = tokio::spawn(async move { + let (socket, _) = listener.accept().await.expect("accept client socket"); + let mut websocket = accept_async(socket).await.expect("accept WebSocket"); + + let subscription = tokio::time::timeout(Duration::from_secs(2), websocket.next()) + .await + .expect("subscription frame before timeout") + .expect("subscription frame") + .expect("valid subscription frame"); + assert!( + subscription.is_text(), + "expected subscription text frame, got {subscription:?}" + ); + + // Do not read or echo the client's Close frame until after the + // client's close timeout has elapsed. This exercises the + // half-closed path where the peer does not complete the close + // handshake promptly. + tokio::time::sleep(Duration::from_millis(200)).await; + + let close = tokio::time::timeout(Duration::from_secs(2), async { + loop { + let message = websocket + .next() + .await + .expect("close frame") + .expect("valid close frame"); + + if matches!(message, Message::Close(_)) { + return message; + } + } + }) + .await + .expect("close frame before timeout"); + + assert!(matches!(close, Message::Close(_))); + + matches!( + tokio::time::timeout(Duration::from_secs(2), websocket.next()).await, + Ok(None | Some(Err(_))) + ) + }); + + let client = Client::new(&endpoint, Config::default()).expect("Client::new"); + let _stream = client + .subscribe_orderbook(vec![U256::from(1_u64)]) + .expect("subscribe_orderbook"); + wait_for_market_connection(&client).await; + + client + .close(Duration::from_millis(50)) + .await + .expect("client close"); + + assert_eq!( + client.subscription_count(), + 0, + "close should remove all channel resources" + ); + assert!( + server.await.expect("server task"), + "server socket stayed open" + ); + } + + #[tokio::test] + async fn arc_strong_count_returns_to_zero_after_subscription_manager_drop() { + let client = Client::new(UNROUTABLE_ENDPOINT, Config::default()).expect("Client::new"); + let stream = client + .subscribe_orderbook(vec![U256::from(1_u64)]) + .expect("subscribe_orderbook"); + + let subscriptions = { + let resources = client + .inner + .channel(ChannelType::Market) + .expect("market channel resources"); + Arc::clone(&resources.subscriptions) + }; + + assert!( + Arc::strong_count(&subscriptions) >= 2, + "reconnection task should have cloned the SubscriptionManager" + ); + + let weak = Arc::downgrade(&subscriptions); + drop(stream); + drop(client); + drop(subscriptions); + wait_for_drop(&weak).await; + + assert!( + weak.upgrade().is_none(), + "SubscriptionManager leaked after client drop: strong_count={}", + weak.strong_count() + ); + } +} diff --git a/src/clob/ws/subscription.rs b/src/clob/ws/subscription.rs index d01aa2da..29c480dd 100644 --- a/src/clob/ws/subscription.rs +++ b/src/clob/ws/subscription.rs @@ -105,7 +105,17 @@ impl SubscriptionManager { } /// Start the reconnection handler that re-subscribes on connection recovery. - pub fn start_reconnection_handler(self: &Arc) { + /// + /// Returns the [`tokio::task::JoinHandle`] for the spawned handler so the + /// caller can abort it when the owning client is dropped. The handler + /// holds a strong `Arc` clone and also owns a clone of the + /// underlying [`ConnectionManager`]; without external cancellation, the + /// `watch::Sender` it waits on can never close, so the task (and every + /// `Arc` it transitively keeps alive) leaks for the lifetime of the + /// process. Callers MUST retain the returned handle and `abort()` it in + /// their `Drop` impl to break this cycle — see + /// [`crate::clob::ws::client`] for the canonical pattern. + pub fn start_reconnection_handler(self: &Arc) -> tokio::task::JoinHandle<()> { let this = Arc::clone(self); tokio::spawn(async move { @@ -140,7 +150,7 @@ impl SubscriptionManager { } } } - }); + }) } /// Re-send subscription requests for all tracked assets and markets. @@ -563,3 +573,66 @@ impl SubscriptionManager { Ok(()) } } + +#[cfg(test)] +mod reconnect_handler_tests { + //! Regression tests for issue #325: the reconnection handler task used + //! to hold a strong `Arc` that was only released + //! when the connection's watch `Sender` closed — but the task itself + //! kept that `Sender` alive (via a cloned `ConnectionManager` inside + //! the manager), creating a refcount cycle that leaked the entire + //! channel whenever a `WsClient` was dropped. + + use std::sync::{Arc, Weak}; + use std::time::Duration; + + use super::{InterestTracker, SubscriptionManager}; + use crate::ws::ConnectionManager; + use crate::ws::config::Config; + use crate::ws::task::AbortOnDrop; + + /// Endpoint that resolves immediately and refuses connections, so the + /// underlying connection task never blocks on DNS or a slow TCP handshake. + const UNROUTABLE_ENDPOINT: &str = "ws://127.0.0.1:1"; + + #[tokio::test] + async fn aborting_reconnect_handle_releases_subscription_manager() { + let interest = Arc::new(InterestTracker::new()); + let connection = ConnectionManager::new( + UNROUTABLE_ENDPOINT.to_owned(), + Config::default(), + Arc::clone(&interest), + ) + .expect("ConnectionManager::new"); + + let subscriptions = Arc::new(SubscriptionManager::new(connection, interest)); + let reconnect_handle = AbortOnDrop::new(subscriptions.start_reconnection_handler()); + + // The spawned task holds an extra strong clone; with the owner clone + // we should observe at least 2 strong refs. + assert!( + Arc::strong_count(&subscriptions) >= 2, + "reconnection task should have cloned an Arc" + ); + + let weak: Weak = Arc::downgrade(&subscriptions); + drop(subscriptions); + drop(reconnect_handle); + + // Abort is observed on the next scheduler tick. Poll briefly rather + // than sleeping a hardcoded amount — on a loaded CI runner the task + // may need a few yields before its future is actually dropped. + let start = std::time::Instant::now(); + while weak.strong_count() != 0 && start.elapsed() < Duration::from_secs(2) { + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(10)).await; + } + + assert!( + weak.upgrade().is_none(), + "SubscriptionManager leaked after reconnect handle aborted: \ + strong_count={} (issue #325 regression)", + weak.strong_count(), + ); + } +} diff --git a/src/rtds/client.rs b/src/rtds/client.rs index 25ba3833..4192738e 100644 --- a/src/rtds/client.rs +++ b/src/rtds/client.rs @@ -14,6 +14,7 @@ use crate::types::Address; use crate::ws::ConnectionManager; use crate::ws::config::Config; use crate::ws::connection::ConnectionState; +use crate::ws::task::AbortOnDrop; /// RTDS (Real-Time Data Socket) client for streaming Polymarket data. /// @@ -65,6 +66,14 @@ struct ClientInner { connection: ConnectionManager, /// Subscription manager for handling subscriptions subscriptions: Arc, + /// Owns the reconnection task spawned by + /// [`SubscriptionManager::start_reconnection_handler`]. The wrapper + /// aborts the task on drop, which releases the strong + /// `Arc` clone held by the task's future and + /// breaks the reference cycle that would otherwise leak the whole + /// client (task, WebSocket, subscription manager) for the lifetime + /// of the process — see issue #325 and [`AbortOnDrop`]. + reconnect_handle: AbortOnDrop, } impl Client { @@ -73,8 +82,10 @@ impl Client { let connection = ConnectionManager::new(endpoint.to_owned(), config.clone(), SimpleParser)?; let subscriptions = Arc::new(SubscriptionManager::new(connection.clone())); - // Start reconnection handler to re-subscribe on connection recovery - subscriptions.start_reconnection_handler(); + // Start reconnection handler to re-subscribe on connection recovery. + // The handle is retained in an `AbortOnDrop` so the task is + // cancelled when the client is dropped — see the field docs. + let reconnect_handle = AbortOnDrop::new(subscriptions.start_reconnection_handler()); Ok(Self { inner: Arc::new(ClientInner { @@ -83,6 +94,7 @@ impl Client { endpoint: endpoint.to_owned(), connection, subscriptions, + reconnect_handle, }), }) } @@ -110,6 +122,7 @@ impl Client { endpoint: inner.endpoint, connection: inner.connection, subscriptions: inner.subscriptions, + reconnect_handle: inner.reconnect_handle, }), }) } @@ -325,7 +338,148 @@ impl Client> { endpoint: inner.endpoint, connection: inner.connection, subscriptions: inner.subscriptions, + reconnect_handle: inner.reconnect_handle, }), }) } } + +// Test-only accessor: expose the inner Arc so the +// teardown tests can take a `Weak` without widening the public API. The +// field is module-private, so the accessor lives in the same file. +#[cfg(test)] +impl Client { + fn inner_subscriptions_for_test(&self) -> Arc { + Arc::clone(&self.inner.subscriptions) + } +} + +#[cfg(test)] +mod teardown_tests { + //! RTDS client teardown regression tests for issue #325. These cover + //! the `reconnect_handle` plumbing in `Client::new`, + //! `Client::authenticate`, and `Client::deauthenticate` — each must + //! forward the `AbortOnDrop` into the new `ClientInner` so the + //! spawned task is still tied to the live client, and each must let + //! the wrapper run its `Drop` (aborting the task) when the last + //! client clone goes away. + + use std::sync::Weak; + use std::time::Duration; + + use super::{Client, SubscriptionManager}; + use crate::auth::{Credentials, Uuid}; + use crate::types::Address; + use crate::ws::config::Config; + + /// Resolves immediately and refuses TCP connections. + const UNROUTABLE_ENDPOINT: &str = "ws://127.0.0.1:1"; + + /// Dummy credentials for `authenticate` / `deauthenticate` round-trips. + /// Only the struct shape matters — the test never hits the network. + fn dummy_credentials() -> Credentials { + Credentials::new( + Uuid::nil(), + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=".to_owned(), + "passphrase".to_owned(), + ) + } + + /// Dummy EOA used for the authenticated client state. + fn dummy_address() -> Address { + "0x0000000000000000000000000000000000000001" + .parse() + .expect("valid zero-ish address") + } + + /// Poll the weak reference until the strong count drops to zero, with + /// a generous timeout so a busy CI runner doesn't flake. + async fn wait_for_drop(weak: &Weak) { + let start = std::time::Instant::now(); + while weak.strong_count() != 0 && start.elapsed() < Duration::from_secs(2) { + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + #[tokio::test] + async fn unauthenticated_client_drop_releases_subscription_manager() { + let client = Client::new(UNROUTABLE_ENDPOINT, Config::default()) + .expect("Client::new should not fail for a well-formed endpoint string"); + + let weak = std::sync::Arc::downgrade(&client.inner_subscriptions_for_test()); + + drop(client); + wait_for_drop(&weak).await; + + assert!( + weak.upgrade().is_none(), + "Unauthenticated RTDS Client leaked SubscriptionManager on drop: \ + strong_count={} (issue #325 regression)", + weak.strong_count(), + ); + } + + #[tokio::test] + async fn authenticate_then_drop_releases_subscription_manager() { + let client = Client::new(UNROUTABLE_ENDPOINT, Config::default()).expect("Client::new"); + + let weak = std::sync::Arc::downgrade(&client.inner_subscriptions_for_test()); + + let authenticated = client + .authenticate(dummy_address(), dummy_credentials()) + .expect("authenticate should succeed when no extra clones exist"); + + // `authenticate` moved the reconnect handle + subscription manager + // into a new `ClientInner`, so the weak ref should still upgrade. + assert!( + weak.upgrade().is_some(), + "authenticate prematurely dropped the SubscriptionManager" + ); + + drop(authenticated); + wait_for_drop(&weak).await; + + assert!( + weak.upgrade().is_none(), + "Authenticated RTDS Client leaked SubscriptionManager on drop: \ + strong_count={} (issue #325 regression)", + weak.strong_count(), + ); + } + + #[tokio::test] + async fn deauthenticate_preserves_reconnect_handle_then_drop_cleans_up() { + let client = Client::new(UNROUTABLE_ENDPOINT, Config::default()).expect("Client::new"); + + let weak = std::sync::Arc::downgrade(&client.inner_subscriptions_for_test()); + + let authenticated = client + .authenticate(dummy_address(), dummy_credentials()) + .expect("authenticate"); + + // Round-trip through deauthenticate; the handle must be forwarded + // into the new `ClientInner` so the task stays alive. + let unauth = authenticated + .deauthenticate() + .expect("deauthenticate should succeed when no extra clones exist"); + + // After the round-trip the manager is still reachable — nothing has + // dropped yet. + assert!( + weak.upgrade().is_some(), + "Round-tripping through authenticate/deauthenticate prematurely \ + dropped the SubscriptionManager" + ); + + drop(unauth); + wait_for_drop(&weak).await; + + assert!( + weak.upgrade().is_none(), + "RTDS Client leaked SubscriptionManager after deauthenticate+drop: \ + strong_count={} (issue #325 regression)", + weak.strong_count(), + ); + } +} diff --git a/src/rtds/subscription.rs b/src/rtds/subscription.rs index 233d4b40..484f5c5d 100644 --- a/src/rtds/subscription.rs +++ b/src/rtds/subscription.rs @@ -83,7 +83,15 @@ impl SubscriptionManager { } /// Start the reconnection handler that re-subscribes on connection recovery. - pub fn start_reconnection_handler(self: &Arc) { + /// + /// Returns the [`tokio::task::JoinHandle`] for the spawned handler so the + /// caller can abort it when the owning client is dropped. The handler + /// holds a strong `Arc` clone and awaits on a `watch::Sender` it + /// transitively keeps alive, so without external cancellation the task + /// (and the whole `SubscriptionManager` graph) leaks — same class of + /// reference cycle as `clob::ws::subscription`. Callers MUST retain the + /// returned handle and `abort()` it on drop. + pub fn start_reconnection_handler(self: &Arc) -> tokio::task::JoinHandle<()> { let this = Arc::clone(self); tokio::spawn(async move { @@ -118,7 +126,7 @@ impl SubscriptionManager { } } } - }); + }) } /// Re-send subscription requests for all tracked topics. @@ -325,3 +333,62 @@ impl SubscriptionManager { Ok(()) } } + +#[cfg(test)] +mod reconnect_handler_tests { + //! RTDS-side regression tests for issue #325. Mirrors the + //! `clob::ws::subscription` test so the same reference-cycle + //! invariant is enforced on both code paths. + + use std::sync::{Arc, Weak}; + use std::time::Duration; + + use super::{SimpleParser, SubscriptionManager}; + use crate::ws::ConnectionManager; + use crate::ws::config::Config; + use crate::ws::task::AbortOnDrop; + + /// Resolves immediately and refuses TCP connections, so the underlying + /// connection task does not block on DNS or a slow handshake. + const UNROUTABLE_ENDPOINT: &str = "ws://127.0.0.1:1"; + + #[tokio::test] + async fn aborting_reconnect_handle_releases_rtds_subscription_manager() { + let connection = ConnectionManager::new( + UNROUTABLE_ENDPOINT.to_owned(), + Config::default(), + SimpleParser, + ) + .expect("ConnectionManager::new"); + + let subscriptions = Arc::new(SubscriptionManager::new(connection)); + let reconnect_handle = AbortOnDrop::new(subscriptions.start_reconnection_handler()); + + // The spawned reconnect task clones the `Arc`, so with the owner + // clone we must observe at least 2 strong refs before drop. + assert!( + Arc::strong_count(&subscriptions) >= 2, + "reconnection task should have cloned an Arc" + ); + + let weak: Weak = Arc::downgrade(&subscriptions); + drop(subscriptions); + drop(reconnect_handle); + + // Poll briefly: the task future is only dropped once the runtime + // processes the abort, which may take a handful of scheduler ticks + // on a loaded runner. + let start = std::time::Instant::now(); + while weak.strong_count() != 0 && start.elapsed() < Duration::from_secs(2) { + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(10)).await; + } + + assert!( + weak.upgrade().is_none(), + "RTDS SubscriptionManager leaked after reconnect handle aborted: \ + strong_count={} (issue #325 regression)", + weak.strong_count(), + ); + } +} diff --git a/src/ws/connection.rs b/src/ws/connection.rs index 56b8c693..c9cdda99 100644 --- a/src/ws/connection.rs +++ b/src/ws/connection.rs @@ -5,16 +5,19 @@ use std::fmt::Debug; use std::marker::PhantomData; -use std::time::Instant; +use std::time::{Duration, Instant}; use backoff::backoff::Backoff as _; -use futures::{SinkExt as _, StreamExt as _}; +use futures::{Sink, SinkExt as _, Stream, StreamExt as _}; use serde::Serialize; use serde::de::DeserializeOwned; use tokio::net::TcpStream; -use tokio::sync::{broadcast, mpsc, watch}; +use tokio::sync::{broadcast, mpsc, oneshot, watch}; use tokio::time::{interval, sleep, timeout}; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message}; +use tokio_tungstenite::{ + MaybeTlsStream, WebSocketStream, connect_async, + tungstenite::{Error as TungsteniteError, Message}, +}; use super::config::Config; use super::error::WsError; @@ -26,6 +29,19 @@ use crate::{Result, error::Error}; type WsStream = WebSocketStream>; +enum OutgoingMessage { + Text(String), + Close { + timeout: Duration, + complete: oneshot::Sender>, + }, +} + +enum ConnectionExit { + Reconnect, + ClosedByClient, +} + /// Broadcast channel capacity for incoming messages. const BROADCAST_CAPACITY: usize = 1024; @@ -97,7 +113,7 @@ where /// Watch channel receiver for state changes (for use in checking the current state) state_rx: watch::Receiver, /// Sender channel for outgoing messages - sender_tx: mpsc::UnboundedSender, + sender_tx: mpsc::UnboundedSender, /// Broadcast sender for incoming messages broadcast_tx: broadcast::Sender, /// Phantom data for unused type parameters @@ -150,7 +166,7 @@ where async fn connection_loop( endpoint: String, config: Config, - mut sender_rx: mpsc::UnboundedReceiver, + mut sender_rx: mpsc::UnboundedReceiver, broadcast_tx: broadcast::Sender, parser: P, state_tx: watch::Sender, @@ -181,7 +197,7 @@ where }); // Handle connection - if let Err(e) = Self::handle_connection( + match Self::handle_connection( ws_stream, &mut sender_rx, &broadcast_tx, @@ -191,10 +207,17 @@ where ) .await { - #[cfg(feature = "tracing")] - tracing::error!("Error handling connection: {e:?}"); - #[cfg(not(feature = "tracing"))] - let _: &_ = &e; + Ok(ConnectionExit::Reconnect) => {} + Ok(ConnectionExit::ClosedByClient) => { + _ = state_tx.send(ConnectionState::Disconnected); + break; + } + Err(e) => { + #[cfg(feature = "tracing")] + tracing::error!("Error handling connection: {e:?}"); + #[cfg(not(feature = "tracing"))] + let _: &_ = &e; + } } } Err(e) => { @@ -227,12 +250,12 @@ where /// Handle an active WebSocket connection. async fn handle_connection( ws_stream: WsStream, - sender_rx: &mut mpsc::UnboundedReceiver, + sender_rx: &mut mpsc::UnboundedReceiver, broadcast_tx: &broadcast::Sender, state_rx: watch::Receiver, config: Config, parser: &P, - ) -> Result<()> { + ) -> Result { let (mut write, mut read) = ws_stream.split(); // Channel to notify heartbeat loop when PONG is received @@ -292,10 +315,20 @@ where } } - // Handle outgoing messages from subscriptions - Some(text) = sender_rx.recv() => { - if write.send(Message::Text(text.into())).await.is_err() { - break; + // Handle outgoing messages from subscriptions and close requests. + Some(message) = sender_rx.recv() => { + match message { + OutgoingMessage::Text(text) => { + if write.send(Message::Text(text.into())).await.is_err() { + break; + } + } + OutgoingMessage::Close { timeout, complete } => { + let close_result = Self::close_stream(&mut write, &mut read, timeout).await; + heartbeat_handle.abort(); + _ = complete.send(close_result); + return Ok(ConnectionExit::ClosedByClient); + } } } @@ -316,7 +349,32 @@ where // Cleanup heartbeat_handle.abort(); - Ok(()) + Ok(ConnectionExit::Reconnect) + } + + async fn close_stream(write: &mut W, read: &mut R, close_timeout: Duration) -> Result<()> + where + W: Sink + Unpin, + R: Stream> + Unpin, + { + write.send(Message::Close(None)).await?; + + let wait_for_peer_close = async { + while let Some(message) = read.next().await { + match message { + Ok(Message::Close(_)) => return Ok(()), + Ok(_) => {} + Err(e) => return Err(Error::from(e)), + } + } + + Ok(()) + }; + + match timeout(close_timeout, wait_for_peer_close).await { + Ok(result) => result, + Err(_) => Ok(()), + } } /// Heartbeat loop that sends PING messages and monitors PONG responses. @@ -382,7 +440,7 @@ where pub fn send(&self, request: &R) -> Result<()> { let json = serde_json::to_string(request)?; self.sender_tx - .send(json) + .send(OutgoingMessage::Text(json)) .map_err(|_e| WsError::ConnectionClosed)?; Ok(()) } @@ -395,11 +453,27 @@ where ) -> Result<()> { let json = request.as_authenticated(credentials)?; self.sender_tx - .send(json) + .send(OutgoingMessage::Text(json)) .map_err(|_e| WsError::ConnectionClosed)?; Ok(()) } + pub(crate) async fn close(&self, close_timeout: Duration) -> Result<()> { + let (complete_tx, complete_rx) = oneshot::channel(); + + self.sender_tx + .send(OutgoingMessage::Close { + timeout: close_timeout, + complete: complete_tx, + }) + .map_err(|_e| WsError::ConnectionClosed)?; + + match timeout(close_timeout, complete_rx).await { + Ok(Ok(result)) => result, + Ok(Err(_)) | Err(_) => Ok(()), + } + } + /// Get the current connection state. #[must_use] pub fn state(&self) -> ConnectionState { diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 5ab261cd..0e089a1d 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -22,6 +22,7 @@ pub mod config; pub mod connection; pub mod error; +pub(crate) mod task; pub mod traits; pub use connection::ConnectionManager; diff --git a/src/ws/task.rs b/src/ws/task.rs new file mode 100644 index 00000000..89dfa515 --- /dev/null +++ b/src/ws/task.rs @@ -0,0 +1,84 @@ +//! Helpers for owning spawned [`tokio`] tasks. +//! +//! The reconnection handlers in [`crate::clob::ws::subscription`] and +//! [`crate::rtds::subscription`] spawn detached tasks that each hold a +//! strong `Arc` clone. Those tasks also own a clone +//! of the underlying [`crate::ws::ConnectionManager`], whose `watch::Sender` +//! is what the task awaits on for state changes. Because the `Sender` only +//! closes when every clone of it drops, and a strong `Arc` clone inside the +//! spawned task prevents the owning `SubscriptionManager` (and therefore +//! that `ConnectionManager` clone) from dropping, the task can never exit +//! on its own — a reference cycle that leaks the entire channel for the +//! lifetime of the process (issue #325). +//! +//! [`AbortOnDrop`] is a thin wrapper around [`tokio::task::JoinHandle`] +//! that calls `abort()` on drop. Clients store the wrapped handle next to +//! the `Arc`; when the client (and therefore the +//! wrapper) drops, the handler task is aborted, its stack locals — which +//! include the strong `Arc` clone — are released, and the whole graph can +//! drop normally. + +use tokio::task::JoinHandle; + +/// Owns a [`JoinHandle`] and calls [`JoinHandle::abort`] on drop. +/// +/// See the module-level documentation for the cycle this breaks. +pub(crate) struct AbortOnDrop(JoinHandle<()>); + +impl AbortOnDrop { + /// Wrap a spawned task handle so it is aborted when this value drops. + #[must_use] + pub(crate) fn new(handle: JoinHandle<()>) -> Self { + Self(handle) + } +} + +impl Drop for AbortOnDrop { + fn drop(&mut self) { + // `JoinHandle::abort` is a no-op if the task has already completed, + // so this is always safe to call. + self.0.abort(); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::Duration; + + use super::AbortOnDrop; + + #[tokio::test] + async fn abort_on_drop_cancels_pending_task() { + let finished = Arc::new(AtomicBool::new(false)); + let finished_task = Arc::clone(&finished); + + let handle = tokio::spawn(async move { + // Park forever unless aborted. + std::future::pending::<()>().await; + finished_task.store(true, Ordering::SeqCst); + }); + + let wrapper = AbortOnDrop::new(handle); + drop(wrapper); + + // Give the runtime a moment to process the abort. + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!( + !finished.load(Ordering::SeqCst), + "task body should never have run to completion after abort" + ); + } + + #[tokio::test] + async fn abort_on_drop_is_noop_for_finished_task() { + // Spawn a task that completes immediately, wait for it, then drop + // the wrapper. `JoinHandle::abort` on a finished task is documented + // as a no-op, so this must not panic or error. + let wrapper = AbortOnDrop::new(tokio::spawn(async {})); + tokio::time::sleep(Duration::from_millis(10)).await; + drop(wrapper); + } +}