Skip to content

Commit e0cc5a5

Browse files
authored
Fix task leak (#292)
* Close all opened tokio task on close * Implement tests. Fix close and drop on clone
1 parent ccabc93 commit e0cc5a5

File tree

4 files changed

+154
-107
lines changed

4 files changed

+154
-107
lines changed

src/producer.rs

Lines changed: 80 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1+
use futures::executor::block_on;
12
use std::future::Future;
23
use std::time::Duration;
34
use std::{
45
marker::PhantomData,
56
sync::{
6-
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
7+
atomic::{AtomicBool, AtomicU64, Ordering},
78
Arc,
89
},
910
};
1011

1112
use dashmap::DashMap;
1213
use futures::{future::BoxFuture, FutureExt};
14+
use tokio::sync::mpsc;
1315
use tokio::sync::mpsc::channel;
14-
use tokio::sync::{mpsc, Mutex};
1516
use tokio::time::sleep;
16-
use tracing::{debug, error, trace};
17+
use tracing::{error, info, trace};
1718

1819
use rabbitmq_stream_protocol::{message::Message, ResponseCode, ResponseKind};
1920

@@ -60,18 +61,49 @@ impl ConfirmationStatus {
6061
}
6162

6263
pub struct ProducerInternal {
63-
client: Client,
64+
client: Arc<Client>,
6465
stream: String,
6566
producer_id: u8,
66-
batch_size: usize,
6767
publish_sequence: Arc<AtomicU64>,
6868
waiting_confirmations: WaiterMap,
6969
closed: Arc<AtomicBool>,
70-
accumulator: MessageAccumulator,
71-
publish_version: u16,
70+
sender: mpsc::Sender<ClientMessage>,
7271
filter_value_extractor: Option<FilterValueExtractor>,
7372
}
7473

74+
impl Drop for ProducerInternal {
75+
fn drop(&mut self) {
76+
block_on(async {
77+
if let Err(e) = self.close().await {
78+
error!(error = ?e, "Error closing producer");
79+
}
80+
});
81+
}
82+
}
83+
84+
impl ProducerInternal {
85+
pub async fn close(&self) -> Result<(), ProducerCloseError> {
86+
match self
87+
.closed
88+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
89+
{
90+
Ok(false) => {
91+
let response = self.client.delete_publisher(self.producer_id).await?;
92+
if response.is_ok() {
93+
self.client.close().await?;
94+
Ok(())
95+
} else {
96+
Err(ProducerCloseError::Close {
97+
status: response.code().clone(),
98+
stream: self.stream.clone(),
99+
})
100+
}
101+
}
102+
_ => Ok(()), // Already closed
103+
}
104+
}
105+
}
106+
75107
/// API for publising messages to RabbitMQ stream
76108
#[derive(Clone)]
77109
pub struct Producer<T>(Arc<ProducerInternal>, PhantomData<T>);
@@ -139,22 +171,29 @@ impl<T> ProducerBuilder<T> {
139171
};
140172

141173
if response.is_ok() {
174+
let (sender, receiver) = mpsc::channel(self.batch_size);
175+
176+
let client = Arc::new(client);
142177
let producer = ProducerInternal {
143178
producer_id,
144-
batch_size: self.batch_size,
145179
stream: stream.to_string(),
146180
client,
147181
publish_sequence,
148182
waiting_confirmations,
149-
publish_version,
150183
closed: Arc::new(AtomicBool::new(false)),
151-
accumulator: MessageAccumulator::new(self.batch_size),
184+
sender,
152185
filter_value_extractor: self.filter_value_extractor,
153186
};
154187

155188
let internal_producer = Arc::new(producer);
156-
let producer = Producer(internal_producer.clone(), PhantomData);
157-
schedule_batch_send(internal_producer);
189+
schedule_batch_send(
190+
self.batch_size,
191+
receiver,
192+
internal_producer.client.clone(),
193+
producer_id,
194+
publish_version,
195+
);
196+
let producer = Producer(internal_producer, PhantomData);
158197

159198
Ok(producer)
160199
} else {
@@ -205,78 +244,33 @@ impl<T> ProducerBuilder<T> {
205244
}
206245
}
207246

208-
pub struct MessageAccumulator {
209-
sender: mpsc::Sender<ClientMessage>,
210-
receiver: Mutex<mpsc::Receiver<ClientMessage>>,
211-
message_count: AtomicUsize,
212-
}
213-
214-
impl MessageAccumulator {
215-
pub fn new(batch_size: usize) -> Self {
216-
let (sender, receiver) = mpsc::channel(batch_size);
217-
Self {
218-
sender,
219-
receiver: Mutex::new(receiver),
220-
message_count: AtomicUsize::new(0),
221-
}
222-
}
223-
224-
pub async fn add(&self, message: ClientMessage) -> RabbitMQStreamResult<()> {
225-
match self.sender.send(message).await {
226-
Ok(_) => {
227-
self.message_count.fetch_add(1, Ordering::Relaxed);
228-
Ok(())
229-
}
230-
Err(e) => Err(ClientError::GenericError(Box::new(e))),
231-
}
232-
}
233-
234-
pub async fn get(&self, buffer: &mut Vec<ClientMessage>, batch_size: usize) -> (bool, usize) {
235-
let mut receiver = self.receiver.lock().await;
236-
237-
let count = receiver.recv_many(buffer, batch_size).await;
238-
self.message_count.fetch_sub(count, Ordering::Relaxed);
239-
240-
// `recv_many` returns 0 only if the channel is closed
241-
// Read https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Receiver.html#method.recv_many
242-
(count == 0, count)
243-
}
244-
}
245-
246-
fn schedule_batch_send(producer: Arc<ProducerInternal>) {
247+
fn schedule_batch_send(
248+
batch_size: usize,
249+
mut receiver: mpsc::Receiver<ClientMessage>,
250+
client: Arc<Client>,
251+
producer_id: u8,
252+
publish_version: u16,
253+
) {
247254
tokio::task::spawn(async move {
248-
let mut buffer = Vec::with_capacity(producer.batch_size);
255+
let mut buffer = Vec::with_capacity(batch_size);
249256
loop {
250-
let (is_closed, count) = producer
251-
.accumulator
252-
.get(&mut buffer, producer.batch_size)
253-
.await;
257+
let count = receiver.recv_many(&mut buffer, batch_size).await;
254258

255-
if is_closed {
256-
error!("Channel is closed and this is bad");
259+
if count == 0 || buffer.is_empty() {
260+
// Channel is closed, exit the loop
257261
break;
258262
}
259263

260-
if count > 0 {
261-
debug!("Sending batch of {} messages", count);
262-
let messages: Vec<_> = buffer.drain(..count).collect();
263-
match producer
264-
.client
265-
.publish(producer.producer_id, messages, producer.publish_version)
266-
.await
267-
{
268-
Ok(_) => {}
269-
Err(e) => {
270-
error!("Error publishing batch {:?}", e);
271-
272-
// Stop loop if producer is closed
273-
if producer.closed.load(Ordering::Relaxed) {
274-
break;
275-
}
276-
}
277-
};
278-
}
264+
let messages: Vec<_> = buffer.drain(..count).collect();
265+
match client.publish(producer_id, messages, publish_version).await {
266+
Ok(_) => {}
267+
Err(e) => {
268+
error!("Error publishing batch {:?}", e);
269+
}
270+
};
279271
}
272+
273+
info!("Batch send task finished");
280274
});
281275
}
282276

@@ -455,10 +449,13 @@ impl<T> Producer<T> {
455449
.waiting_confirmations
456450
.insert(publishing_id, ProducerMessageWaiter::Once(waiter));
457451

458-
self.0.accumulator.add(msg).await?;
452+
if let Err(e) = self.0.sender.send(msg).await {
453+
return Err(ClientError::GenericError(Box::new(e)))?;
454+
}
459455

460456
Ok(())
461457
}
458+
462459
async fn internal_batch_send<Fut>(
463460
&self,
464461
messages: Vec<Message>,
@@ -488,7 +485,9 @@ impl<T> Producer<T> {
488485
}
489486

490487
// Queue the message for sending
491-
self.0.accumulator.add(client_message).await?;
488+
if let Err(e) = self.0.sender.send(client_message).await {
489+
return Err(ClientError::GenericError(Box::new(e)))?;
490+
}
492491
self.0
493492
.waiting_confirmations
494493
.insert(publishing_id, ProducerMessageWaiter::Shared(waiter.clone()));
@@ -500,27 +499,9 @@ impl<T> Producer<T> {
500499
pub fn is_closed(&self) -> bool {
501500
self.0.closed.load(Ordering::Relaxed)
502501
}
503-
// TODO handle producer state after close
502+
504503
pub async fn close(self) -> Result<(), ProducerCloseError> {
505-
match self
506-
.0
507-
.closed
508-
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
509-
{
510-
Ok(false) => {
511-
let response = self.0.client.delete_publisher(self.0.producer_id).await?;
512-
if response.is_ok() {
513-
self.0.client.close().await?;
514-
Ok(())
515-
} else {
516-
Err(ProducerCloseError::Close {
517-
status: response.code().clone(),
518-
stream: self.0.stream.clone(),
519-
})
520-
}
521-
}
522-
_ => Err(ProducerCloseError::AlreadyClosed),
523-
}
504+
self.0.close().await
524505
}
525506
}
526507

tests/client_test.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,13 @@ async fn client_test_route_test() {
462462
async fn client_close() {
463463
let test = TestClient::create().await;
464464

465+
let output = test
466+
.client
467+
.metadata(vec![test.stream.clone()])
468+
.await
469+
.unwrap();
470+
assert_ne!(output.len(), 0);
471+
465472
test.client
466473
.close()
467474
.await

tests/consumer_test.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,6 @@ async fn consumer_close_test() {
221221
consumer.handle().close().await,
222222
Err(ConsumerCloseError::AlreadyClosed),
223223
));
224-
producer.clone().close().await.unwrap();
225-
226-
assert!(matches!(
227-
producer.close().await,
228-
Err(ProducerCloseError::AlreadyClosed),
229-
));
230224
}
231225

232226
#[tokio::test(flavor = "multi_thread")]

tests/producer_test.rs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use std::{collections::HashSet, sync::Arc};
1+
use std::{collections::HashSet, sync::Arc, time::Duration};
22

33
use chrono::Utc;
44
use fake::{Fake, Faker};
55
use futures::{lock::Mutex, StreamExt};
6-
use tokio::sync::mpsc::channel;
6+
use tokio::{sync::mpsc::channel, task::yield_now, time::sleep};
77

88
use rabbitmq_stream_client::{
99
error::ClientError,
@@ -19,6 +19,7 @@ use common::*;
1919
use rabbitmq_stream_client::types::{
2020
HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy,
2121
};
22+
use tracing::span;
2223

2324
use std::sync::atomic::{AtomicU32, Ordering};
2425
use tokio::sync::Notify;
@@ -719,3 +720,67 @@ async fn producer_drop_connection() {
719720
rabbitmq_stream_client::error::ProducerCloseError::Client(ClientError::ConnectionClosed)
720721
));
721722
}
723+
724+
#[tokio::test(flavor = "multi_thread")]
725+
async fn producer_close() {
726+
let env = TestEnvironment::create().await;
727+
728+
let producer = env.env.producer().build(&env.stream).await.unwrap();
729+
let producer2 = producer.clone();
730+
731+
let metrics = tokio::runtime::Handle::current().metrics();
732+
assert_eq!(metrics.num_alive_tasks(), 3);
733+
734+
producer.close().await.unwrap();
735+
736+
let status = producer2
737+
.send_with_confirm(Message::builder().body(b"message".to_vec()).build())
738+
.await;
739+
let err = status.unwrap_err();
740+
assert!(matches!(
741+
err,
742+
rabbitmq_stream_client::error::ProducerPublishError::Closed
743+
));
744+
drop(producer2);
745+
746+
// Ensure that the producer is closed and no tasks are alive
747+
sleep(Duration::from_millis(500)).await;
748+
749+
let metrics = tokio::runtime::Handle::current().metrics();
750+
assert_eq!(metrics.num_alive_tasks(), 0);
751+
}
752+
753+
#[tokio::test(flavor = "multi_thread")]
754+
async fn producer_drop() {
755+
let env = TestEnvironment::create().await;
756+
757+
let producer = env.env.producer().build(&env.stream).await.unwrap();
758+
let producer2 = producer.clone();
759+
760+
let metrics = tokio::runtime::Handle::current().metrics();
761+
assert_eq!(metrics.num_alive_tasks(), 3);
762+
763+
// This should not close everything: another producer is still alive
764+
drop(producer);
765+
766+
// Ensure that if something should drop some tasks, it has time to do so
767+
sleep(Duration::from_millis(500)).await;
768+
769+
producer2
770+
.send_with_confirm(Message::builder().body(b"message".to_vec()).build())
771+
.await
772+
.unwrap();
773+
774+
let metrics = tokio::runtime::Handle::current().metrics();
775+
assert_eq!(metrics.num_alive_tasks(), 3);
776+
777+
drop(producer2);
778+
779+
// If we drop the last reference to internal producer,
780+
// all tasks should be closed
781+
// Keep time for tasks to close
782+
sleep(Duration::from_millis(500)).await;
783+
784+
let metrics = tokio::runtime::Handle::current().metrics();
785+
assert_eq!(metrics.num_alive_tasks(), 0);
786+
}

0 commit comments

Comments
 (0)