diff --git a/examples/lwt/dune b/examples/lwt/dune index 1494aee..90d1b71 100644 --- a/examples/lwt/dune +++ b/examples/lwt/dune @@ -1,7 +1,7 @@ (executables (libraries h1 h1-lwt-unix h1_examples base stdio lwt lwt.unix) (optional true) - (names lwt_get lwt_post lwt_echo_post lwt_echo_upgrade)) + (names lwt_get lwt_post lwt_echo_post lwt_echo_upgrade lwt_chunked)) (alias (name runtest) diff --git a/examples/lwt/lwt_chunked.ml b/examples/lwt/lwt_chunked.ml new file mode 100644 index 0000000..81c6d44 --- /dev/null +++ b/examples/lwt/lwt_chunked.ml @@ -0,0 +1,40 @@ +open Base +open Lwt.Infix +module Arg = Stdlib.Arg + +open H1_lwt_unix + +let request_handler (_ : Unix.sockaddr) reqd = + let body = H1.Reqd.respond_with_streaming reqd (H1.Response.create ~headers:(H1.Headers.of_list ["connection", "close"]) `OK) in + let rec respond_loop i = + H1.Body.Writer.write_string body (Printf.sprintf "Chunk %i\n" i); + H1.Body.Writer.flush_with_reason body (function + | `Closed -> Stdio.print_endline "closed" + | `Written -> Stdio.print_endline "written"; Lwt.bind (Lwt_unix.sleep 5.) (fun () -> respond_loop (i+1)) |> ignore + ); + Lwt.return_unit + in ignore (respond_loop 0) + +let error_handler (_ : Unix.sockaddr) = H1_examples.Server.error_handler + +let main port = + let listen_address = Unix.(ADDR_INET (inet_addr_loopback, port)) in + Lwt.async (fun () -> + Lwt_io.establish_server_with_client_socket + listen_address + (Server.create_connection_handler ~upgrade_handler:None ~request_handler ~error_handler) + >|= fun _server -> + Stdio.printf "Listening on port %i.\n" port); + let forever, _ = Lwt.wait () in + Lwt_main.run forever +;; + +let () = + Stdlib.Sys.set_signal Stdlib.Sys.sigpipe Stdlib.Sys.Signal_ignore; + let port = ref 8080 in + Arg.parse + ["-p", Arg.Set_int port, " Listening port number (8080 by default)"] + ignore + "Echoes POST requests. Runs forever."; + main !port +;; diff --git a/lib/body.ml b/lib/body.ml index d5f0424..63fe27f 100644 --- a/lib/body.ml +++ b/lib/body.ml @@ -102,18 +102,20 @@ module Reader = struct end module Writer = struct + module Writer = Serialize.Writer + type encoding = | Identity | Chunked of { mutable written_final_chunk : bool } type t = - { faraday : Faraday.t - ; encoding : encoding - ; when_ready_to_write : unit -> unit - ; buffered_bytes : int ref + { faraday : Faraday.t + ; writer : Writer.t + ; encoding : encoding + ; buffered_bytes : int ref } - let of_faraday faraday ~encoding ~when_ready_to_write = + let of_faraday faraday writer ~encoding = let encoding = match encoding with | `Fixed _ | `Close_delimited -> Identity @@ -121,34 +123,57 @@ module Writer = struct in { faraday ; encoding - ; when_ready_to_write + ; writer ; buffered_bytes = ref 0 } - let create buffer ~encoding ~when_ready_to_write = - of_faraday (Faraday.of_bigstring buffer) ~encoding ~when_ready_to_write + let create buffer writer ~encoding = + of_faraday (Faraday.of_bigstring buffer) writer ~encoding let write_char t c = - Faraday.write_char t.faraday c + if not (Faraday.is_closed t.faraday) then + Faraday.write_char t.faraday c let write_string t ?off ?len s = - Faraday.write_string ?off ?len t.faraday s + if not (Faraday.is_closed t.faraday) then + Faraday.write_string ?off ?len t.faraday s let write_bigstring t ?off ?len b = - Faraday.write_bigstring ?off ?len t.faraday b + if not (Faraday.is_closed t.faraday) then + Faraday.write_bigstring ?off ?len t.faraday b let schedule_bigstring t ?off ?len (b:Bigstringaf.t) = - Faraday.schedule_bigstring ?off ?len t.faraday b + if not (Faraday.is_closed t.faraday) then + Faraday.schedule_bigstring ?off ?len t.faraday b - let ready_to_write t = t.when_ready_to_write () + let ready_to_write t = Writer.wakeup t.writer let flush t kontinue = Faraday.flush t.faraday kontinue; ready_to_write t + let flush_with_reason t kontinue = + if Writer.is_closed t.writer then + kontinue `Closed + else begin + Faraday.flush_with_reason t.faraday (fun reason -> + let result = + match reason with + | Nothing_pending | Shift -> `Written + | Drain -> `Closed + in + kontinue result); + ready_to_write t + end + let is_closed t = Faraday.is_closed t.faraday + let close_and_drain t = + Faraday.close t.faraday; + (* Resolve all pending flushes *) + ignore (Faraday.drain t.faraday : int) + let close t = Faraday.close t.faraday; ready_to_write t; @@ -166,33 +191,39 @@ module Writer = struct in faraday_has_output || additional_encoding_output - let transfer_to_writer t writer = + let transfer_to_writer t = let faraday = t.faraday in - begin match Faraday.operation faraday with - | `Yield -> () - | `Close -> - (match t.encoding with - | Identity -> () - | Chunked ({ written_final_chunk } as chunked) -> - if not written_final_chunk then begin - chunked.written_final_chunk <- true; - Serialize.Writer.schedule_chunk writer []; - end); - Serialize.Writer.unyield writer; - | `Writev iovecs -> - let buffered = t.buffered_bytes in - begin match IOVec.shiftv iovecs !buffered with - | [] -> () - | iovecs -> - let lengthv = IOVec.lengthv iovecs in - buffered := !buffered + lengthv; - begin match t.encoding with - | Identity -> Serialize.Writer.schedule_fixed writer iovecs - | Chunked _ -> Serialize.Writer.schedule_chunk writer iovecs - end; - Serialize.Writer.flush writer (fun () -> - Faraday.shift faraday lengthv; - buffered := !buffered - lengthv) - end + if Writer.is_closed t.writer then + close_and_drain t + else begin + match Faraday.operation faraday with + | `Yield -> () + | `Close -> + (match t.encoding with + | Identity -> () + | Chunked ({ written_final_chunk } as chunked) -> + if not written_final_chunk then begin + chunked.written_final_chunk <- true; + Serialize.Writer.schedule_chunk t.writer []; + end); + Serialize.Writer.unyield t.writer; + | `Writev iovecs -> + let buffered = t.buffered_bytes in + begin match IOVec.shiftv iovecs !buffered with + | [] -> () + | iovecs -> + let lengthv = IOVec.lengthv iovecs in + buffered := !buffered + lengthv; + begin match t.encoding with + | Identity -> Serialize.Writer.schedule_fixed t.writer iovecs + | Chunked _ -> Serialize.Writer.schedule_chunk t.writer iovecs + end; + Serialize.Writer.flush t.writer (fun result -> + match result with + | `Closed -> close_and_drain t + | `Written -> + Faraday.shift faraday lengthv; + buffered := !buffered - lengthv) + end end end diff --git a/lib/client_connection.ml b/lib/client_connection.ml index 75c3635..9a1d65b 100644 --- a/lib/client_connection.ml +++ b/lib/client_connection.ml @@ -71,8 +71,8 @@ module Oneshot = struct | `Error `Bad_request -> failwith "H1.Client_connection.request: invalid body length" in - Body.Writer.create (Bigstringaf.create config.request_body_buffer_size) - ~encoding ~when_ready_to_write:(fun () -> Writer.wakeup writer) + Body.Writer.create (Bigstringaf.create config.request_body_buffer_size) writer + ~encoding in let t = { request @@ -89,7 +89,7 @@ module Oneshot = struct let flush_request_body t = if Body.Writer.has_pending_output t.request_body - then Body.Writer.transfer_to_writer t.request_body t.writer + then Body.Writer.transfer_to_writer t.request_body ;; let set_error_and_handle_without_shutdown t error = diff --git a/lib/h1.mli b/lib/h1.mli index 8aa2fec..97052d7 100644 --- a/lib/h1.mli +++ b/lib/h1.mli @@ -494,23 +494,29 @@ module Body : sig modified until a subsequent call to {!flush} has successfully completed. *) - val flush : t -> (unit -> unit) -> unit - (** [flush t f] makes all bytes in [t] available for writing to the awaiting - output channel. Once those bytes have reached that output channel, [f] - will be called. + val flush_with_reason : t -> ([ `Written | `Closed ] -> unit) -> unit + (** [flush_with_reason t f] makes all bytes in [t] available for writing to the awaiting output + channel. Once those bytes have reached that output channel, [f `Written] will be + called. If instead, the output channel is closed before all of those bytes are + successfully written, [f `Closed] will be called. The type of the output channel is runtime-dependent, as are guarantees about whether those packets have been queued for delivery or have actually been received by the intended recipient. *) + val flush: t -> (unit -> unit) -> unit + (** [flush t f] is identical to [flush_with_reason t], except ignoring the result of the flush. + In most situations, you should use flush_with_reason and properly handle a closed output channel. *) + val close : t -> unit (** [close t] closes [t], causing subsequent write calls to raise. If [t] is writable, this will cause any pending output to become available to the output channel. *) val is_closed : t -> bool - (** [is_closed t] is [true] if {!close} has been called on [t] and [false] - otherwise. A closed [t] may still have pending output. *) + (** [is_closed t] is [true] if {!close} has been called on [t], or if the attached + output channel is closed (e.g. because [report_write_result `Closed] has been + called). A closed [t] may still have pending output. *) end end diff --git a/lib/reqd.ml b/lib/reqd.ml index 1de19b2..544bb2b 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -175,8 +175,7 @@ let unsafe_respond_with_streaming ~flush_headers_immediately t response = "H1.Reqd.respond_with_streaming: invalid response body length" in let response_body = - Body.Writer.create t.response_body_buffer ~encoding - ~when_ready_to_write:(fun () -> Writer.wakeup t.writer) + Body.Writer.create t.response_body_buffer t.writer ~encoding in Writer.write_response t.writer response; if t.persistent then @@ -288,6 +287,5 @@ let flush_request_body t = let flush_response_body t = match t.response_state with - | Streaming (_, response_body) -> - Body.Writer.transfer_to_writer response_body t.writer + | Streaming (_, response_body) -> Body.Writer.transfer_to_writer response_body | _ -> () diff --git a/lib/serialize.ml b/lib/serialize.ml index 76039b9..4f3c231 100644 --- a/lib/serialize.ml +++ b/lib/serialize.ml @@ -89,18 +89,18 @@ let schedule_bigstring_chunk t chunk = module Writer = struct type t = { buffer : Bigstringaf.t - (* The buffer that the encoder uses for buffered writes. Managed by the - * control module for the encoder. *) + (* The buffer that the encoder uses for buffered writes. Managed by the + * control module for the encoder. *) ; encoder : Faraday.t - (* The encoder that handles encoding for writes. Uses the [buffer] - * referenced above internally. *) + (* The encoder that handles encoding for writes. Uses the [buffer] + * referenced above internally. *) ; mutable drained_bytes : int - (* The number of bytes that were not written due to the output stream - * being closed before all buffered output could be written. Useful for - * detecting error cases. *) + (* The number of bytes that were not written due to the output stream + * being closed before all buffered output could be written. Useful for + * detecting error cases. *) ; mutable wakeup : Optional_thunk.t - (* The callback from the runtime to be invoked when output is ready to be - * flushed. *) + (* The callback from the runtime to be invoked when output is ready to be + * flushed. *) } let create ?(buffer_size=0x800) () = @@ -158,13 +158,19 @@ module Writer = struct ;; let flush t f = - flush t.encoder f + flush_with_reason t.encoder (fun reason -> + let result = + match reason with + | Nothing_pending | Shift -> `Written + | Drain -> `Closed + in + f result) let unyield t = (* This would be better implemented by a function that just takes the encoder out of a yielded state if it's in that state. Requires a change to the faraday library. *) - flush t (fun () -> ()) + flush t (fun _result -> ()) let yield t = Faraday.yield t.encoder diff --git a/lib/server_connection.ml b/lib/server_connection.ml index 71f7b58..ca2de2e 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -177,8 +177,9 @@ let set_error_and_handle ?request t error = "H1.Server_connection.error_handler: invalid response body \ length" in - Body.Writer.of_faraday (Writer.faraday writer) ~encoding - ~when_ready_to_write:(fun () -> Writer.wakeup writer))) + Body.Writer.of_faraday (Writer.faraday writer) writer ~encoding + ) + ) let report_exn t exn = set_error_and_handle t (`Exn exn) diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index bd41ce6..5de4a0e 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -292,12 +292,6 @@ let connection_is_shutdown t = writer_closed t; ;; -let raises_writer_closed f = - (* This is raised when you write to a closed [Faraday.t] *) - Alcotest.check_raises "raises because writer is closed" - (Failure "cannot write to closed writer") f -;; - let request_handler_with_body body reqd = Body.Reader.close (Reqd.request_body reqd); Reqd.respond_with_string reqd (Response.create `OK) body @@ -312,7 +306,7 @@ let echo_handler response reqd = let response_body = Reqd.respond_with_streaming reqd response in let rec on_read buffer ~off ~len = Body.Writer.write_string response_body (Bigstringaf.substring ~off ~len buffer); - Body.Writer.flush response_body (fun () -> + Body.Writer.flush response_body (fun _ -> Body.Reader.schedule_read request_body ~on_eof ~on_read) and on_eof () = print_endline "echo handler eof"; @@ -332,7 +326,9 @@ let streaming_handler ?(flush=false) response writes reqd = | w :: ws -> Body.Writer.write_string body w; writes := ws; - Body.Writer.flush body write + Body.Writer.flush_with_reason body (function + | `Closed -> () + | `Written -> write ()) in write (); ;; @@ -772,9 +768,11 @@ let test_chunked_encoding () = let response = Response.create `OK ~headers:Headers.encoding_chunked in let resp_body = Reqd.respond_with_streaming reqd response in Body.Writer.write_string resp_body "First chunk"; - Body.Writer.flush resp_body (fun () -> - Body.Writer.write_string resp_body "Second chunk"; - Body.Writer.close resp_body); + Body.Writer.flush_with_reason resp_body (function + | `Closed -> assert false + | `Written -> + Body.Writer.write_string resp_body "Second chunk"; + Body.Writer.close resp_body); in let t = create ~error_handler request_handler in writer_yielded t; @@ -801,9 +799,11 @@ let test_chunked_encoding_for_error () = `Bad_request error; let body = start_response Headers.encoding_chunked in Body.Writer.write_string body "Bad"; - Body.Writer.flush body (fun () -> - Body.Writer.write_string body " request"; - Body.Writer.close body); + Body.Writer.flush_with_reason body (function + | `Closed -> assert false + | `Written -> + Body.Writer.write_string body " request"; + Body.Writer.close body); in let t = create ~error_handler (fun _ -> assert false) in let c = feed_string t " X\r\n\r\n" in @@ -838,6 +838,53 @@ let test_blocked_write_on_chunked_encoding () = write_string t ~msg:"second write" second_write ;; +let test_body_writing_when_socket_closes () = + let response = Response.create `OK ~headers:Headers.encoding_chunked in + let body_ref = ref None in + let request_handler reqd = + let body = Reqd.respond_with_streaming reqd response in + body_ref := Some body + in + let t = create request_handler in + writer_yielded t; + read_request t (Request.create `GET "/"); + + let flush_result_testable = + Alcotest.of_pp + (Fmt.using (function `Closed -> "Closed" | `Written -> "Written") Fmt.string) + in + + let body = Option.get !body_ref in + let check_flush ~expect service_writer = + let flush_result = ref None in + Body.Writer.flush_with_reason body (fun r -> flush_result := Some r); + service_writer (); + Alcotest.(check' (option flush_result_testable)) + ~msg:"flush_result is as expected" + ~expected:(Some expect) + ~actual:!flush_result; + in + + Body.Writer.write_string body "First chunk"; + check_flush (fun () -> + write_response t + ~msg:"First chunk written" + ~body:"b\r\nFirst chunk\r\n" + response) + ~expect:`Written; + + Body.Writer.write_string body "Second chunk"; + check_flush (fun () -> write_eof t) ~expect:`Closed; + + (* Writing after the writer is closed does not raise, but flushes get immediately + resolved with `Closed. *) + Body.Writer.write_string body "Chunk after closed"; + check_flush (fun () -> ()) ~expect:`Closed; + + Body.Writer.close body; + check_flush (fun () -> ()) ~expect:`Closed; +;; + let test_unexpected_eof () = let t = create default_request_handler in read_request t (Request.create `GET "/"); @@ -1087,41 +1134,6 @@ let test_shutdown_in_request_handler () = writer_closed t ;; -let test_shutdown_during_asynchronous_request () = - let request = Request.create `GET "/" in - let response = Response.create `OK in - let continue = ref (fun () -> ()) in - let t = create (fun reqd -> - continue := (fun () -> - Reqd.respond_with_string reqd response "")) - in - read_request t request; - shutdown t; - raises_writer_closed !continue; - reader_closed t; - writer_closed t -;; - -let test_flush_response_before_shutdown () = - let request = Request.create `GET "/" ~headers:(Headers.encoding_fixed 0) in - let response = Response.create `OK ~headers:Headers.encoding_chunked in - let continue = ref (fun () -> ()) in - let request_handler reqd = - let body = Reqd.respond_with_streaming ~flush_headers_immediately:true reqd response in - continue := (fun () -> - Body.Writer.write_string body "hello world"; - Body.Writer.close body); - in - let t = create request_handler in - read_request t request; - write_response t response; - !continue (); - shutdown t; - raises_writer_closed (fun () -> - write_string t "b\r\nhello world\r\n"; - connection_is_shutdown t); -;; - let test_schedule_read_with_data_available () = let response = Response.create `OK in let body = ref None in @@ -1267,6 +1279,7 @@ let tests = ; "chunked encoding", `Quick, test_chunked_encoding ; "chunked encoding for error", `Quick, test_chunked_encoding_for_error ; "blocked write on chunked encoding", `Quick, test_blocked_write_on_chunked_encoding + ; "body writing when socket closes", `Quick, test_body_writing_when_socket_closes ; "writer unexpected eof", `Quick, test_unexpected_eof ; "input shrunk", `Quick, test_input_shrunk ; "failed request parse", `Quick, test_failed_request_parse @@ -1279,8 +1292,6 @@ let tests = ; "parse failure at eof", `Quick, test_parse_failure_at_eof ; "response finished before body read", `Quick, test_response_finished_before_body_read ; "shutdown in request handler", `Quick, test_shutdown_in_request_handler - ; "shutdown during asynchronous request", `Quick, test_shutdown_during_asynchronous_request - ; "flush response before shutdown", `Quick, test_flush_response_before_shutdown ; "schedule read with data available", `Quick, test_schedule_read_with_data_available ; "test upgrades", `Quick, test_upgrade ; "test upgrade where server does not upgrade", `Quick, test_upgrade_where_server_does_not_upgrade