diff --git a/lib/reqd.ml b/lib/reqd.ml index b8d915b3..40c989aa 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -44,6 +44,7 @@ end module Input_state = struct type t = + | Waiting | Ready | Complete | Upgraded @@ -189,11 +190,14 @@ let respond_with_streaming ?(flush_headers_immediately=false) t response = let respond_with_upgrade ?reason t headers = match t.response_state with | Waiting -> - let response = Response.create ?reason ~headers `Switching_protocols in - t.response_state <- Upgrade response; - Body.close_reader t.request_body; - Writer.write_response t.writer response; - Writer.wakeup t.writer; + if not (Request.is_upgrade t.request) then + failwith "httpaf.Reqd.respond_with_upgrade: request was not an upgrade request" + else ( + let response = Response.create ?reason ~headers `Switching_protocols in + t.response_state <- Upgrade response; + Body.close_reader t.request_body; + Writer.write_response t.writer response; + Writer.wakeup t.writer); | Streaming _ -> failwith "httpaf.Reqd.respond_with_upgrade: response already started" | Upgrade _ @@ -248,12 +252,23 @@ let persistent_connection t = t.persistent let input_state t : Input_state.t = - match t.response_state with - | Upgrade _ -> Upgraded - | Waiting | Fixed _ | Streaming _ -> + let upgrade_status = + match Request.is_upgrade t.request with + | false -> `Not_upgrading + | true -> + match t.response_state with + | Upgrade _ -> `Finished_upgrading + | Fixed _ | Streaming _ -> `Upgrade_declined + | Waiting -> `Upgrade_in_progress + in + match upgrade_status with + | `Finished_upgrading -> Upgraded + | `Not_upgrading | `Upgrade_declined -> if Body.is_closed t.request_body then Complete else Ready + | `Upgrade_in_progress -> + Waiting ;; let output_state t : Output_state.t = diff --git a/lib/server_connection.ml b/lib/server_connection.ml index 64f4ea80..53b308ea 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -197,6 +197,7 @@ let rec _next_read_operation t = ) else ( let reqd = current_reqd_exn t in match Reqd.input_state reqd with + | Waiting -> `Yield | Ready -> Reader.next t.reader | Complete -> _final_read_operation_for t reqd | Upgraded -> `Upgrade @@ -293,6 +294,7 @@ and _final_write_operation_for t reqd ~upgrade = Writer.next t.writer; ) else ( match Reqd.input_state reqd with + | Waiting -> `Yield | Ready -> Writer.next t.writer; | Upgraded -> `Upgrade | Complete -> diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index 33064c26..cbe969f4 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -967,16 +967,24 @@ let test_upgrade () = let test_upgrade_where_server_does_not_upgrade () = let upgrade_headers =["Connection", "upgrade" ; "Upgrade", "foo"] in - let response = Response.create `Bad_request ~headers:(Headers.of_list upgrade_headers) in - let request_handler reqd = - Reqd.respond_with_string reqd response "" - in + let reqd = ref None in + let request_handler reqd' = reqd := Some reqd' in let t = create request_handler in read_request t (Request.create `GET "/" ~headers:(Headers.of_list (("Content-Length", "0") :: upgrade_headers))); - Alcotest.check read_operation "Reader is `Close" `Close (current_read_operation t); + (* At this point, we don't know if the response handler will call respond_with_upgrade + or not. So we pause the reader until that is determined. *) + Alcotest.check read_operation "Reader is `Yield during upgrade negotiation" + `Yield (current_read_operation t); + + (* Now pretend the user doesn't want to do the upgrade and make sure we close the + connection *) + let reqd = Option.get !reqd in + let response = Response.create `Bad_request ~headers:(Headers.of_list upgrade_headers) in + Reqd.respond_with_string reqd response ""; write_response t response; + Alcotest.check read_operation "Reader is `Close" `Close (current_read_operation t); Alcotest.check write_operation "Writer is `Close" (`Close 0) (current_write_operation t); ;;