Skip to content

Add ability to handle events that expect an ack #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ test-fast: keys

run-test-servers:
cd ci && docker build -t test_suite:latest . && cd ..
docker run -d -p 4200:4200 -p 4201:4201 -p 4202:4202 -p 4203:4203 -p 4204:4204 -p 4205:4205 -p 4206:4206 --name socketio_test test_suite:latest
docker run --rm -d -p 4200:4200 -p 4201:4201 -p 4202:4202 -p 4203:4203 -p 4204:4204 -p 4205:4205 -p 4206:4206 --name socketio_test test_suite:latest

test-all: keys run-test-servers
@cargo test --verbose --all-features
-cargo test --verbose --all-features
docker stop socketio_test

clippy:
Expand Down
80 changes: 80 additions & 0 deletions socketio/src/asynchronous/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,51 @@ impl ClientBuilder {
self
}

/// Registers a new callback for a certain [`crate::event::Event`] that expects the client to
/// ack. The event could either be one of the common events like `message`, `error`, `open`,
/// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload};
/// use futures_util::FutureExt;
///
/// #[tokio::main]
/// async fn main() {
/// let socket = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on_with_ack("test", |payload: Payload, client: Client, ack: i32| {
/// async move {
/// match payload {
/// Payload::Text(values) => println!("Received: {:#?}", values),
/// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data),
/// // This is deprecated, use Payload::Text instead
/// Payload::String(str) => println!("Received: {}", str),
/// }
/// client.ack(ack, "received").await;
/// }
/// .boxed()
/// })
/// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
/// .connect()
/// .await;
/// }
///
#[cfg(feature = "async-callbacks")]
pub fn on_with_ack<T: Into<Event>, F>(mut self, event: T, callback: F) -> Self
where
F: for<'a> std::ops::FnMut(Payload, Client, i32) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
{
self.on.insert(
event.into(),
Callback::<DynAsyncCallback>::new_with_ack(callback),
);
self
}

/// Registers a callback for reconnect events. The event handler must return
/// a [ReconnectSettings] struct with the settings that should be updated.
///
Expand Down Expand Up @@ -263,6 +308,41 @@ impl ClientBuilder {
self
}

/// Registers a Callback for all [`crate::event::Event::Custom`] and
/// [`crate::event::Event::Message`] that expect the client to ack.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::ClientBuilder, Payload};
/// use futures_util::future::FutureExt;
///
/// #[tokio::main]
/// async fn main() {
/// let client = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on_any_with_ack(|event, payload, client, ack| {
/// async move {
/// if let Payload::String(str) = payload {
/// println!("{}: {}", String::from(event), str);
/// }
/// client.ack(ack, "received").await;
/// }.boxed()
/// })
/// .connect()
/// .await;
/// }
/// ```
pub fn on_any_with_ack<F>(mut self, callback: F) -> Self
where
F: for<'a> FnMut(Event, Payload, Client, i32) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
{
self.on_any = Some(Callback::<DynAsyncAnyCallback>::new_with_ack(callback));
self
}

/// Uses a preconfigured TLS connector for secure communication. This configures
/// both the `polling` as well as the `websocket` transport type.
/// # Example
Expand Down
69 changes: 55 additions & 14 deletions socketio/src/asynchronous/client/callback.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use futures_util::future::BoxFuture;
use futures_util::{future::BoxFuture, FutureExt};
use std::{
fmt::Debug,
future::Future,
ops::{Deref, DerefMut},
};

Expand All @@ -9,11 +10,18 @@ use crate::{Event, Payload};
use super::client::{Client, ReconnectSettings};

/// Internal type, provides a way to store futures and return them in a boxed manner.
pub(crate) type DynAsyncCallback =
Box<dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync>;
pub(crate) type DynAsyncCallback = Box<
dyn for<'a> FnMut(Payload, Client, Option<i32>) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
>;

pub(crate) type DynAsyncAnyCallback = Box<
dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
dyn for<'a> FnMut(Event, Payload, Client, Option<i32>) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
>;

pub(crate) type DynAsyncReconnectSettingsCallback =
Expand All @@ -30,8 +38,10 @@ impl<T> Debug for Callback<T> {
}

impl Deref for Callback<DynAsyncCallback> {
type Target =
dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send;
type Target = dyn for<'a> FnMut(Payload, Client, Option<i32>) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
Expand All @@ -45,19 +55,34 @@ impl DerefMut for Callback<DynAsyncCallback> {
}

impl Callback<DynAsyncCallback> {
pub(crate) fn new<T>(callback: T) -> Self
pub(crate) fn new_with_ack<T>(mut callback: T) -> Self
where
T: for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send,
T: for<'a> FnMut(Payload, Client, i32) -> BoxFuture<'static, ()> + 'static + Sync + Send,
{
Callback {
inner: Box::new(callback),
inner: Box::new(move |p, c, a| match a {
Some(a) => callback(p, c, a).boxed(),
None => std::future::ready(()).boxed(),
}),
}
}

pub(crate) fn new<T, Fut>(mut callback: T) -> Self
where
T: FnMut(Payload, Client) -> Fut + Sync + Send + 'static,
Fut: Future<Output = ()> + 'static + Send,
{
Callback {
inner: Box::new(move |p, c, _a| callback(p, c).boxed()),
}
}
}

impl Deref for Callback<DynAsyncAnyCallback> {
type Target =
dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send;
type Target = dyn for<'a> FnMut(Event, Payload, Client, Option<i32>) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
Expand All @@ -71,12 +96,28 @@ impl DerefMut for Callback<DynAsyncAnyCallback> {
}

impl Callback<DynAsyncAnyCallback> {
pub(crate) fn new<T>(callback: T) -> Self
pub(crate) fn new_with_ack<T>(mut callback: T) -> Self
where
T: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send,
T: for<'a> FnMut(Event, Payload, Client, i32) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send,
{
Callback {
inner: Box::new(callback),
inner: Box::new(move |e, p, c, a| match a {
Some(a) => callback(e, p, c, a).boxed(),
None => std::future::ready(()).boxed(),
}),
}
}

pub(crate) fn new<T, Fut>(mut callback: T) -> Self
where
T: FnMut(Event, Payload, Client) -> Fut + Sync + Send + 'static,
Fut: Future<Output = ()> + 'static + Send,
{
Callback {
inner: Box::new(move |e, p, c, _a| callback(e, p, c).boxed()),
}
}
}
Expand Down
53 changes: 39 additions & 14 deletions socketio/src/asynchronous/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl Client {
// We don't need to do that in the other cases, since proper server close
// and manual client close are handled explicitly.
if let Some(err) = client_clone
.callback(&Event::Close, CloseReason::TransportClose.as_str())
.callback(&Event::Close, CloseReason::TransportClose.as_str(), None)
.await
.err()
{
Expand Down Expand Up @@ -410,19 +410,33 @@ impl Client {
self.socket.read().await.send(socket_packet).await
}

async fn callback<P: Into<Payload>>(&self, event: &Event, payload: P) -> Result<()> {
pub async fn ack<D>(&self, ack_id: i32, data: D) -> Result<()>
where
D: Into<Payload>,
{
let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id);

self.socket.read().await.send(socket_packet).await
}

async fn callback<P: Into<Payload>>(
&self,
event: &Event,
payload: P,
ack_id: Option<i32>,
) -> Result<()> {
let mut builder = self.builder.write().await;
let payload = payload.into();

if let Some(callback) = builder.on.get_mut(event) {
callback(payload.clone(), self.clone()).await;
callback(payload.clone(), self.clone(), ack_id).await;
}

// Call on_any for all common and custom events.
match event {
Event::Message | Event::Custom(_) => {
if let Some(callback) = builder.on_any.as_mut() {
callback(event.clone(), payload, self.clone()).await;
callback(event.clone(), payload, self.clone(), ack_id).await;
}
}
_ => (),
Expand All @@ -445,6 +459,7 @@ impl Client {
ack.callback.deref_mut()(
Payload::from(payload.to_owned()),
self.clone(),
None,
)
.await;
}
Expand All @@ -453,6 +468,7 @@ impl Client {
ack.callback.deref_mut()(
Payload::Binary(payload.to_owned()),
self.clone(),
None,
)
.await;
}
Expand Down Expand Up @@ -480,8 +496,12 @@ impl Client {

if let Some(attachments) = &packet.attachments {
if let Some(binary_payload) = attachments.get(0) {
self.callback(&event, Payload::Binary(binary_payload.to_owned()))
.await?;
self.callback(
&event,
Payload::Binary(binary_payload.to_owned()),
packet.id,
)
.await?;
}
}
Ok(())
Expand Down Expand Up @@ -514,7 +534,7 @@ impl Client {
};

// call the correct callback
self.callback(&event, payloads.to_vec()).await?;
self.callback(&event, payloads.to_vec(), packet.id).await?;
}

Ok(())
Expand All @@ -529,23 +549,27 @@ impl Client {
match packet.packet_type {
PacketId::Ack | PacketId::BinaryAck => {
if let Err(err) = self.handle_ack(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
return Err(err);
}
}
PacketId::BinaryEvent => {
if let Err(err) = self.handle_binary_event(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
}
}
PacketId::Connect => {
*(self.disconnect_reason.write().await) = DisconnectReason::default();
self.callback(&Event::Connect, "").await?;
self.callback(&Event::Connect, "", None).await?;
}
PacketId::Disconnect => {
*(self.disconnect_reason.write().await) = DisconnectReason::Server;
self.callback(&Event::Close, CloseReason::IOServerDisconnect.as_str())
.await?;
self.callback(
&Event::Close,
CloseReason::IOServerDisconnect.as_str(),
None,
)
.await?;
}
PacketId::ConnectError => {
self.callback(
Expand All @@ -555,12 +579,13 @@ impl Client {
.data
.as_ref()
.unwrap_or(&String::from("\"No error message provided\"")),
None,
)
.await?;
}
PacketId::Event => {
if let Err(err) = self.handle_event(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
}
}
}
Expand All @@ -582,7 +607,7 @@ impl Client {
None => None,
Some(Err(err)) => {
// call the error callback
match self.callback(&Event::Error, err.to_string()).await {
match self.callback(&Event::Error, err.to_string(), None).await {
Err(callback_err) => Some((Err(callback_err), socket)),
Ok(_) => Some((Err(err), socket)),
}
Expand Down
Loading