1
+ use bytes:: Bytes ;
1
2
use libc:: { c_int, size_t} ;
2
3
use std:: ffi:: c_void;
3
4
@@ -8,13 +9,21 @@ use super::HYPER_ITER_CONTINUE;
8
9
use crate :: header:: { HeaderName , HeaderValue } ;
9
10
use crate :: { Body , HeaderMap , Method , Request , Response , Uri } ;
10
11
11
- // ===== impl Request =====
12
-
13
12
pub struct hyper_request ( pub ( super ) Request < Body > ) ;
14
13
15
14
pub struct hyper_response ( pub ( super ) Response < Body > ) ;
16
15
17
- pub struct hyper_headers ( pub ( super ) HeaderMap ) ;
16
+ #[ derive( Default ) ]
17
+ pub struct hyper_headers {
18
+ pub ( super ) headers : HeaderMap ,
19
+ orig_casing : HeaderCaseMap ,
20
+ }
21
+
22
+ // Will probably be moved to `hyper::ext::http1`
23
+ #[ derive( Debug , Default ) ]
24
+ pub ( crate ) struct HeaderCaseMap ( HeaderMap < Bytes > ) ;
25
+
26
+ // ===== impl hyper_request =====
18
27
19
28
ffi_fn ! {
20
29
/// Construct a new HTTP request.
@@ -96,7 +105,7 @@ ffi_fn! {
96
105
/// This is not an owned reference, so it should not be accessed after the
97
106
/// `hyper_request` has been consumed.
98
107
fn hyper_request_headers( req: * mut hyper_request) -> * mut hyper_headers {
99
- hyper_headers:: wrap ( unsafe { & mut * req } . 0 . headers_mut ( ) )
108
+ hyper_headers:: get_or_default ( unsafe { & mut * req } . 0 . extensions_mut ( ) )
100
109
}
101
110
}
102
111
@@ -114,7 +123,16 @@ ffi_fn! {
114
123
}
115
124
}
116
125
117
- // ===== impl Response =====
126
+ impl hyper_request {
127
+ pub ( super ) fn finalize_request ( & mut self ) {
128
+ if let Some ( headers) = self . 0 . extensions_mut ( ) . remove :: < hyper_headers > ( ) {
129
+ * self . 0 . headers_mut ( ) = headers. headers ;
130
+ self . 0 . extensions_mut ( ) . insert ( headers. orig_casing ) ;
131
+ }
132
+ }
133
+ }
134
+
135
+ // ===== impl hyper_response =====
118
136
119
137
ffi_fn ! {
120
138
/// Free an HTTP response after using it.
@@ -159,7 +177,7 @@ ffi_fn! {
159
177
/// This is not an owned reference, so it should not be accessed after the
160
178
/// `hyper_response` has been freed.
161
179
fn hyper_response_headers( resp: * mut hyper_response) -> * mut hyper_headers {
162
- hyper_headers:: wrap ( unsafe { & mut * resp } . 0 . headers_mut ( ) )
180
+ hyper_headers:: get_or_default ( unsafe { & mut * resp } . 0 . extensions_mut ( ) )
163
181
}
164
182
}
165
183
@@ -173,6 +191,22 @@ ffi_fn! {
173
191
}
174
192
}
175
193
194
+ impl hyper_response {
195
+ pub ( super ) fn wrap ( mut resp : Response < Body > ) -> hyper_response {
196
+ let headers = std:: mem:: take ( resp. headers_mut ( ) ) ;
197
+ let orig_casing = resp
198
+ . extensions_mut ( )
199
+ . remove :: < HeaderCaseMap > ( )
200
+ . unwrap_or_default ( ) ;
201
+ resp. extensions_mut ( ) . insert ( hyper_headers {
202
+ headers,
203
+ orig_casing,
204
+ } ) ;
205
+
206
+ hyper_response ( resp)
207
+ }
208
+ }
209
+
176
210
unsafe impl AsTaskType for hyper_response {
177
211
fn as_task_type ( & self ) -> hyper_task_return_type {
178
212
hyper_task_return_type:: HYPER_TASK_RESPONSE
@@ -185,9 +219,15 @@ type hyper_headers_foreach_callback =
185
219
extern "C" fn ( * mut c_void , * const u8 , size_t , * const u8 , size_t ) -> c_int ;
186
220
187
221
impl hyper_headers {
188
- pub ( crate ) fn wrap ( cx : & mut HeaderMap ) -> & mut hyper_headers {
189
- // A struct with only one field has the same layout as that field.
190
- unsafe { std:: mem:: transmute :: < & mut HeaderMap , & mut hyper_headers > ( cx) }
222
+ pub ( super ) fn get_or_default ( ext : & mut http:: Extensions ) -> & mut hyper_headers {
223
+ if let None = ext. get_mut :: < hyper_headers > ( ) {
224
+ ext. insert ( hyper_headers {
225
+ headers : Default :: default ( ) ,
226
+ orig_casing : Default :: default ( ) ,
227
+ } ) ;
228
+ }
229
+
230
+ ext. get_mut :: < hyper_headers > ( ) . unwrap ( )
191
231
}
192
232
}
193
233
@@ -199,14 +239,31 @@ ffi_fn! {
199
239
/// The callback should return `HYPER_ITER_CONTINUE` to keep iterating, or
200
240
/// `HYPER_ITER_BREAK` to stop.
201
241
fn hyper_headers_foreach( headers: * const hyper_headers, func: hyper_headers_foreach_callback, userdata: * mut c_void) {
202
- for ( name, value) in unsafe { & * headers } . 0 . iter( ) {
203
- let name_ptr = name. as_str( ) . as_bytes( ) . as_ptr( ) ;
204
- let name_len = name. as_str( ) . as_bytes( ) . len( ) ;
205
- let val_ptr = value. as_bytes( ) . as_ptr( ) ;
206
- let val_len = value. as_bytes( ) . len( ) ;
207
-
208
- if HYPER_ITER_CONTINUE != func( userdata, name_ptr, name_len, val_ptr, val_len) {
209
- break ;
242
+ let headers = unsafe { & * headers } ;
243
+ // For each header name/value pair, there may be a value in the casemap
244
+ // that corresponds to the HeaderValue. So, we iterator all the keys,
245
+ // and for each one, try to pair the originally cased name with the value.
246
+ //
247
+ // TODO: consider adding http::HeaderMap::entries() iterator
248
+ for name in headers. headers. keys( ) {
249
+ let mut names = headers. orig_casing. get_all( name) . iter( ) ;
250
+
251
+ for value in headers. headers. get_all( name) {
252
+ let ( name_ptr, name_len) = if let Some ( orig_name) = names. next( ) {
253
+ ( orig_name. as_ptr( ) , orig_name. len( ) )
254
+ } else {
255
+ (
256
+ name. as_str( ) . as_bytes( ) . as_ptr( ) ,
257
+ name. as_str( ) . as_bytes( ) . len( ) ,
258
+ )
259
+ } ;
260
+
261
+ let val_ptr = value. as_bytes( ) . as_ptr( ) ;
262
+ let val_len = value. as_bytes( ) . len( ) ;
263
+
264
+ if HYPER_ITER_CONTINUE != func( userdata, name_ptr, name_len, val_ptr, val_len) {
265
+ return ;
266
+ }
210
267
}
211
268
}
212
269
}
@@ -219,8 +276,9 @@ ffi_fn! {
219
276
fn hyper_headers_set( headers: * mut hyper_headers, name: * const u8 , name_len: size_t, value: * const u8 , value_len: size_t) -> hyper_code {
220
277
let headers = unsafe { & mut * headers } ;
221
278
match unsafe { raw_name_value( name, name_len, value, value_len) } {
222
- Ok ( ( name, value) ) => {
223
- headers. 0 . insert( name, value) ;
279
+ Ok ( ( name, value, orig_name) ) => {
280
+ headers. headers. insert( & name, value) ;
281
+ headers. orig_casing. insert( name, orig_name) ;
224
282
hyper_code:: HYPERE_OK
225
283
}
226
284
Err ( code) => code,
@@ -237,8 +295,9 @@ ffi_fn! {
237
295
let headers = unsafe { & mut * headers } ;
238
296
239
297
match unsafe { raw_name_value( name, name_len, value, value_len) } {
240
- Ok ( ( name, value) ) => {
241
- headers. 0 . append( name, value) ;
298
+ Ok ( ( name, value, orig_name) ) => {
299
+ headers. headers. append( & name, value) ;
300
+ headers. orig_casing. append( name, orig_name) ;
242
301
hyper_code:: HYPERE_OK
243
302
}
244
303
Err ( code) => code,
@@ -251,8 +310,9 @@ unsafe fn raw_name_value(
251
310
name_len : size_t ,
252
311
value : * const u8 ,
253
312
value_len : size_t ,
254
- ) -> Result < ( HeaderName , HeaderValue ) , hyper_code > {
313
+ ) -> Result < ( HeaderName , HeaderValue , Bytes ) , hyper_code > {
255
314
let name = std:: slice:: from_raw_parts ( name, name_len) ;
315
+ let orig_name = Bytes :: copy_from_slice ( name) ;
256
316
let name = match HeaderName :: from_bytes ( name) {
257
317
Ok ( name) => name,
258
318
Err ( _) => return Err ( hyper_code:: HYPERE_INVALID_ARG ) ,
@@ -263,5 +323,78 @@ unsafe fn raw_name_value(
263
323
Err ( _) => return Err ( hyper_code:: HYPERE_INVALID_ARG ) ,
264
324
} ;
265
325
266
- Ok ( ( name, value) )
326
+ Ok ( ( name, value, orig_name) )
327
+ }
328
+
329
+ // ===== impl HeaderCaseMap =====
330
+
331
+ impl HeaderCaseMap {
332
+ pub ( crate ) fn get_all ( & self , name : & HeaderName ) -> http:: header:: GetAll < ' _ , Bytes > {
333
+ self . 0 . get_all ( name)
334
+ }
335
+
336
+ pub ( crate ) fn insert ( & mut self , name : HeaderName , orig : Bytes ) {
337
+ self . 0 . insert ( name, orig) ;
338
+ }
339
+
340
+ pub ( crate ) fn append < N > ( & mut self , name : N , orig : Bytes )
341
+ where
342
+ N : http:: header:: IntoHeaderName ,
343
+ {
344
+ self . 0 . append ( name, orig) ;
345
+ }
346
+ }
347
+
348
+ #[ cfg( test) ]
349
+ mod tests {
350
+ use super :: * ;
351
+
352
+ #[ test]
353
+ fn test_headers_foreach_cases_preserved ( ) {
354
+ let mut headers = hyper_headers:: default ( ) ;
355
+
356
+ let name1 = b"Set-CookiE" ;
357
+ let value1 = b"a=b" ;
358
+ hyper_headers_add (
359
+ & mut headers,
360
+ name1. as_ptr ( ) ,
361
+ name1. len ( ) ,
362
+ value1. as_ptr ( ) ,
363
+ value1. len ( ) ,
364
+ ) ;
365
+
366
+ let name2 = b"SET-COOKIE" ;
367
+ let value2 = b"c=d" ;
368
+ hyper_headers_add (
369
+ & mut headers,
370
+ name2. as_ptr ( ) ,
371
+ name2. len ( ) ,
372
+ value2. as_ptr ( ) ,
373
+ value2. len ( ) ,
374
+ ) ;
375
+
376
+ let mut vec = Vec :: < u8 > :: new ( ) ;
377
+ hyper_headers_foreach ( & headers, concat, & mut vec as * mut _ as * mut c_void ) ;
378
+
379
+ assert_eq ! ( vec, b"Set-CookiE: a=b\r \n SET-COOKIE: c=d\r \n " ) ;
380
+
381
+ extern "C" fn concat (
382
+ vec : * mut c_void ,
383
+ name : * const u8 ,
384
+ name_len : usize ,
385
+ value : * const u8 ,
386
+ value_len : usize ,
387
+ ) -> c_int {
388
+ unsafe {
389
+ let vec = & mut * ( vec as * mut Vec < u8 > ) ;
390
+ let name = std:: slice:: from_raw_parts ( name, name_len) ;
391
+ let value = std:: slice:: from_raw_parts ( value, value_len) ;
392
+ vec. extend ( name) ;
393
+ vec. extend ( b": " ) ;
394
+ vec. extend ( value) ;
395
+ vec. extend ( b"\r \n " ) ;
396
+ }
397
+ HYPER_ITER_CONTINUE
398
+ }
399
+ }
267
400
}
0 commit comments