diff --git a/Cargo.lock b/Cargo.lock index c88e74d..74aec94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -90,12 +90,54 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "futures" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cd0210d8c325c245ff06fd95a3b13689a1a276ac8cfa8e8720cb840bfb84b9e" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fc8cd39e3dbf865f7340dce6a2d401d24fd37c6fe6c4f0ee0de8bfca2252d27" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "futures-core" version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "629316e42fe7c2a0b9a65b47d159ceaa5453ab14e8f0a3c5eedbb8cd55b4a445" +[[package]] +name = "futures-executor" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b808bf53348a36cab739d7e04755909b9fcaaa69b7d7e588b37b6ec62704c97" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e481354db6b5c353246ccf6a728b0c5511d752c08da7260546fc0933869daa11" + [[package]] name = "futures-macro" version = "0.3.18" @@ -107,6 +149,12 @@ dependencies = [ "syn", ] +[[package]] +name = "futures-sink" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "996c6442437b62d21a32cd9906f9c41e7dc1e19a9579843fad948696769305af" + [[package]] name = "futures-task" version = "0.3.18" @@ -119,9 +167,13 @@ version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41d22213122356472061ac0f1ab2cee28d2bac8491410fd68c2af53d1cedb83e" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -342,6 +394,7 @@ version = "0.1.0" dependencies = [ "env_logger", "flate2", + "futures", "futures-util", "log", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 8e60b5d..d81e7dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,18 +12,21 @@ repository = "https://github.com/ZeroTwo-Bot/zlib-stream-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -log = "0.4.14" -env_logger = "0.9.0" -thiserror = "1.0.28" +log = "^0.4" +env_logger = "^0.9" +thiserror = "^1.0" flate2 = { version = "1.0", features = ["cloudflare_zlib"], default-features = false } -futures-util = { version = "0.3.17", optional = true } -tokio = { version = "^1", features = ["rt-multi-thread"], optional = true } +futures-util = { version = "^0.3", optional = true } +futures = { version = "^0.3", optional = true } +tokio = { version = "^1", features = ["rt", "rt-multi-thread"], optional = true } [dev-dependencies] -tokio = { version = "1.14.0", features = ["rt", "macros"] } +tokio = { version = "1.14.0", features = ["rt", "rt-multi-thread", "macros"] } [features] default = ["stream"] -tokio-runtime = ["stream", "tokio"] stream = ["futures-util"] +thread = ["futures"] +chunk = ["stream"] +tokio-runtime = ["tokio", "thread"] diff --git a/src/chunk.rs b/src/chunk.rs new file mode 100644 index 0000000..e894f24 --- /dev/null +++ b/src/chunk.rs @@ -0,0 +1,57 @@ +use futures_util::{Stream, StreamExt}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub struct ChunkedByteStream + Sized, T: Stream + Unpin> { + stream: T, + max_chunk_size: usize, + pending_frame: Option>, +} + +impl + Sized, T: Stream + Unpin> ChunkedByteStream { + /// Creates a new ChunkedByteStream object with a defined chunk size which splits incoming + /// vec chunks into chunks of the defined `max_chunk_size`. + /// `max_chunk_size` must not be lower than 64 to avoid issues with the `ZlibStream` implementation. + pub fn new(stream: T, max_chunk_size: usize) -> Self { + Self { + stream, + max_chunk_size, + pending_frame: None, + } + } +} + +impl + Sized, T: Stream + Unpin> Stream for ChunkedByteStream { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(frame) = &self.pending_frame { + let frame = frame.clone(); + let send_frame = if frame.len() > self.max_chunk_size { + self.pending_frame = Some(frame[self.max_chunk_size..].to_owned()); + frame[..self.max_chunk_size].to_owned() + } else { + self.pending_frame = None; + frame + }; + return Poll::Ready(Some(send_frame)); + } + match Pin::new(&mut self.stream.next()).poll(cx) { + Poll::Ready(data) => { + if let Some(data) = data { + let vec = data.as_ref().to_vec(); + if vec.len() > self.max_chunk_size { + self.pending_frame = Some(vec[self.max_chunk_size..].to_owned()); + Poll::Ready(Some(vec[..self.max_chunk_size].to_owned())) + } else { + Poll::Ready(Some(vec)) + } + } else { + Poll::Ready(None) + } + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 11f273c..dd1b013 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,9 @@ +#[cfg(feature = "chunk")] +pub mod chunk; #[cfg(feature = "stream")] pub mod stream; +#[cfg(feature = "thread")] +pub mod thread; #[cfg(test)] mod test; diff --git a/src/stream.rs b/src/stream.rs index 1848506..73fb246 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "thread")] +use crate::thread::ZlibStreamDecompressorThread; use crate::{ZlibDecompressionError, ZlibStreamDecompressor}; use flate2::DecompressError; use futures_util::{Stream, StreamExt}; @@ -6,26 +8,42 @@ use std::pin::Pin; use std::task::{Context, Poll}; pub struct ZlibStream + Sized, T: Stream + Unpin> { + #[cfg(not(feature = "thread"))] decompressor: ZlibStreamDecompressor, + #[cfg(feature = "thread")] + decompressor: ZlibStreamDecompressorThread, stream: T, + #[cfg(feature = "thread")] + thread_poll: + Option, ZlibDecompressionError>>>, } impl + Sized, T: Stream + Unpin> ZlibStream { /// Creates a new ZlibStream object with the default decompressor and the underlying /// stream as data source pub fn new(stream: T) -> Self { + #[cfg(not(feature = "thread"))] + let decompressor = Default::default(); + #[cfg(feature = "thread")] + let decompressor = ZlibStreamDecompressorThread::spawn(Default::default()); Self { - decompressor: Default::default(), + decompressor, stream, + #[cfg(feature = "thread")] + thread_poll: None, } } /// Creates a new ZlibStream object with the specified decompressor and the underlying /// stream as data source pub fn new_with_decompressor(decompressor: ZlibStreamDecompressor, stream: T) -> Self { + #[cfg(feature = "thread")] + let decompressor = ZlibStreamDecompressorThread::spawn(decompressor); Self { decompressor, stream, + #[cfg(feature = "thread")] + thread_poll: None, } } } @@ -34,15 +52,29 @@ impl + Sized, T: Stream + Unpin> Stream for ZlibStream< type Item = Result, DecompressError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[cfg(feature = "thread")] + if let Some(poll) = &mut self.thread_poll { + if let Some(poll) = poll_decompress_channel(poll, cx) { + match poll { + Poll::Ready(_) => { + self.thread_poll = None; + } + Poll::Pending => { + cx.waker().wake_by_ref(); + } + } + return poll; + } + } match Pin::new(&mut self.stream.next()).poll(cx) { Poll::Ready(vec) => { if let Some(vec) = vec { - #[cfg(feature = "tokio-runtime")] - let result = tokio::task::block_in_place(|| self.decompressor.decompress(vec)); - - #[cfg(not(feature = "tokio-runtime"))] + #[cfg(not(feature = "thread"))] let result = self.decompressor.decompress(vec); + #[cfg(feature = "thread")] + let mut result = self.decompressor.decompress(vec.as_ref().to_vec()); + #[cfg(not(feature = "thread"))] match result { Ok(data) => Poll::Ready(Some(Ok(data))), Err(ZlibDecompressionError::NeedMoreData) => { @@ -53,6 +85,20 @@ impl + Sized, T: Stream + Unpin> Stream for ZlibStream< Poll::Ready(Some(Err(err))) } } + + #[cfg(feature = "thread")] + { + if let Some(poll) = poll_decompress_channel(&mut result, cx) { + if let Poll::Pending = poll { + self.thread_poll = Some(result); + cx.waker().wake_by_ref(); + } + poll + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + } } else { Poll::Ready(None) } @@ -61,3 +107,25 @@ impl + Sized, T: Stream + Unpin> Stream for ZlibStream< } } } + +#[cfg(feature = "thread")] +fn poll_decompress_channel( + channel: &mut futures::channel::oneshot::Receiver, ZlibDecompressionError>>, + cx: &mut Context<'_>, +) -> Option, DecompressError>>>> { + match Pin::new(channel).poll(cx) { + Poll::Ready(outer) => match outer { + Ok(result) => match result { + Ok(data) => Some(Poll::Ready(Some(Ok(data)))), + Err(ZlibDecompressionError::DecompressError(err)) => { + Some(Poll::Ready(Some(Err(err)))) + } + _ => None, + }, + Err(_cancelled) => return Some(Poll::Ready(None)), + }, + Poll::Pending => { + return Some(Poll::Pending); + } + } +} diff --git a/src/test.rs b/src/test.rs index 1090069..6f3cb62 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,7 +1,9 @@ +#[cfg(feature = "chunk")] +use crate::chunk::ChunkedByteStream; +#[cfg(feature = "stream")] use crate::stream::ZlibStream; use crate::{ZlibDecompressionError, ZlibStreamDecompressor}; -use futures_util::{Stream, StreamExt}; -use std::pin::Pin; +use futures_util::StreamExt; fn payload() -> Vec { vec![ @@ -72,7 +74,7 @@ async fn test_stream() { let stream = futures_util::stream::iter(stream); let mut stream = ZlibStream::new(stream); - let result = futures_util::future::poll_fn(move |cx| Pin::new(&mut stream).poll_next(cx)).await; + let result = stream.next().await; assert_eq!( inflated(), String::from_utf8( @@ -104,3 +106,42 @@ async fn test_stream_split() { .unwrap() ) } + +#[cfg(feature = "chunk")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_chunk_stream() { + let chunk_size = 8usize; + let data = vec![payload()]; + + let stream = futures_util::stream::iter(data); + let mut stream = ChunkedByteStream::new(stream, chunk_size); + let mut concat = vec![]; + while let Some(data) = stream.next().await { + concat.extend_from_slice(data.as_slice()); + assert!(data.len() <= chunk_size, "Data size exceeded threshold!") + } + + assert_eq!(concat, payload(), "Payloads aren't equal") +} + +#[cfg(feature = "chunk")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_chunk_stream_zlib() { + let chunk_size = 55usize; + let data = vec![payload()]; + + let stream = futures_util::stream::iter(data); + let stream = ChunkedByteStream::new(stream, chunk_size); + let mut stream = ZlibStream::new(stream); + + let result = stream.next().await; + assert_eq!( + inflated(), + String::from_utf8( + result + .expect("Poll returned end of stream") + .expect("Decompression failed") + ) + .unwrap() + ) +} diff --git a/src/thread.rs b/src/thread.rs new file mode 100644 index 0000000..798512a --- /dev/null +++ b/src/thread.rs @@ -0,0 +1,65 @@ +use crate::{ZlibDecompressionError, ZlibStreamDecompressor}; +pub use futures::channel::oneshot::Canceled; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; + +#[derive(Clone)] +pub struct ZlibStreamDecompressorThread { + sender: SyncSender, +} + +enum ThreadMessage { + Finish, + Decompress( + Vec, + futures::channel::oneshot::Sender, ZlibDecompressionError>>, + ), +} + +impl ZlibStreamDecompressorThread { + pub fn spawn(decompressor: ZlibStreamDecompressor) -> ZlibStreamDecompressorThread { + let (tx, rx) = sync_channel(128); + #[cfg(not(feature = "tokio-runtime"))] + std::thread::spawn(move || ZlibStreamDecompressorThread::work(decompressor, rx)); + #[cfg(feature = "tokio-runtime")] + tokio::task::spawn_blocking(move || ZlibStreamDecompressorThread::work(decompressor, rx)); + ZlibStreamDecompressorThread { sender: tx } + } + + fn work(mut decompressor: ZlibStreamDecompressor, rx: Receiver) { + loop { + match rx.recv() { + Ok(message) => { + match message { + ThreadMessage::Finish => break, + ThreadMessage::Decompress(data, channel) => { + let result = decompressor.decompress(data); + channel.send(result).ok(); // this is fine + } + } + } + Err(err) => { + log::error!("Decompressor Thread errored on channel recv: {}", err) + } + } + } + } + + pub fn abort(&self) { + self.sender.send(ThreadMessage::Finish).ok(); + } + + pub fn decompress( + &self, + data: Vec, + ) -> futures::channel::oneshot::Receiver, ZlibDecompressionError>> { + let (tx, rx) = futures::channel::oneshot::channel(); + self.sender.send(ThreadMessage::Decompress(data, tx)).ok(); + rx + } +} + +impl Drop for ZlibStreamDecompressorThread { + fn drop(&mut self) { + self.abort(); + } +}