Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ mod respond;
mod response_template;
mod verification;

pub type ErrorResponse = Box<dyn std::error::Error + Send + Sync + 'static>;

pub use mock::{Match, Mock, MockBuilder, Times};
pub use mock_server::{MockGuard, MockServer, MockServerBuilder};
pub use request::Request;
Expand Down
34 changes: 28 additions & 6 deletions src/mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::respond::Respond;
use crate::{MockGuard, MockServer, Request, ResponseTemplate};
use crate::respond::{Respond, RespondErr};
use crate::{ErrorResponse, MockGuard, MockServer, Request, ResponseTemplate};
use std::fmt::{Debug, Formatter};
use std::ops::{
Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
Expand Down Expand Up @@ -256,7 +256,7 @@ impl Debug for Matcher {
#[must_use = "`Mock`s have to be mounted or registered with a `MockServer` to become effective"]
pub struct Mock {
pub(crate) matchers: Vec<Matcher>,
pub(crate) response: Box<dyn Respond>,
pub(crate) response: Result<Box<dyn Respond>, Box<dyn RespondErr>>,
/// Maximum number of times (inclusive) we should return a response from this Mock on
/// matching requests.
/// If `None`, there is no cap and we will respond to all incoming matching requests.
Expand Down Expand Up @@ -629,8 +629,14 @@ impl Mock {

/// Given a [`Request`] build an instance a [`ResponseTemplate`] using
/// the responder associated with the `Mock`.
pub(crate) fn response_template(&self, request: &Request) -> ResponseTemplate {
self.response.respond(request)
pub(crate) fn response_template(
&self,
request: &Request,
) -> Result<ResponseTemplate, ErrorResponse> {
match &self.response {
Ok(responder) => Ok(responder.respond(request)),
Err(responder_err) => Err(responder_err.respond_err(request)),
}
}
}

Expand All @@ -656,7 +662,23 @@ impl MockBuilder {
pub fn respond_with<R: Respond + 'static>(self, responder: R) -> Mock {
Mock {
matchers: self.matchers,
response: Box::new(responder),
response: Ok(Box::new(responder)),
max_n_matches: None,
priority: 5,
name: None,
expectation_range: Times(TimesEnum::Unbounded(RangeFull)),
}
}

/// Instead of response with an HTTP reply, return a Rust error.
///
/// This can simulate lower level errors, e.g., a [`ConnectionReset`] IO Error.
///
/// [`ConnectionReset`]: std::io::ErrorKind::ConnectionReset
pub fn respond_with_err<R: RespondErr + 'static>(self, responder_err: R) -> Mock {
Mock {
matchers: self.matchers,
response: Err(Box::new(responder_err)),
max_n_matches: None,
priority: 5,
name: None,
Expand Down
4 changes: 2 additions & 2 deletions src/mock_server/bare_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::mock_server::hyper::run_server;
use crate::mock_set::MockId;
use crate::mock_set::MountedMockSet;
use crate::request::BodyPrintLimit;
use crate::{mock::Mock, verification::VerificationOutcome, Request};
use crate::{mock::Mock, verification::VerificationOutcome, ErrorResponse, Request};
use http_body_util::Full;
use hyper::body::Bytes;
use std::fmt::{Debug, Write};
Expand Down Expand Up @@ -39,7 +39,7 @@ impl MockServerState {
pub(super) async fn handle_request(
&mut self,
request: Request,
) -> (hyper::Response<Full<Bytes>>, Option<tokio::time::Sleep>) {
) -> Result<(hyper::Response<Full<Bytes>>, Option<tokio::time::Sleep>), ErrorResponse> {
// If request recording is enabled, record the incoming request
// by adding it to the `received_requests` stack
if let Some(received_requests) = &mut self.received_requests {
Expand Down
16 changes: 14 additions & 2 deletions src/mock_server/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::RwLock;

/// Work around a lifetime error where, for some reason,
/// `Box<dyn std::error::Error + Send + Sync + 'static>` can't be converted to a
/// `Box<dyn std::error::Error + Send + Sync>`
struct ErrorLifetimeCast(Box<dyn std::error::Error + Send + Sync + 'static>);

impl From<ErrorLifetimeCast> for Box<dyn std::error::Error + Send + Sync> {
fn from(value: ErrorLifetimeCast) -> Self {
value.0
}
}

/// The actual HTTP server responding to incoming requests according to the specified mocks.
pub(super) async fn run_server(
listener: std::net::TcpListener,
Expand All @@ -24,7 +35,8 @@ pub(super) async fn run_server(
.write()
.await
.handle_request(wiremock_request)
.await;
.await
.map_err(ErrorLifetimeCast)?;

// We do not wait for the delay within the handler otherwise we would be
// holding on to the write-side of the `RwLock` on `mock_set`.
Expand All @@ -38,7 +50,7 @@ pub(super) async fn run_server(
delay.await;
}

Ok::<_, &'static str>(response)
Ok::<_, ErrorLifetimeCast>(response)
}
};

Expand Down
28 changes: 16 additions & 12 deletions src/mock_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use crate::request::BodyPrintLimit;
use crate::{
mounted_mock::MountedMock,
verification::{VerificationOutcome, VerificationReport},
ErrorResponse,
};
use crate::{Mock, Request, ResponseTemplate};
use crate::{Mock, Request};
use http_body_util::Full;
use hyper::body::Bytes;
use log::debug;
Expand Down Expand Up @@ -56,9 +57,9 @@ impl MountedMockSet {
pub(crate) async fn handle_request(
&mut self,
request: Request,
) -> (hyper::Response<Full<Bytes>>, Option<Sleep>) {
) -> Result<(hyper::Response<Full<Bytes>>, Option<Sleep>), ErrorResponse> {
debug!("Handling request.");
let mut response_template: Option<ResponseTemplate> = None;
let mut response_template: Option<_> = None;
self.mocks.sort_by_key(|(m, _)| m.specification.priority);
for (mock, mock_state) in &mut self.mocks {
if *mock_state == MountedMockState::OutOfScope {
Expand All @@ -70,19 +71,22 @@ impl MountedMockSet {
}
}
if let Some(response_template) = response_template {
let delay = response_template.delay().map(sleep);
(response_template.generate_response(), delay)
match response_template {
Ok(response_template) => {
let delay = response_template.delay().map(sleep);
Ok((response_template.generate_response(), delay))
}
Err(err) => Err(err),
}
} else {
let mut msg = "Got unexpected request:\n".to_string();
_ = request.print_with_limit(&mut msg, self.body_print_limit);
debug!("{}", msg);
(
hyper::Response::builder()
.status(hyper::StatusCode::NOT_FOUND)
.body(Full::default())
.unwrap(),
None,
)
let not_found_response = hyper::Response::builder()
.status(hyper::StatusCode::NOT_FOUND)
.body(Full::default())
.unwrap();
Ok((not_found_response, None))
}
}

Expand Down
9 changes: 7 additions & 2 deletions src/mounted_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::sync::{atomic::AtomicBool, Arc};

use tokio::sync::Notify;

use crate::{verification::VerificationReport, Match, Mock, Request, ResponseTemplate};
use crate::{
verification::VerificationReport, ErrorResponse, Match, Mock, Request, ResponseTemplate,
};

/// Given the behaviour specification as a [`Mock`], keep track of runtime information
/// concerning this mock - e.g. how many times it matched on a incoming request.
Expand Down Expand Up @@ -80,7 +82,10 @@ impl MountedMock {
}
}

pub(crate) fn response_template(&self, request: &Request) -> ResponseTemplate {
pub(crate) fn response_template(
&self,
request: &Request,
) -> Result<ResponseTemplate, ErrorResponse> {
self.specification.response_template(request)
}

Expand Down
17 changes: 16 additions & 1 deletion src/respond.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Request, ResponseTemplate};
use crate::{ErrorResponse, Request, ResponseTemplate};

/// Anything that implements `Respond` can be used to reply to an incoming request when a
/// [`Mock`] is activated.
Expand Down Expand Up @@ -147,3 +147,18 @@ where
(self)(request)
}
}

/// Like [`Respond`], but it only allows returning an error through a function.
pub trait RespondErr: Send + Sync {
fn respond_err(&self, request: &Request) -> ErrorResponse;
}

impl<F, Err> RespondErr for F
where
F: Send + Sync + Fn(&Request) -> Err,
Err: std::error::Error + Send + Sync + 'static,
{
fn respond_err(&self, request: &Request) -> ErrorResponse {
Box::new((self)(request))
}
}
62 changes: 61 additions & 1 deletion tests/mocks.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use futures::FutureExt;
use serde::Serialize;
use serde_json::json;
use std::fmt::{Display, Formatter};
use std::io::ErrorKind;
use std::iter;
use std::net::TcpStream;
use std::time::Duration;
use surf::StatusCode;
use wiremock::matchers::{body_json, body_partial_json, method, path, PathExactMatcher};
use wiremock::{Mock, MockServer, ResponseTemplate};
use wiremock::{Mock, MockServer, Request, ResponseTemplate};

#[async_std::test]
async fn new_starts_the_server() {
Expand Down Expand Up @@ -365,3 +368,60 @@ async fn debug_prints_mock_server_variants() {
format!("{:?}", bare_mock_server)
);
}

#[tokio::test]
async fn io_err() {
// Act
let mock_server = MockServer::start().await;
let mock = Mock::given(method("GET")).respond_with_err(|_: &Request| {
std::io::Error::new(ErrorKind::ConnectionReset, "connection reset")
});
mock_server.register(mock).await;

// Assert
let err = reqwest::get(&mock_server.uri()).await.unwrap_err();
// We're skipping the original error since it can be either `error sending request` or
// `error sending request for url (http://127.0.0.1:<port>/)`
let actual_err: Vec<String> =
iter::successors(std::error::Error::source(&err), |err| err.source())
.map(|err| err.to_string())
.collect();

let expected_err = vec![
"client error (SendRequest)".to_string(),
"connection closed before message completed".to_string(),
];
assert_eq!(actual_err, expected_err);
}

#[tokio::test]
async fn custom_err() {
// Act
#[derive(Debug)]
struct CustomErr;
impl Display for CustomErr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("custom error")
}
}
impl std::error::Error for CustomErr {}

let mock_server = MockServer::start().await;
let mock = Mock::given(method("GET")).respond_with_err(|_: &Request| CustomErr);
mock_server.register(mock).await;

// Assert
let err = reqwest::get(&mock_server.uri()).await.unwrap_err();
// We're skipping the original error since it can be either `error sending request` or
// `error sending request for url (http://127.0.0.1:<port>/)`
let actual_err: Vec<String> =
iter::successors(std::error::Error::source(&err), |err| err.source())
.map(|err| err.to_string())
.collect();

let expected_err = vec![
"client error (SendRequest)".to_string(),
"connection closed before message completed".to_string(),
];
assert_eq!(actual_err, expected_err);
}
Loading