Skip to content

Commit

Permalink
add missing async/ack ack method
Browse files Browse the repository at this point in the history
  • Loading branch information
dontcryme committed Jan 29, 2025
1 parent a4e5287 commit d10db36
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 2 deletions.
46 changes: 46 additions & 0 deletions socketio/src/asynchronous/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,52 @@ impl Client {
.await
}

/// When receive server's emitwithack callback event, invoke socket.ack(..) function can react to server with ack signal
/// use futures_util::FutureExt;
///
/// # Example
/// ```
/// use futures_util::FutureExt;
/// use rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload};
/// use serde_json::json;
/// use std::time::Duration;
/// use std::thread;
/// use bytes::Bytes;
///
/// #[tokio::main]
/// async fn main() {
///
/// let callback = |payload: Payload, socket: Client| {
/// async move {
/// let byte_test = vec![0x01, 0x02];
/// let _ = socket.ack(byte_test).await;
/// }.boxed()
/// };
///
/// // get a socket that is connected to the admin namespace
/// let socket = ClientBuilder::new("http://localhost:4200")
/// .namespace("/")
/// .on("foo", callback)
/// .on("error", |err, _| {
/// async move { eprintln!("Error: {:#?}", err) }.boxed()
/// })
/// .connect()
/// .await
/// .expect("Connection failed");
///
///
/// thread::sleep(Duration::from_millis(30000));
/// socket.disconnect().await.expect("Disconnect failed");
/// }
/// ```
#[inline]
pub async fn ack<D>(&self, data: D) -> Result<()>
where
D: Into<Payload>,
{
self.socket.read().await.ack(&self.nsp, data.into()).await
}

/// Disconnects this client from the server by sending a `socket.io` closing
/// packet.
/// # Example
Expand Down
23 changes: 21 additions & 2 deletions socketio/src/asynchronous/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{
fmt::Debug,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
atomic::{AtomicBool, AtomicI32, Ordering},
Arc,
},
};
Expand All @@ -24,16 +24,20 @@ pub(crate) struct Socket {
engine_client: Arc<EngineClient>,
connected: Arc<AtomicBool>,
generator: StreamGenerator<Packet>,
ack_id: Arc<AtomicI32>,
}

impl Socket {
/// Creates an instance of `Socket`.
pub(super) fn new(engine_client: EngineClient) -> Result<Self> {
let connected = Arc::new(AtomicBool::default());
let ack_id = Arc::new(AtomicI32::new(-1));

Ok(Socket {
engine_client: Arc::new(engine_client.clone()),
connected: connected.clone(),
generator: StreamGenerator::new(Self::stream(engine_client, connected)),
ack_id: ack_id.clone(),
generator: StreamGenerator::new(Self::stream(engine_client, connected, ack_id)),
})
}

Expand All @@ -58,6 +62,9 @@ impl Socket {
if self.connected.load(Ordering::Acquire) {
self.connected.store(false, Ordering::Release);
}
if self.ack_id.load(Ordering::Acquire) != -1 {
self.ack_id.store(-1, Ordering::Release);
}
Ok(())
}

Expand Down Expand Up @@ -89,9 +96,17 @@ impl Socket {
self.send(socket_packet).await
}

/// Emits to connected other side with given data
pub async fn ack(&self, nsp: &str, data: Payload) -> Result<()> {
let socket_packet =
Packet::ack_from_payload(data, nsp, Some(self.ack_id.load(Ordering::Acquire)))?;
self.send(socket_packet).await
}

fn stream(
client: EngineClient,
is_connected: Arc<AtomicBool>,
ack_id: Arc<AtomicI32>,
) -> Pin<Box<impl Stream<Item = Result<Packet>> + Send>> {
Box::pin(try_stream! {
for await received_data in client.clone() {
Expand All @@ -101,6 +116,10 @@ impl Socket {
|| packet.packet_id == EnginePacketId::MessageBinary
{
let packet = Self::handle_engineio_packet(packet, client.clone()).await?;

if ack_id.load(Ordering::Acquire) != packet.id.unwrap_or(-1) {
ack_id.store(packet.id.unwrap_or(-1), Ordering::Release);
}
Self::handle_socketio_packet(&packet, is_connected.clone());

yield packet;
Expand Down
36 changes: 36 additions & 0 deletions socketio/src/client/raw_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,42 @@ impl RawClient {
Ok(())
}

/// Example code for handling ACK response when calling emitWithAck on the server
///
/// # Example
/// ```
/// use rust_socketio::{ClientBuilder, Payload, RawClient};
/// use std::time::Duration;
/// use std::thread::sleep;
///
///
/// let ack_callback = |message: Payload, socket: RawClient| {
/// match message {
/// Payload::Text(values) => println!("{:#?}", values),
/// Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes),
/// // This is deprecated, use Payload::Text instead
/// Payload::String(str) => println!("{}", str),
/// }
/// socket.ack("foo").unwrap();
/// };
///
/// let mut socket = ClientBuilder::new("http://localhost:4200/")
/// .on("foo", ack_callback)
/// .connect()
/// .expect("connection failed");
///
///
///
/// sleep(Duration::from_secs(2));
/// ```
#[inline]
pub fn ack<D>(&self, data: D) -> Result<()>
where
D: Into<Payload>,
{
self.socket.ack(&self.nsp, data.into())
}

/// Sends a message to the server using the underlying `engine.io` protocol.
/// This message takes an event, which could either be one of the common
/// events like "message" or "error" or a custom event like "foo". But be
Expand Down
103 changes: 103 additions & 0 deletions socketio/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,53 @@ impl Packet {
}
}
}

#[inline]
pub(crate) fn ack_from_payload<'a>(
payload: Payload,
nsp: &'a str,
ack_id: Option<i32>,
) -> Result<Packet> {
match payload {
Payload::Binary(bin_data) => Ok(Packet::new(
PacketId::BinaryAck,
nsp.to_owned(),
None,
ack_id,
1,
Some(vec![bin_data]),
)),
#[allow(deprecated)]
Payload::String(str_data) => {
let payload = if serde_json::from_str::<IgnoredAny>(&str_data).is_ok() {
format!("[{str_data}]")
} else {
format!("[\"{str_data}\"]")
};

Ok(Packet::new(
PacketId::Ack,
nsp.to_owned(),
Some(payload),
ack_id,
0,
None,
))
}
Payload::Text(data) => {
let payload = serde_json::Value::Array(data).to_string();

Ok(Packet::new(
PacketId::Ack,
nsp.to_owned(),
Some(payload),
ack_id,
0,
None,
))
}
}
}
}

impl Default for Packet {
Expand Down Expand Up @@ -671,4 +718,60 @@ mod test {
}
)
}

#[test]
fn ack_from_payload_binary() {
let payload = Payload::Binary(Bytes::from_static(&[0, 4, 9]));
let result = Packet::ack_from_payload(payload.clone(), "namespace", None).unwrap();
assert_eq!(
result,
Packet {
packet_type: PacketId::BinaryAck,
nsp: "namespace".to_owned(),
data: None,
id: None,
attachment_count: 1,
attachments: Some(vec![Bytes::from_static(&[0, 4, 9])]),
}
)
}

#[test]
#[allow(deprecated)]
fn ack_from_payload_string() {
let payload = Payload::String("test".to_owned());
let result =
Packet::ack_from_payload(payload.clone(), "other_namespace", Some(10)).unwrap();
assert_eq!(
result,
Packet {
packet_type: PacketId::Ack,
nsp: "other_namespace".to_owned(),
data: Some("[\"test\"]".to_owned()),
id: Some(10),
attachment_count: 0,
attachments: None,
}
)
}

#[test]
fn ack_from_payload_json() {
let payload = Payload::Text(vec![
serde_json::json!("String test"),
serde_json::json!({"type":"object"}),
]);
let result = Packet::ack_from_payload(payload.clone(), "/", Some(10)).unwrap();
assert_eq!(
result,
Packet {
packet_type: PacketId::Ack,
nsp: "/".to_owned(),
data: Some("[\"String test\",{\"type\":\"object\"}]".to_owned()),
id: Some(10),
attachment_count: 0,
attachments: None,
}
)
}
}
19 changes: 19 additions & 0 deletions socketio/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::packet::{Packet, PacketId};
use bytes::Bytes;
use rust_engineio::{Client as EngineClient, Packet as EnginePacket, PacketId as EnginePacketId};
use std::convert::TryFrom;
use std::sync::atomic::AtomicI32;
use std::sync::{atomic::AtomicBool, Arc};
use std::{fmt::Debug, sync::atomic::Ordering};

Expand All @@ -14,15 +15,19 @@ pub(crate) struct Socket {
//TODO: 0.4.0 refactor this
engine_client: Arc<EngineClient>,
connected: Arc<AtomicBool>,
ack_id: Arc<AtomicI32>,
}

impl Socket {
/// Creates an instance of `Socket`.
pub(super) fn new(engine_client: EngineClient) -> Result<Self> {
let ack_id = Arc::new(AtomicI32::new(-1));

Ok(Socket {
engine_client: Arc::new(engine_client),
connected: Arc::new(AtomicBool::default()),
ack_id: ack_id.clone(),
})
}

Expand All @@ -47,6 +52,9 @@ impl Socket {
if self.connected.load(Ordering::Acquire) {
self.connected.store(false, Ordering::Release);
}
if self.ack_id.load(Ordering::Acquire) != -1 {
self.ack_id.store(-1, Ordering::Release);
}
Ok(())
}

Expand Down Expand Up @@ -78,6 +86,13 @@ impl Socket {
self.send(socket_packet)
}

/// Emits to connected other side with given data
pub fn ack(&self, nsp: &str, data: Payload) -> Result<()> {
let socket_packet =
Packet::ack_from_payload(data, nsp, Some(self.ack_id.load(Ordering::Acquire)))?;
self.send(socket_packet)
}

pub(crate) fn poll(&self) -> Result<Option<Packet>> {
loop {
match self.engine_client.poll() {
Expand All @@ -86,6 +101,10 @@ impl Socket {
|| packet.packet_id == EnginePacketId::MessageBinary
{
let packet = self.handle_engineio_packet(packet)?;
if self.ack_id.load(Ordering::Acquire) != packet.id.unwrap_or(-1) {
self.ack_id
.store(packet.id.unwrap_or(-1), Ordering::Release);
}
self.handle_socketio_packet(&packet);
return Ok(Some(packet));
} else {
Expand Down

0 comments on commit d10db36

Please sign in to comment.