Skip to content

Commit b4fd93f

Browse files
committed
feat: WASM streaming body
1 parent 4cb2866 commit b4fd93f

File tree

4 files changed

+156
-12
lines changed

4 files changed

+156
-12
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ tower-service = "0.3"
109109
futures-core = { version = "0.3.28", default-features = false }
110110
futures-util = { version = "0.3.28", default-features = false, optional = true }
111111
sync_wrapper = { version = "1.0", features = ["futures"] }
112+
pin-project-lite = "0.2.11"
112113

113114
# Optional deps...
114115

@@ -129,7 +130,6 @@ percent-encoding = "2.3"
129130
tokio = { version = "1.0", default-features = false, features = ["net", "time"] }
130131
tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] }
131132
tower-http = { version = "0.6.5", default-features = false, features = ["follow-redirect"] }
132-
pin-project-lite = "0.2.11"
133133

134134
# Optional deps...
135135
rustls-pki-types = { version = "1.9.0", features = ["std"], optional = true }

src/lib.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,11 @@ fn _assert_impls() {
333333
assert_sync::<Client>();
334334
assert_clone::<Client>();
335335

336-
assert_send::<Request>();
337-
assert_send::<RequestBuilder>();
336+
#[cfg(not(target_arch = "wasm32"))]
337+
{
338+
assert_send::<Request>();
339+
assert_send::<RequestBuilder>();
340+
}
338341

339342
#[cfg(not(target_arch = "wasm32"))]
340343
{
@@ -344,8 +347,11 @@ fn _assert_impls() {
344347
assert_send::<Error>();
345348
assert_sync::<Error>();
346349

347-
assert_send::<Body>();
348-
assert_sync::<Body>();
350+
#[cfg(not(target_arch = "wasm32"))]
351+
{
352+
assert_send::<Body>();
353+
assert_sync::<Body>();
354+
}
349355
}
350356

351357
if_hyper! {

src/wasm/body.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ enum Inner {
2222
/// MultipartForm holds a multipart/form-data body.
2323
#[cfg(feature = "multipart")]
2424
MultipartForm(Form),
25+
#[cfg(feature = "stream")]
26+
Streaming(Streaming),
2527
}
2628

2729
#[derive(Clone)]
@@ -58,6 +60,15 @@ impl Single {
5860
}
5961
}
6062

63+
pub(crate) type BodyFuture =
64+
std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), JsValue>> + 'static>>;
65+
66+
#[cfg(feature = "stream")]
67+
pub(crate) struct Streaming {
68+
write_fut: BodyFuture,
69+
readable: web_sys::ReadableStream,
70+
}
71+
6172
impl Body {
6273
/// Returns a reference to the internal data of the `Body`.
6374
///
@@ -68,6 +79,56 @@ impl Body {
6879
Inner::Single(single) => Some(single.as_bytes()),
6980
#[cfg(feature = "multipart")]
7081
Inner::MultipartForm(_) => None,
82+
#[cfg(feature = "stream")]
83+
Inner::Streaming(_) => None,
84+
}
85+
}
86+
87+
/// Turn a futures `Stream` into a JS `ReadableStream`.
88+
///
89+
/// # Example
90+
///
91+
/// ```
92+
/// # use reqwest::Body;
93+
/// # use futures_util;
94+
/// # fn main() {
95+
/// let chunks: Vec<Result<_, ::std::io::Error>> = vec![
96+
/// Ok("hello"),
97+
/// Ok(" "),
98+
/// Ok("world"),
99+
/// ];
100+
///
101+
/// let stream = futures_util::stream::iter(chunks);
102+
///
103+
/// let body = Body::wrap_stream(stream);
104+
/// # }
105+
/// ```
106+
///
107+
/// # Optional
108+
///
109+
/// This requires the `stream` feature to be enabled.
110+
#[cfg(feature = "stream")]
111+
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
112+
pub fn wrap_stream<S>(stream: S) -> Body
113+
where
114+
S: futures_core::stream::TryStream + 'static,
115+
S::Error: Into<Box<dyn std::error::Error>>,
116+
Bytes: From<S::Ok>,
117+
{
118+
use futures_util::{FutureExt, StreamExt, TryStreamExt};
119+
use wasm_bindgen::{JsError, UnwrapThrowExt};
120+
121+
let transform_stream =
122+
wasm_streams::TransformStream::from_raw(web_sys::TransformStream::new().unwrap_throw());
123+
Body {
124+
inner: Inner::Streaming(Streaming {
125+
write_fut: stream
126+
.map_ok(|b| Single::Bytes(b.into()).to_js_value())
127+
.map_err(|err| JsValue::from(JsError::new(&err.into().to_string())))
128+
.forward(transform_stream.writable().into_sink())
129+
.boxed_local(),
130+
readable: transform_stream.readable().into_raw(),
131+
}),
71132
}
72133
}
73134

@@ -80,6 +141,18 @@ impl Body {
80141
let js_value: &JsValue = form_data.as_ref();
81142
Ok(js_value.to_owned())
82143
}
144+
#[cfg(feature = "stream")]
145+
Inner::Streaming(streaming) => Ok(streaming.readable.clone().into()),
146+
}
147+
}
148+
149+
pub(crate) fn into_future(self) -> Option<BodyFuture> {
150+
match self.inner {
151+
Inner::Single(_) => None,
152+
#[cfg(feature = "multipart")]
153+
Inner::MultipartForm(_) => None,
154+
#[cfg(feature = "stream")]
155+
Inner::Streaming(streaming) => Some(streaming.write_fut),
83156
}
84157
}
85158

@@ -88,6 +161,8 @@ impl Body {
88161
match &self.inner {
89162
Inner::Single(single) => Some(single),
90163
Inner::MultipartForm(_) => None,
164+
#[cfg(feature = "stream")]
165+
Inner::Streaming(_) => None,
91166
}
92167
}
93168

@@ -109,6 +184,10 @@ impl Body {
109184
Inner::MultipartForm(form) => Self {
110185
inner: Inner::MultipartForm(form),
111186
},
187+
#[cfg(feature = "stream")]
188+
Inner::Streaming(streaming) => Self {
189+
inner: Inner::Streaming(streaming),
190+
},
112191
}
113192
}
114193

@@ -117,6 +196,8 @@ impl Body {
117196
Inner::Single(single) => single.is_empty(),
118197
#[cfg(feature = "multipart")]
119198
Inner::MultipartForm(form) => form.is_empty(),
199+
#[cfg(feature = "stream")]
200+
Inner::Streaming(_) => false,
120201
}
121202
}
122203

@@ -127,6 +208,8 @@ impl Body {
127208
}),
128209
#[cfg(feature = "multipart")]
129210
Inner::MultipartForm(_) => None,
211+
#[cfg(feature = "stream")]
212+
Inner::Streaming(_) => None,
130213
}
131214
}
132215
}

src/wasm/client.rs

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
use std::convert::TryInto;
2+
use std::fmt;
3+
use std::future::Future;
4+
use std::pin::Pin;
5+
use std::sync::Arc;
6+
use std::task::{ready, Context, Poll};
7+
18
use http::header::USER_AGENT;
29
use http::{HeaderMap, HeaderValue, Method};
310
use js_sys::{Promise, JSON};
4-
use std::convert::TryInto;
5-
use std::{fmt, future::Future, sync::Arc};
11+
use pin_project_lite::pin_project;
612
use url::Url;
713
use wasm_bindgen::prelude::{wasm_bindgen, UnwrapThrowExt as _};
814

@@ -182,11 +188,46 @@ impl fmt::Debug for ClientBuilder {
182188
}
183189
}
184190

191+
pin_project! {
192+
struct Pending {
193+
#[pin]
194+
body_fut: Option<super::body::BodyFuture>,
195+
#[pin]
196+
fetch: wasm_bindgen_futures::JsFuture,
197+
}
198+
}
199+
200+
impl Future for Pending {
201+
type Output = Result<web_sys::Response, crate::error::BoxError>;
202+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
203+
use wasm_bindgen::JsCast;
204+
205+
let mut this = self.project();
206+
if let Some(body_fut) = this.body_fut.as_mut().as_pin_mut() {
207+
if let Poll::Ready(res) = body_fut.poll(cx) {
208+
this.body_fut.set(None);
209+
if let Err(err) = res {
210+
return Poll::Ready(Err(crate::error::wasm(err)));
211+
}
212+
}
213+
}
214+
Poll::Ready(
215+
ready!(this.fetch.poll(cx))
216+
.map_err(crate::error::wasm)
217+
.and_then(|js_resp| {
218+
js_resp
219+
.dyn_into::<web_sys::Response>()
220+
.map_err(|_js_val| "promise resolved to unexpected type".into())
221+
}),
222+
)
223+
}
224+
}
225+
185226
// Can use new methods in web-sys when requiring v0.2.93.
186227
// > `init.method(m)` to `init.set_method(m)`
187228
// For now, ignore their deprecation.
188229
#[allow(deprecated)]
189-
async fn fetch(req: Request) -> crate::Result<Response> {
230+
async fn fetch(mut req: Request) -> crate::Result<Response> {
190231
// Build the js Request
191232
let mut init = web_sys::RequestInit::new();
192233
init.method(req.method().as_str());
@@ -216,11 +257,22 @@ async fn fetch(req: Request) -> crate::Result<Response> {
216257
init.credentials(creds);
217258
}
218259

219-
if let Some(body) = req.body() {
260+
let body_fut = if let Some(body) = req.body_mut().take() {
220261
if !body.is_empty() {
221262
init.body(Some(body.to_js_value()?.as_ref()));
263+
let fut = body.into_future();
264+
if fut.is_some() {
265+
js_sys::Reflect::set(&init, &"duplex".into(), &"half".into())
266+
.map_err(crate::error::wasm)
267+
.map_err(crate::error::builder)?;
268+
}
269+
fut
270+
} else {
271+
None
222272
}
223-
}
273+
} else {
274+
None
275+
};
224276

225277
let mut abort = AbortGuard::new()?;
226278
if let Some(timeout) = req.timeout() {
@@ -233,8 +285,11 @@ async fn fetch(req: Request) -> crate::Result<Response> {
233285
.map_err(crate::error::builder)?;
234286

235287
// Await the fetch() promise
236-
let p = js_fetch(&js_req);
237-
let js_resp = super::promise::<web_sys::Response>(p)
288+
let pending = Pending {
289+
body_fut,
290+
fetch: js_fetch(&js_req).into(),
291+
};
292+
let js_resp = pending
238293
.await
239294
.map_err(|error| {
240295
if error.to_string() == "JsValue(\"reqwest::errors::TimedOut\")" {

0 commit comments

Comments
 (0)