Skip to content

Commit 7d7eeb0

Browse files
committed
refactor(ffi): Add HeaderCaseMap preserving http1 header casing
1 parent bed9b0d commit 7d7eeb0

File tree

12 files changed

+322
-36
lines changed

12 files changed

+322
-36
lines changed

capi/examples/client.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ int main(int argc, char *argv[]) {
228228
}
229229

230230
hyper_headers *req_headers = hyper_request_headers(req);
231-
hyper_headers_set(req_headers, STR_ARG("host"), STR_ARG(host));
231+
hyper_headers_set(req_headers, STR_ARG("Host"), STR_ARG(host));
232232

233233
// Send it!
234234
hyper_task *send = hyper_clientconn_send(client, req);

src/ffi/body.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub(crate) struct UserBody {
2424
type hyper_body_foreach_callback = extern "C" fn(*mut c_void, *const hyper_buf) -> c_int;
2525

2626
type hyper_body_data_callback =
27-
extern "C" fn(*mut c_void, *mut hyper_context, *mut *mut hyper_buf) -> c_int;
27+
extern "C" fn(*mut c_void, *mut hyper_context<'_>, *mut *mut hyper_buf) -> c_int;
2828

2929
ffi_fn! {
3030
/// Create a new "empty" body.

src/ffi/client.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,15 @@ ffi_fn! {
6767
return std::ptr::null_mut();
6868
}
6969

70-
let req = unsafe { Box::from_raw(req) };
70+
let mut req = unsafe { Box::from_raw(req) };
71+
72+
// Update request with original-case map of headers
73+
req.finalize_request();
74+
7175
let fut = unsafe { &mut *conn }.tx.send_request(req.0);
7276

7377
let fut = async move {
74-
fut.await.map(hyper_response)
78+
fut.await.map(hyper_response::wrap)
7579
};
7680

7781
Box::into_raw(Task::boxed(fut))

src/ffi/error.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl hyper_error {
3333
ErrorKind::IncompleteMessage => hyper_code::HYPERE_UNEXPECTED_EOF,
3434
ErrorKind::User(User::AbortedByCallback) => hyper_code::HYPERE_ABORTED_BY_CALLBACK,
3535
// TODO: add more variants
36-
_ => hyper_code::HYPERE_ERROR
36+
_ => hyper_code::HYPERE_ERROR,
3737
}
3838
}
3939

src/ffi/http_types.rs

+156-23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use bytes::Bytes;
12
use libc::{c_int, size_t};
23
use std::ffi::c_void;
34

@@ -8,13 +9,21 @@ use super::HYPER_ITER_CONTINUE;
89
use crate::header::{HeaderName, HeaderValue};
910
use crate::{Body, HeaderMap, Method, Request, Response, Uri};
1011

11-
// ===== impl Request =====
12-
1312
pub struct hyper_request(pub(super) Request<Body>);
1413

1514
pub struct hyper_response(pub(super) Response<Body>);
1615

17-
pub struct hyper_headers(pub(super) HeaderMap);
16+
#[derive(Default)]
17+
pub struct hyper_headers {
18+
pub(super) headers: HeaderMap,
19+
orig_casing: HeaderCaseMap,
20+
}
21+
22+
// Will probably be moved to `hyper::ext::http1`
23+
#[derive(Debug, Default)]
24+
pub(crate) struct HeaderCaseMap(HeaderMap<Bytes>);
25+
26+
// ===== impl hyper_request =====
1827

1928
ffi_fn! {
2029
/// Construct a new HTTP request.
@@ -96,7 +105,7 @@ ffi_fn! {
96105
/// This is not an owned reference, so it should not be accessed after the
97106
/// `hyper_request` has been consumed.
98107
fn hyper_request_headers(req: *mut hyper_request) -> *mut hyper_headers {
99-
hyper_headers::wrap(unsafe { &mut *req }.0.headers_mut())
108+
hyper_headers::get_or_default(unsafe { &mut *req }.0.extensions_mut())
100109
}
101110
}
102111

@@ -114,7 +123,16 @@ ffi_fn! {
114123
}
115124
}
116125

117-
// ===== impl Response =====
126+
impl hyper_request {
127+
pub(super) fn finalize_request(&mut self) {
128+
if let Some(headers) = self.0.extensions_mut().remove::<hyper_headers>() {
129+
*self.0.headers_mut() = headers.headers;
130+
self.0.extensions_mut().insert(headers.orig_casing);
131+
}
132+
}
133+
}
134+
135+
// ===== impl hyper_response =====
118136

119137
ffi_fn! {
120138
/// Free an HTTP response after using it.
@@ -159,7 +177,7 @@ ffi_fn! {
159177
/// This is not an owned reference, so it should not be accessed after the
160178
/// `hyper_response` has been freed.
161179
fn hyper_response_headers(resp: *mut hyper_response) -> *mut hyper_headers {
162-
hyper_headers::wrap(unsafe { &mut *resp }.0.headers_mut())
180+
hyper_headers::get_or_default(unsafe { &mut *resp }.0.extensions_mut())
163181
}
164182
}
165183

@@ -173,6 +191,22 @@ ffi_fn! {
173191
}
174192
}
175193

194+
impl hyper_response {
195+
pub(super) fn wrap(mut resp: Response<Body>) -> hyper_response {
196+
let headers = std::mem::take(resp.headers_mut());
197+
let orig_casing = resp
198+
.extensions_mut()
199+
.remove::<HeaderCaseMap>()
200+
.unwrap_or_default();
201+
resp.extensions_mut().insert(hyper_headers {
202+
headers,
203+
orig_casing,
204+
});
205+
206+
hyper_response(resp)
207+
}
208+
}
209+
176210
unsafe impl AsTaskType for hyper_response {
177211
fn as_task_type(&self) -> hyper_task_return_type {
178212
hyper_task_return_type::HYPER_TASK_RESPONSE
@@ -185,9 +219,15 @@ type hyper_headers_foreach_callback =
185219
extern "C" fn(*mut c_void, *const u8, size_t, *const u8, size_t) -> c_int;
186220

187221
impl hyper_headers {
188-
pub(crate) fn wrap(cx: &mut HeaderMap) -> &mut hyper_headers {
189-
// A struct with only one field has the same layout as that field.
190-
unsafe { std::mem::transmute::<&mut HeaderMap, &mut hyper_headers>(cx) }
222+
pub(super) fn get_or_default(ext: &mut http::Extensions) -> &mut hyper_headers {
223+
if let None = ext.get_mut::<hyper_headers>() {
224+
ext.insert(hyper_headers {
225+
headers: Default::default(),
226+
orig_casing: Default::default(),
227+
});
228+
}
229+
230+
ext.get_mut::<hyper_headers>().unwrap()
191231
}
192232
}
193233

@@ -199,14 +239,31 @@ ffi_fn! {
199239
/// The callback should return `HYPER_ITER_CONTINUE` to keep iterating, or
200240
/// `HYPER_ITER_BREAK` to stop.
201241
fn hyper_headers_foreach(headers: *const hyper_headers, func: hyper_headers_foreach_callback, userdata: *mut c_void) {
202-
for (name, value) in unsafe { &*headers }.0.iter() {
203-
let name_ptr = name.as_str().as_bytes().as_ptr();
204-
let name_len = name.as_str().as_bytes().len();
205-
let val_ptr = value.as_bytes().as_ptr();
206-
let val_len = value.as_bytes().len();
207-
208-
if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) {
209-
break;
242+
let headers = unsafe { &*headers };
243+
// For each header name/value pair, there may be a value in the casemap
244+
// that corresponds to the HeaderValue. So, we iterator all the keys,
245+
// and for each one, try to pair the originally cased name with the value.
246+
//
247+
// TODO: consider adding http::HeaderMap::entries() iterator
248+
for name in headers.headers.keys() {
249+
let mut names = headers.orig_casing.get_all(name).iter();
250+
251+
for value in headers.headers.get_all(name) {
252+
let (name_ptr, name_len) = if let Some(orig_name) = names.next() {
253+
(orig_name.as_ptr(), orig_name.len())
254+
} else {
255+
(
256+
name.as_str().as_bytes().as_ptr(),
257+
name.as_str().as_bytes().len(),
258+
)
259+
};
260+
261+
let val_ptr = value.as_bytes().as_ptr();
262+
let val_len = value.as_bytes().len();
263+
264+
if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) {
265+
return;
266+
}
210267
}
211268
}
212269
}
@@ -219,8 +276,9 @@ ffi_fn! {
219276
fn hyper_headers_set(headers: *mut hyper_headers, name: *const u8, name_len: size_t, value: *const u8, value_len: size_t) -> hyper_code {
220277
let headers = unsafe { &mut *headers };
221278
match unsafe { raw_name_value(name, name_len, value, value_len) } {
222-
Ok((name, value)) => {
223-
headers.0.insert(name, value);
279+
Ok((name, value, orig_name)) => {
280+
headers.headers.insert(&name, value);
281+
headers.orig_casing.insert(name, orig_name);
224282
hyper_code::HYPERE_OK
225283
}
226284
Err(code) => code,
@@ -237,8 +295,9 @@ ffi_fn! {
237295
let headers = unsafe { &mut *headers };
238296

239297
match unsafe { raw_name_value(name, name_len, value, value_len) } {
240-
Ok((name, value)) => {
241-
headers.0.append(name, value);
298+
Ok((name, value, orig_name)) => {
299+
headers.headers.append(&name, value);
300+
headers.orig_casing.append(name, orig_name);
242301
hyper_code::HYPERE_OK
243302
}
244303
Err(code) => code,
@@ -251,8 +310,9 @@ unsafe fn raw_name_value(
251310
name_len: size_t,
252311
value: *const u8,
253312
value_len: size_t,
254-
) -> Result<(HeaderName, HeaderValue), hyper_code> {
313+
) -> Result<(HeaderName, HeaderValue, Bytes), hyper_code> {
255314
let name = std::slice::from_raw_parts(name, name_len);
315+
let orig_name = Bytes::copy_from_slice(name);
256316
let name = match HeaderName::from_bytes(name) {
257317
Ok(name) => name,
258318
Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG),
@@ -263,5 +323,78 @@ unsafe fn raw_name_value(
263323
Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG),
264324
};
265325

266-
Ok((name, value))
326+
Ok((name, value, orig_name))
327+
}
328+
329+
// ===== impl HeaderCaseMap =====
330+
331+
impl HeaderCaseMap {
332+
pub(crate) fn get_all(&self, name: &HeaderName) -> http::header::GetAll<'_, Bytes> {
333+
self.0.get_all(name)
334+
}
335+
336+
pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) {
337+
self.0.insert(name, orig);
338+
}
339+
340+
pub(crate) fn append<N>(&mut self, name: N, orig: Bytes)
341+
where
342+
N: http::header::IntoHeaderName,
343+
{
344+
self.0.append(name, orig);
345+
}
346+
}
347+
348+
#[cfg(test)]
349+
mod tests {
350+
use super::*;
351+
352+
#[test]
353+
fn test_headers_foreach_cases_preserved() {
354+
let mut headers = hyper_headers::default();
355+
356+
let name1 = b"Set-CookiE";
357+
let value1 = b"a=b";
358+
hyper_headers_add(
359+
&mut headers,
360+
name1.as_ptr(),
361+
name1.len(),
362+
value1.as_ptr(),
363+
value1.len(),
364+
);
365+
366+
let name2 = b"SET-COOKIE";
367+
let value2 = b"c=d";
368+
hyper_headers_add(
369+
&mut headers,
370+
name2.as_ptr(),
371+
name2.len(),
372+
value2.as_ptr(),
373+
value2.len(),
374+
);
375+
376+
let mut vec = Vec::<u8>::new();
377+
hyper_headers_foreach(&headers, concat, &mut vec as *mut _ as *mut c_void);
378+
379+
assert_eq!(vec, b"Set-CookiE: a=b\r\nSET-COOKIE: c=d\r\n");
380+
381+
extern "C" fn concat(
382+
vec: *mut c_void,
383+
name: *const u8,
384+
name_len: usize,
385+
value: *const u8,
386+
value_len: usize,
387+
) -> c_int {
388+
unsafe {
389+
let vec = &mut *(vec as *mut Vec<u8>);
390+
let name = std::slice::from_raw_parts(name, name_len);
391+
let value = std::slice::from_raw_parts(value, value_len);
392+
vec.extend(name);
393+
vec.extend(b": ");
394+
vec.extend(value);
395+
vec.extend(b"\r\n");
396+
}
397+
HYPER_ITER_CONTINUE
398+
}
399+
}
267400
}

src/ffi/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ mod io;
2828
mod task;
2929

3030
pub(crate) use self::body::UserBody;
31+
pub(crate) use self::http_types::HeaderCaseMap;
3132

3233
pub const HYPER_ITER_CONTINUE: libc::c_int = 0;
3334
#[allow(unused)]

src/proto/h1/conn.rs

+16
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ where
4444
error: None,
4545
keep_alive: KA::Busy,
4646
method: None,
47+
#[cfg(feature = "ffi")]
48+
preserve_header_case: false,
4749
title_case_headers: false,
4850
notify_read: false,
4951
reading: Reading::Init,
@@ -142,6 +144,8 @@ where
142144
ParseContext {
143145
cached_headers: &mut self.state.cached_headers,
144146
req_method: &mut self.state.method,
147+
#[cfg(feature = "ffi")]
148+
preserve_header_case: self.state.preserve_header_case,
145149
}
146150
)) {
147151
Ok(msg) => msg,
@@ -474,6 +478,16 @@ where
474478

475479
self.enforce_version(&mut head);
476480

481+
// Maybe check if we should preserve header casing on received
482+
// message headers...
483+
#[cfg(feature = "ffi")]
484+
{
485+
if T::is_client() && !self.state.preserve_header_case {
486+
self.state.preserve_header_case =
487+
head.extensions.get::<crate::ffi::HeaderCaseMap>().is_some();
488+
}
489+
}
490+
477491
let buf = self.io.headers_buf();
478492
match super::role::encode_headers::<T>(
479493
Encode {
@@ -736,6 +750,8 @@ struct State {
736750
/// This is used to know things such as if the message can include
737751
/// a body or not.
738752
method: Option<Method>,
753+
#[cfg(feature = "ffi")]
754+
preserve_header_case: bool,
739755
title_case_headers: bool,
740756
/// Set to true when the Dispatcher should poll read operations
741757
/// again. See the `maybe_notify` method for more.

src/proto/h1/dispatch.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ cfg_server! {
492492
version: parts.version,
493493
subject: parts.status,
494494
headers: parts.headers,
495-
extensions: http::Extensions::default(),
495+
extensions: parts.extensions,
496496
};
497497
Poll::Ready(Some(Ok((head, body))))
498498
} else {
@@ -576,7 +576,7 @@ cfg_client! {
576576
version: parts.version,
577577
subject: crate::proto::RequestLine(parts.method, parts.uri),
578578
headers: parts.headers,
579-
extensions: http::Extensions::default(),
579+
extensions: parts.extensions,
580580
};
581581
*this.callback = Some(cb);
582582
Poll::Ready(Some(Ok((head, body))))

src/proto/h1/io.rs

+4
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ where
159159
ParseContext {
160160
cached_headers: parse_ctx.cached_headers,
161161
req_method: parse_ctx.req_method,
162+
#[cfg(feature = "ffi")]
163+
preserve_header_case: parse_ctx.preserve_header_case,
162164
},
163165
)? {
164166
Some(msg) => {
@@ -636,6 +638,8 @@ mod tests {
636638
let parse_ctx = ParseContext {
637639
cached_headers: &mut None,
638640
req_method: &mut None,
641+
#[cfg(feature = "ffi")]
642+
preserve_header_case: false,
639643
};
640644
assert!(buffered
641645
.parse::<ClientTransaction>(cx, parse_ctx)

0 commit comments

Comments
 (0)