|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +use crate::client::{ |
| 19 | + HttpError, HttpErrorKind, HttpRequest, HttpResponse, HttpResponseBody, HttpService, |
| 20 | +}; |
| 21 | +use async_trait::async_trait; |
| 22 | +use bytes::Bytes; |
| 23 | +use http::Response; |
| 24 | +use http_body_util::BodyExt; |
| 25 | +use hyper::body::{Body, Frame}; |
| 26 | +use std::pin::Pin; |
| 27 | +use std::task::{Context, Poll}; |
| 28 | +use thiserror::Error; |
| 29 | +use tokio::runtime::Handle; |
| 30 | +use tokio::task::JoinHandle; |
| 31 | + |
| 32 | +/// Spawn error |
| 33 | +#[derive(Debug, Error)] |
| 34 | +#[error("SpawnError")] |
| 35 | +struct SpawnError {} |
| 36 | + |
| 37 | +impl From<SpawnError> for HttpError { |
| 38 | + fn from(value: SpawnError) -> Self { |
| 39 | + Self::new(HttpErrorKind::Interrupted, value) |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +/// Wraps a provided [`HttpService`] and runs it on a separate tokio runtime |
| 44 | +#[derive(Debug)] |
| 45 | +pub struct SpawnService<T: HttpService + Clone> { |
| 46 | + inner: T, |
| 47 | + runtime: Handle, |
| 48 | +} |
| 49 | + |
| 50 | +impl<T: HttpService + Clone> SpawnService<T> { |
| 51 | + /// Creates a new [`SpawnService`] from the provided |
| 52 | + pub fn new(inner: T, runtime: Handle) -> Self { |
| 53 | + Self { inner, runtime } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +#[async_trait] |
| 58 | +impl<T: HttpService + Clone> HttpService for SpawnService<T> { |
| 59 | + async fn call(&self, req: HttpRequest) -> Result<HttpResponse, HttpError> { |
| 60 | + let inner = self.inner.clone(); |
| 61 | + let (send, recv) = tokio::sync::oneshot::channel(); |
| 62 | + |
| 63 | + // We use an unbounded channel to prevent backpressure across the runtime boundary |
| 64 | + // which could in turn starve the underlying IO operations |
| 65 | + let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); |
| 66 | + |
| 67 | + let handle = SpawnHandle(self.runtime.spawn(async move { |
| 68 | + let r = match HttpService::call(&inner, req).await { |
| 69 | + Ok(resp) => resp, |
| 70 | + Err(e) => { |
| 71 | + let _ = send.send(Err(e)); |
| 72 | + return; |
| 73 | + } |
| 74 | + }; |
| 75 | + |
| 76 | + let (parts, mut body) = r.into_parts(); |
| 77 | + if send.send(Ok(parts)).is_err() { |
| 78 | + return; |
| 79 | + } |
| 80 | + |
| 81 | + while let Some(x) = body.frame().await { |
| 82 | + sender.send(x).unwrap(); |
| 83 | + } |
| 84 | + })); |
| 85 | + |
| 86 | + let parts = recv.await.map_err(|_| SpawnError {})??; |
| 87 | + |
| 88 | + Ok(Response::from_parts( |
| 89 | + parts, |
| 90 | + HttpResponseBody::new(SpawnBody { |
| 91 | + stream: receiver, |
| 92 | + _worker: handle, |
| 93 | + }), |
| 94 | + )) |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +/// A wrapper around a [`JoinHandle`] that aborts on drop |
| 99 | +struct SpawnHandle(JoinHandle<()>); |
| 100 | +impl Drop for SpawnHandle { |
| 101 | + fn drop(&mut self) { |
| 102 | + self.0.abort(); |
| 103 | + } |
| 104 | +} |
| 105 | + |
| 106 | +type StreamItem = Result<Frame<Bytes>, HttpError>; |
| 107 | + |
| 108 | +struct SpawnBody { |
| 109 | + stream: tokio::sync::mpsc::UnboundedReceiver<StreamItem>, |
| 110 | + _worker: SpawnHandle, |
| 111 | +} |
| 112 | + |
| 113 | +impl Body for SpawnBody { |
| 114 | + type Data = Bytes; |
| 115 | + type Error = HttpError; |
| 116 | + |
| 117 | + fn poll_frame(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<StreamItem>> { |
| 118 | + self.stream.poll_recv(cx) |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +#[cfg(test)] |
| 123 | +mod tests { |
| 124 | + use super::*; |
| 125 | + use crate::client::mock_server::MockServer; |
| 126 | + use crate::client::retry::RetryExt; |
| 127 | + use crate::client::HttpClient; |
| 128 | + use crate::RetryConfig; |
| 129 | + |
| 130 | + async fn test_client(client: HttpClient) { |
| 131 | + let (send, recv) = tokio::sync::oneshot::channel(); |
| 132 | + |
| 133 | + let mock = MockServer::new().await; |
| 134 | + mock.push(Response::new("BANANAS".to_string())); |
| 135 | + |
| 136 | + let url = mock.url().to_string(); |
| 137 | + let thread = std::thread::spawn(|| { |
| 138 | + futures::executor::block_on(async move { |
| 139 | + let retry = RetryConfig::default(); |
| 140 | + let ret = client.get(url).send_retry(&retry).await.unwrap(); |
| 141 | + let payload = ret.into_body().bytes().await.unwrap(); |
| 142 | + assert_eq!(payload.as_ref(), b"BANANAS"); |
| 143 | + let _ = send.send(()); |
| 144 | + }) |
| 145 | + }); |
| 146 | + recv.await.unwrap(); |
| 147 | + thread.join().unwrap(); |
| 148 | + } |
| 149 | + |
| 150 | + #[tokio::test] |
| 151 | + async fn test_spawn() { |
| 152 | + let client = HttpClient::new(SpawnService::new(reqwest::Client::new(), Handle::current())); |
| 153 | + test_client(client).await; |
| 154 | + } |
| 155 | + |
| 156 | + #[tokio::test] |
| 157 | + #[should_panic] |
| 158 | + async fn test_no_spawn() { |
| 159 | + let client = HttpClient::new(reqwest::Client::new()); |
| 160 | + test_client(client).await; |
| 161 | + } |
| 162 | +} |
0 commit comments