Skip to content

Commit

Permalink
add custom equivalence functions
Browse files Browse the repository at this point in the history
  • Loading branch information
evanrittenhouse committed Jan 5, 2025
1 parent 0081e5b commit 897d31e
Showing 1 changed file with 145 additions and 20 deletions.
165 changes: 145 additions & 20 deletions h3i/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::cmp;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;

use multimap::MultiMap;
use quiche;
Expand Down Expand Up @@ -433,33 +434,64 @@ impl Serialize for SerializableQFrame<'_> {
}
}

type CustomEquivalenceHandler =
Box<dyn for<'f> Fn(&'f H3iFrame) -> bool + Send + Sync + 'static>;

#[derive(Serialize, Clone)]
enum Comparator {
Frame(H3iFrame),
#[serde(skip)]
/// Specifies how to compare an incoming [`H3iFrame`] with this [`ExpectedFrame`]. Typically,
/// the validation attempts to fuzzy-match HEADERS frames, but there are times where other
/// behavior is desired (for example, checking deserialized JSON payloads). See
/// [`ExpectedFrame::is_equivalent`] for more logic on how frames are expected.
Fn(Arc<CustomEquivalenceHandler>),
}

/// A combination of stream ID and [`H3iFrame`] which is used to instruct h3i to
/// watch for specific frames.
#[derive(Debug, PartialEq, Eq, Serialize, Clone)]
#[derive(Serialize, Clone)]
pub struct ExpectedFrame {
stream_id: u64,
frame: H3iFrame,
comparator: Comparator,
}

impl ExpectedFrame {
pub fn new(stream_id: u64, frame: impl Into<H3iFrame>) -> Self {
Self {
stream_id,
frame: frame.into(),
comparator: Comparator::Frame(frame.into()),
}
}

pub fn new_with_comparator<F>(stream_id: u64, comparator_fn: F) -> Self
where
F: Fn(&H3iFrame) -> bool + Send + Sync + 'static,
{
Self {
stream_id,
comparator: Comparator::Fn(Arc::new(Box::new(comparator_fn))),
}
}

pub(crate) fn stream_id(&self) -> u64 {
self.stream_id
}

/// Check if this [`ExpectedFrame`] is equivalent to another [`H3iFrame`].
/// For QuicheH3/ResetStream variants, equivalence is the same as
/// equality. For Headers variants, this [`ExpectedFrame`] is equivalent
/// to another if the other frame contains all [`Header`]s in _this_
/// frame.
/// Check if this [`ExpectedFrame`] is equivalent to an incoming [`H3iFrame`]. If `self.comparator`
/// is defined, equivalence is determined by the return value of the user-supplied closure.
///
/// If `self.comparator` is not defined, for QuicheH3/ResetStream variants, equivalence is the
/// same as equality. For Headers variants, this [`ExpectedFrame`] is equivalent to another if
/// the other frame contains all [`Header`]s in _this_ frame. This allows users to fuzzy-match
/// on header values.
pub(crate) fn is_equivalent(&self, other: &H3iFrame) -> bool {
match &self.frame {
let frame = match &self.comparator {
Comparator::Fn(compare) => return compare(other),
Comparator::Frame(frame) => frame,
};

match frame {
H3iFrame::Headers(me) => {
let H3iFrame::Headers(other) = other else {
return false;
Expand All @@ -485,25 +517,118 @@ impl ExpectedFrame {
}
}

impl Debug for ExpectedFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let repr = match &self.comparator {
Comparator::Frame(frame) => format!("{frame:?}"),
Comparator::Fn(_) => "closure".to_string(),
};

write!(
f,
"{}",
format!(
"ExpectedFrame {{ stream_id: {}, comparator: {repr} }}",
self.stream_id
)
)
}
}

impl PartialEq for ExpectedFrame {
fn eq(&self, other: &Self) -> bool {
match (&self.comparator, &other.comparator) {
(Comparator::Frame(this_frame), Comparator::Frame(other_frame)) => {
self.stream_id == other.stream_id && this_frame == other_frame
},
_ => false,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use quiche::h3::frame::Frame;

#[test]
fn test_equivalence() {
let this = ExpectedFrame::new(0, vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
]);
let other = ExpectedFrame::new(0, vec![
fn test_header_equivalence() {
let mut this = ExpectedFrame::new(
0,
vec![Header::new(b"hello", b"world"), Header::new(b"go", b"jets")],
);
let mut other: H3iFrame = vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
Header::new(b"go", b"devils"),
]);
]
.into();
assert!(this.is_equivalent(&other));

this = ExpectedFrame::new(0, other);
other =
vec![Header::new(b"hello", b"world"), Header::new(b"go", b"jets")]
.into();
// `other` does not contain the `go: devils` header, so it's not equivalent to `this.
assert!(!this.is_equivalent(&other));
}

assert!(this.is_equivalent(&other.frame));
// `this` does not contain the `go: devils` header, so `other` is not
// equivalent to `this`.
assert!(!other.is_equivalent(&this.frame));
#[test]
fn test_rst_stream_equivalence() {
let mut rs = ResetStream {
stream_id: 0,
error_code: 57,
};

let this = ExpectedFrame::new(0, H3iFrame::ResetStream(rs.clone()));
let incoming = H3iFrame::ResetStream(rs.clone());
assert!(this.is_equivalent(&incoming));

rs.stream_id = 57;
let incoming = H3iFrame::ResetStream(rs);
assert!(!this.is_equivalent(&incoming));
}

#[test]
fn test_frame_equivalence() {
let mut d = Frame::Data {
payload: b"57".to_vec(),
};

let this = ExpectedFrame::new(0, H3iFrame::QuicheH3(d.clone()));
let incoming = H3iFrame::QuicheH3(d.clone());
assert!(this.is_equivalent(&incoming));

d = Frame::Data {
payload: b"go jets".to_vec(),
};
let incoming = H3iFrame::QuicheH3(d.clone());
assert!(!this.is_equivalent(&incoming));
}

#[test]
fn test_comparator() {
let this = ExpectedFrame::new_with_comparator(0, |frame| {
if let H3iFrame::Headers(..) = frame {
frame
.to_enriched_headers()
.unwrap()
.header_map()
.get(&b"cookie".to_vec())
.is_some_and(|v| {
std::str::from_utf8(v)
.map(|s| s.to_lowercase())
.unwrap()
.contains("cookie")
})
} else {
false
}
});

let incoming: H3iFrame =
vec![Header::new(b"cookie", b"SomeRandomCookie1234")].into();

assert!(this.is_equivalent(&incoming));
}
}

0 comments on commit 897d31e

Please sign in to comment.