diff --git a/futures-util/src/future/flatten_stream.rs b/futures-util/src/future/flatten_stream.rs index fe3da461bf..d42133ef5a 100644 --- a/futures-util/src/future/flatten_stream.rs +++ b/futures-util/src/future/flatten_stream.rs @@ -3,14 +3,17 @@ use core::pin::Pin; use futures_core::future::Future; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; +use pin_utils::unsafe_pinned; /// Stream for the [`flatten_stream`](super::FutureExt::flatten_stream) method. #[must_use = "streams do nothing unless polled"] pub struct FlattenStream { - state: State + state: State, } impl FlattenStream { + unsafe_pinned!(state: State); + pub(super) fn new(future: Fut) -> FlattenStream { FlattenStream { state: State::Future(future) @@ -30,11 +33,25 @@ impl fmt::Debug for FlattenStream } #[derive(Debug)] -enum State { +enum State { // future is not yet called or called and not ready Future(Fut), // future resolved to Stream - Stream(Fut::Output), + Stream(St), +} + +impl State { + fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> State, Pin<&'a mut St>> { + // safety: data is never moved via the resulting &mut reference + match unsafe { Pin::get_unchecked_mut(self) } { + // safety: the future we're re-pinning here will never be moved; + // it will just be polled, then dropped in place + State::Future(f) => State::Future(unsafe { Pin::new_unchecked(f) }), + // safety: the stream we're repinning here will never be moved; + // it will just be polled, then dropped in place + State::Stream(s) => State::Stream(unsafe { Pin::new_unchecked(s) }), + } + } } impl FusedStream for FlattenStream @@ -57,35 +74,15 @@ impl Stream for FlattenStream fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - // safety: data is never moved via the resulting &mut reference - let stream = match &mut unsafe { Pin::get_unchecked_mut(self.as_mut()) }.state { + match self.as_mut().state().get_pin_mut() { State::Future(f) => { - // safety: the future we're re-pinning here will never be moved; - // it will just be polled, then dropped in place - match unsafe { Pin::new_unchecked(f) }.poll(cx) { - Poll::Pending => { - // State is not changed, early return. - return Poll::Pending - }, - Poll::Ready(stream) => { - // Future resolved to stream. - // We do not return, but poll that - // stream in the next loop iteration. - stream - } - } - } - State::Stream(s) => { - // safety: the stream we're repinning here will never be moved; - // it will just be polled, then dropped in place - return unsafe { Pin::new_unchecked(s) }.poll_next(cx); + let stream = ready!(f.poll(cx)); + // Future resolved to stream. + // We do not return, but poll that + // stream in the next loop iteration. + self.as_mut().state().set(State::Stream(stream)); } - }; - - unsafe { - // safety: we use the &mut only for an assignment, which causes - // only an in-place drop - Pin::get_unchecked_mut(self.as_mut()).state = State::Stream(stream); + State::Stream(s) => return s.poll_next(cx), } } } diff --git a/futures-util/src/try_future/mod.rs b/futures-util/src/try_future/mod.rs index 37c8ab5a5e..31f0566f6a 100644 --- a/futures-util/src/try_future/mod.rs +++ b/futures-util/src/try_future/mod.rs @@ -5,6 +5,7 @@ use core::pin::Pin; use futures_core::future::TryFuture; +use futures_core::stream::TryStream; use futures_core::task::{Context, Poll}; use futures_sink::Sink; @@ -51,6 +52,9 @@ pub use self::map_ok::MapOk; mod or_else; pub use self::or_else::OrElse; +mod try_flatten_stream; +pub use self::try_flatten_stream::TryFlattenStream; + mod unwrap_or_else; pub use self::unwrap_or_else::UnwrapOrElse; @@ -318,6 +322,39 @@ pub trait TryFutureExt: TryFuture { OrElse::new(self, f) } + /// Flatten the execution of this future when the successful result of this + /// future is a stream. + /// + /// This can be useful when stream initialization is deferred, and it is + /// convenient to work with that stream as if stream was available at the + /// call site. + /// + /// Note that this function consumes this future and returns a wrapped + /// version of it. + /// + /// # Examples + /// + /// ``` + /// #![feature(async_await)] + /// # futures::executor::block_on(async { + /// use futures::future::{self, TryFutureExt}; + /// use futures::stream::{self, TryStreamExt}; + /// + /// let stream_items = vec![17, 18, 19].into_iter().map(Ok); + /// let future_of_a_stream = future::ok::<_, ()>(stream::iter(stream_items)); + /// + /// let stream = future_of_a_stream.try_flatten_stream(); + /// let list = stream.try_collect::>().await; + /// assert_eq!(list, Ok(vec![17, 18, 19])); + /// # }); + /// ``` + fn try_flatten_stream(self) -> TryFlattenStream + where Self::Ok: TryStream, + Self: Sized + { + TryFlattenStream::new(self) + } + /// Unwraps this future's ouput, producing a future with this future's /// [`Ok`](TryFuture::Ok) type as its /// [`Output`](std::future::Future::Output) type. diff --git a/futures-util/src/try_future/try_flatten_stream.rs b/futures-util/src/try_future/try_flatten_stream.rs new file mode 100644 index 0000000000..942110bc09 --- /dev/null +++ b/futures-util/src/try_future/try_flatten_stream.rs @@ -0,0 +1,113 @@ +use core::fmt; +use core::pin::Pin; +use futures_core::future::TryFuture; +use futures_core::stream::{FusedStream, Stream, TryStream}; +use futures_core::task::{Context, Poll}; +use pin_utils::unsafe_pinned; + +/// Stream for the [`try_flatten_stream`](super::TryFutureExt::try_flatten_stream) method. +#[must_use = "streams do nothing unless polled"] +pub struct TryFlattenStream +where + Fut: TryFuture, +{ + state: State, +} + +impl TryFlattenStream +where + Fut: TryFuture, + Fut::Ok: TryStream, +{ + unsafe_pinned!(state: State); + + pub(super) fn new(future: Fut) -> Self { + Self { + state: State::Future(future) + } + } +} + +impl fmt::Debug for TryFlattenStream +where + Fut: TryFuture + fmt::Debug, + Fut::Ok: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("TryFlattenStream") + .field("state", &self.state) + .finish() + } +} + +#[derive(Debug)] +enum State { + // future is not yet called or called and not ready + Future(Fut), + // future resolved to Stream + Stream(St), + // future resolved to error + Done, +} + +impl State { + fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> State, Pin<&'a mut St>> { + // safety: data is never moved via the resulting &mut reference + match unsafe { Pin::get_unchecked_mut(self) } { + // safety: the future we're re-pinning here will never be moved; + // it will just be polled, then dropped in place + State::Future(f) => State::Future(unsafe { Pin::new_unchecked(f) }), + // safety: the stream we're repinning here will never be moved; + // it will just be polled, then dropped in place + State::Stream(s) => State::Stream(unsafe { Pin::new_unchecked(s) }), + State::Done => State::Done, + } + } +} + +impl FusedStream for TryFlattenStream +where + Fut: TryFuture, + Fut::Ok: TryStream + FusedStream, +{ + fn is_terminated(&self) -> bool { + match &self.state { + State::Future(_) => false, + State::Stream(stream) => stream.is_terminated(), + State::Done => true, + } + } +} + +impl Stream for TryFlattenStream +where + Fut: TryFuture, + Fut::Ok: TryStream, +{ + type Item = Result<::Ok, Fut::Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.as_mut().state().get_pin_mut() { + State::Future(f) => { + match ready!(f.try_poll(cx)) { + Ok(stream) => { + // Future resolved to stream. + // We do not return, but poll that + // stream in the next loop iteration. + self.as_mut().state().set(State::Stream(stream)); + } + Err(e) => { + // Future resolved to error. + // We have neither a pollable stream nor a future. + self.as_mut().state().set(State::Done); + return Poll::Ready(Some(Err(e))); + } + } + } + State::Stream(s) => return s.try_poll_next(cx), + State::Done => return Poll::Ready(None), + } + } + } +} diff --git a/futures/src/lib.rs b/futures/src/lib.rs index d290e93e4a..fc5f0d106a 100644 --- a/futures/src/lib.rs +++ b/futures/src/lib.rs @@ -251,7 +251,7 @@ pub mod future { TryFutureExt, AndThen, ErrInto, FlattenSink, IntoFuture, MapErr, MapOk, OrElse, - UnwrapOrElse, + TryFlattenStream, UnwrapOrElse, }; #[cfg(feature = "never-type")]