diff --git a/lib/message.ml b/lib/message.ml index acf8b98e..c24ddbe3 100644 --- a/lib/message.ml +++ b/lib/message.ml @@ -41,7 +41,7 @@ let persistent_connection ?(proxy=false) version headers = (* XXX(seliopou): use proxy argument in the case of HTTP/1.0 as per https://tools.ietf.org/html/rfc7230#section-6.3 *) match Headers.get headers "connection" with - | Some ("close" | "upgrade") -> false + | Some "close" -> false | Some "keep-alive" -> Version.(compare version v1_0) >= 0 | _ -> Version.(compare version v1_1) >= 0 diff --git a/lib/parse.ml b/lib/parse.ml index 925e5f02..664e3c3b 100644 --- a/lib/parse.ml +++ b/lib/parse.ml @@ -212,22 +212,13 @@ module Reader = struct | `Invalid_response_body_length of Response.t | `Parse of string list * string ] - type parser_result = - | Can_continue - | Stop - type 'error parse_state = - | Done of parser_result + | Done | Fail of 'error - | Partial of - (Bigstringaf.t - -> off:int - -> len:int - -> AU.more - -> (parser_result, 'error) result AU.state) + | Partial of (Bigstringaf.t -> off:int -> len:int -> AU.more -> (unit, 'error) result AU.state) type 'error t = - { parser : (parser_result, 'error) result Angstrom.t + { parser : (unit, 'error) result Angstrom.t ; mutable parse_state : 'error parse_state (* The state of the parse for the current request *) ; mutable closed : bool @@ -240,25 +231,24 @@ module Reader = struct let create parser = { parser - ; parse_state = Done Can_continue + ; parse_state = Done ; closed = false } + let ok = return (Ok ()) + let request handler = let parser = request <* commit >>= fun request -> match Request.body_length request with - | `Error `Bad_request -> - return (Error (`Bad_request request)) + | `Error `Bad_request -> return (Error (`Bad_request request)) | `Fixed 0L -> handler request Body.empty; - (* If the client has requested an upgrade, then any bytes after the headers are - likely not HTTP, so we should be careful not to try to parse them. *) - return (Ok (if Request.is_upgrade request then Stop else Can_continue)) + ok | `Fixed _ | `Chunked as encoding -> let request_body = Body.create_reader Bigstringaf.empty in handler request request_body; - body ~encoding request_body *> return (Ok Can_continue) + body ~encoding request_body *> ok in create parser @@ -271,13 +261,13 @@ module Reader = struct | `Error `Internal_server_error -> return (Error (`Invalid_response_body_length response)) | `Fixed 0L -> handler response Body.empty; - return (Ok Can_continue) + ok | `Fixed _ | `Chunked | `Close_delimited as encoding -> (* We do not trust the length provided in the [`Fixed] case, as the client could DOS easily. *) let response_body = Body.create_reader Bigstringaf.empty in handler response response_body; - body ~encoding response_body *> return (Ok Can_continue) + body ~encoding response_body *> ok in create parser ;; @@ -288,8 +278,8 @@ module Reader = struct let transition t state = match state with - | AU.Done(consumed, Ok result) -> - t.parse_state <- Done result; + | AU.Done(consumed, Ok ()) -> + t.parse_state <- Done; consumed | AU.Done(consumed, Error error) -> t.parse_state <- Fail error; @@ -313,8 +303,8 @@ module Reader = struct let rec read_with_more t bs ~off ~len more = let consumed = match t.parse_state with - | Fail _ | Done Stop -> 0 - | Done Can_continue -> + | Fail _ -> 0 + | Done -> start t (AU.parse t.parser); read_with_more t bs ~off ~len more; | Partial continue -> @@ -334,10 +324,11 @@ module Reader = struct let next t = if t.closed then `Close - else + else ( match t.parse_state with | Fail err -> `Error err - | Done Stop -> `Close - | Done Can_continue | Partial _ -> `Read + | Done -> `Read + | Partial _ -> `Read + ) ;; end diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index cbe969f4..4cd6c085 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -967,8 +967,8 @@ let test_upgrade () = let test_upgrade_where_server_does_not_upgrade () = let upgrade_headers =["Connection", "upgrade" ; "Upgrade", "foo"] in - let reqd = ref None in - let request_handler reqd' = reqd := Some reqd' in + let reqd_ref = ref None in + let request_handler reqd = reqd_ref := Some reqd in let t = create request_handler in read_request t (Request.create `GET "/" @@ -980,12 +980,17 @@ let test_upgrade_where_server_does_not_upgrade () = (* Now pretend the user doesn't want to do the upgrade and make sure we close the connection *) - let reqd = Option.get !reqd in - let response = Response.create `Bad_request ~headers:(Headers.of_list upgrade_headers) in + let reqd = Option.get !reqd_ref in + let response = Response.create `Bad_request ~headers:(Headers.encoding_fixed 0) in + Reqd.respond_with_string reqd response ""; + write_response t response; + + (* The connection is left healthy and can be used for more requests *) + read_request t (Request.create `GET "/" ~headers:(Headers.encoding_fixed 0)); + let reqd = Option.get !reqd_ref in + let response = Response.create `OK ~headers:(Headers.encoding_fixed 0) in Reqd.respond_with_string reqd response ""; write_response t response; - Alcotest.check read_operation "Reader is `Close" `Close (current_read_operation t); - Alcotest.check write_operation "Writer is `Close" (`Close 0) (current_write_operation t); ;; let tests =