Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ datafusion-udf-wasm-arrow2bytes = { path = "arrow2bytes", version = "0.1.0" }
datafusion-udf-wasm-bundle = { path = "guests/bundle", version = "0.1.0" }
datafusion-udf-wasm-guest = { path = "guests/rust", version = "0.1.0" }
datafusion-udf-wasm-python = { path = "guests/python", version = "0.1.0" }
http = { version = "1.3.1", default-features = false }
hyper = { version = "1.7", default-features = false }
tokio = { version = "1.48.0", default-features = false }
pyo3 = { version = "0.27.1", default-features = false, features = ["macros"] }
tar = { version = "0.4.44", default-features = false }
Expand Down
2 changes: 2 additions & 0 deletions host/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ arrow.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-udf-wasm-arrow2bytes.workspace = true
http.workspace = true
hyper.workspace = true
tar.workspace = true
tempfile.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "sync"] }
Expand Down
266 changes: 266 additions & 0 deletions host/src/http.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
//! Interfaces for HTTP interactions of the guest.

use std::{borrow::Cow, collections::HashSet, fmt};

use http::Method;
use wasmtime_wasi_http::body::HyperOutgoingBody;

/// Validates if an outgoing HTTP interaction is allowed.
///
/// You can implement your own business logic here or use one of the pre-built implementations in [this module](self).
pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static {
/// Validate incoming request.
///
/// Return [`Ok`] if the request should be allowed, return [`Err`] otherwise.
fn validate(
&self,
request: &hyper::Request<HyperOutgoingBody>,
use_tls: bool,
) -> Result<(), Rejected>;
}

/// Reject ALL requests.
#[derive(Debug, Clone, Copy, Default)]
pub struct RejectAllHttpRequests;

impl HttpRequestValidator for RejectAllHttpRequests {
fn validate(
&self,
_request: &hyper::Request<HyperOutgoingBody>,
_use_tls: bool,
) -> Result<(), Rejected> {
Err(Rejected)
}
}

/// A request matcher.
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct Matcher {
/// Method.
pub method: Method,

/// Host.
///
/// Requests without a host will be rejected.
pub host: Cow<'static, str>,

/// Port.
///
/// For requests without an explicit port, this defaults to `80` for non-TLS requests and to `443` for TLS requests.
pub port: u16,
}

/// Allow-list requests.
#[derive(Debug, Clone, Default)]
pub struct AllowCertainHttpRequests {
/// Set of all matchers.
///
/// If ANY of them matches, the request will be allowed.
matchers: HashSet<Matcher>,
}

impl AllowCertainHttpRequests {
/// Create new, empty request matcher.
pub fn new() -> Self {
Self::default()
}

/// Allow given request.
pub fn allow(&mut self, matcher: Matcher) {
self.matchers.insert(matcher);
}
}

impl HttpRequestValidator for AllowCertainHttpRequests {
fn validate(
&self,
request: &hyper::Request<HyperOutgoingBody>,
use_tls: bool,
) -> Result<(), Rejected> {
let matcher = Matcher {
method: request.method().clone(),
host: request.uri().host().ok_or(Rejected)?.to_owned().into(),
port: request
.uri()
.port_u16()
.unwrap_or(if use_tls { 443 } else { 80 }),
};

if self.matchers.contains(&matcher) {
Ok(())
} else {
Err(Rejected)
}
}
}

/// Reject HTTP request.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Rejected;

impl fmt::Display for Rejected {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("rejected")
}
}

impl std::error::Error for Rejected {}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn reject_all() {
let policy = RejectAllHttpRequests;

let request = hyper::Request::builder().body(Default::default()).unwrap();
policy.validate(&request, false).unwrap_err();
}

#[test]
fn allow_certain() {
let request_no_port = hyper::Request::builder()
.method(Method::GET)
.uri("http://foo.bar")
.body(Default::default())
.unwrap();

let request_with_port = hyper::Request::builder()
.method(Method::GET)
.uri("http://my.universe:1337")
.body(Default::default())
.unwrap();

struct Case {
matchers: Vec<Matcher>,
result_no_port_no_tls: Result<(), Rejected>,
result_no_port_with_tls: Result<(), Rejected>,
result_with_port_no_tls: Result<(), Rejected>,
result_with_port_with_tls: Result<(), Rejected>,
}

let cases = [
Case {
matchers: vec![],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
host: "foo.bar".into(),
port: 80,
}],
result_no_port_no_tls: Ok(()),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
host: "foo.bar".into(),
port: 443,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Ok(()),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
},
Case {
matchers: vec![Matcher {
method: Method::POST,
host: "foo.bar".into(),
port: 80,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
host: "my.universe".into(),
port: 80,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
host: "my.universe".into(),
port: 1337,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Ok(()),
result_with_port_with_tls: Ok(()),
},
Case {
matchers: vec![
Matcher {
method: Method::GET,
host: "foo.bar".into(),
port: 80,
},
Matcher {
method: Method::POST,
host: "foo.bar".into(),
port: 80,
},
Matcher {
method: Method::GET,
host: "my.universe".into(),
port: 1337,
},
],
result_no_port_no_tls: Ok(()),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Ok(()),
result_with_port_with_tls: Ok(()),
},
];

for (i, case) in cases.into_iter().enumerate() {
println!("case: {}", i + 1);

let Case {
matchers,
result_no_port_no_tls,
result_no_port_with_tls,
result_with_port_no_tls,
result_with_port_with_tls,
} = case;

let mut policy = AllowCertainHttpRequests::default();

for matcher in matchers {
policy.allow(matcher);
}

assert_eq!(
policy.validate(&request_no_port, false),
result_no_port_no_tls,
);
assert_eq!(
policy.validate(&request_no_port, true),
result_no_port_with_tls,
);
assert_eq!(
policy.validate(&request_with_port, false),
result_with_port_no_tls,
);
assert_eq!(
policy.validate(&request_with_port, true),
result_with_port_with_tls,
);
}
}
}
Loading