diff --git a/lib/body.ml b/lib/body.ml index dbca1c30..cdbc513f 100644 --- a/lib/body.ml +++ b/lib/body.ml @@ -37,13 +37,12 @@ type _ t = ; mutable write_final_if_chunked : bool ; mutable on_eof : unit -> unit ; mutable on_read : Bigstringaf.t -> off:int -> len:int -> unit - ; mutable when_ready_to_write : unit -> unit + ; mutable when_ready_to_write : Optional_thunk.t ; buffered_bytes : int ref } let default_on_eof = Sys.opaque_identity (fun () -> ()) let default_on_read = Sys.opaque_identity (fun _ ~off:_ ~len:_ -> ()) -let default_ready_to_write = Sys.opaque_identity (fun () -> ()) let of_faraday faraday = { faraday @@ -51,7 +50,7 @@ let of_faraday faraday = ; write_final_if_chunked = true ; on_eof = default_on_eof ; on_read = default_on_read - ; when_ready_to_write = default_ready_to_write + ; when_ready_to_write = Optional_thunk.none ; buffered_bytes = ref 0 } @@ -79,8 +78,8 @@ let schedule_bigstring t ?off ?len (b:Bigstringaf.t) = let ready_to_write t = let callback = t.when_ready_to_write in - t.when_ready_to_write <- default_ready_to_write; - callback () + t.when_ready_to_write <- Optional_thunk.none; + Optional_thunk.unchecked_value callback () let flush t kontinue = Faraday.flush t.faraday kontinue; @@ -145,11 +144,11 @@ let close_reader t = ;; let when_ready_to_write t callback = - if not (t.when_ready_to_write == default_ready_to_write) - then failwith "Body.when_ready_to_write: only one callback can be registered at a time" - else if is_closed t + if is_closed t then callback () - else t.when_ready_to_write <- callback + else if Optional_thunk.is_some t.when_ready_to_write + then failwith "Body.when_ready_to_write: only one callback can be registered at a time" + else t.when_ready_to_write <- Optional_thunk.some callback let transfer_to_writer_with_encoding t ~encoding writer = let faraday = t.faraday in diff --git a/lib/optional_thunk.ml b/lib/optional_thunk.ml new file mode 100644 index 00000000..5a8cf961 --- /dev/null +++ b/lib/optional_thunk.ml @@ -0,0 +1,11 @@ +type t = unit -> unit + +let none = Sys.opaque_identity (fun () -> ()) +let some f = + if f == none + then failwith "Optional_thunk: this function is not representable as a some value"; + f + +let is_none t = t == none +let is_some t = not (is_none t) +let unchecked_value t = t diff --git a/lib/optional_thunk.mli b/lib/optional_thunk.mli new file mode 100644 index 00000000..dd36f50d --- /dev/null +++ b/lib/optional_thunk.mli @@ -0,0 +1,9 @@ +type t + +val none : t +val some : (unit -> unit) -> t + +val is_none : t -> bool +val is_some : t -> bool + +val unchecked_value : t -> unit -> unit diff --git a/lib/reqd.ml b/lib/reqd.ml index f598f5b5..87dbbea0 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -35,7 +35,7 @@ type error = [ `Bad_request | `Bad_gateway | `Internal_server_error | `Exn of exn ] type response_state = - | Waiting of (unit -> unit) ref + | Waiting of Optional_thunk.t ref | Complete of Response.t | Streaming of Response.t * [`write] Body.t @@ -78,8 +78,6 @@ type t = ; mutable error_code : [`Ok | error ] } -let default_waiting = Sys.opaque_identity (fun () -> ()) - let create error_handler request request_body writer response_body_buffer = { request ; request_body @@ -87,14 +85,14 @@ let create error_handler request request_body writer response_body_buffer = ; response_body_buffer ; error_handler ; persistent = Request.persistent_connection request - ; response_state = Waiting (ref default_waiting) + ; response_state = Waiting (ref Optional_thunk.none) ; error_code = `Ok } let done_waiting when_done_waiting = let f = !when_done_waiting in - when_done_waiting := default_waiting; - f () + when_done_waiting := Optional_thunk.none; + Optional_thunk.unchecked_value f () let request { request; _ } = request let request_body { request_body; _ } = request_body @@ -213,9 +211,9 @@ let error_code t = let on_more_output_available t f = match t.response_state with | Waiting when_done_waiting -> - if not (!when_done_waiting == default_waiting) then - failwith "httpaf.Reqd.on_more_output_available: only one callback can be registered at a time"; - when_done_waiting := f + if Optional_thunk.is_some !when_done_waiting + then failwith "httpaf.Reqd.on_more_output_available: only one callback can be registered at a time"; + when_done_waiting := Optional_thunk.some f | Streaming(_, response_body) -> Body.when_ready_to_write response_body f | Complete _ -> diff --git a/lib/server_connection.ml b/lib/server_connection.ml index fbaaa8b7..d57b5c44 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -64,12 +64,10 @@ type t = ; request_queue : Reqd.t Queue.t (* invariant: If [request_queue] is not empty, then the head of the queue has already had [request_handler] called on it. *) - ; mutable wakeup_writer : (unit -> unit) - ; mutable wakeup_reader : (unit -> unit) + ; mutable wakeup_writer : Optional_thunk.t + ; mutable wakeup_reader : Optional_thunk.t } -let default_wakeup = Sys.opaque_identity (fun () -> ()) - let is_closed t = Reader.is_closed t.reader && Writer.is_closed t.writer @@ -85,35 +83,37 @@ let current_reqd_exn t = let yield_reader t k = if is_closed t then failwith "on_wakeup_reader on closed conn" - else if not (t.wakeup_reader == default_wakeup); + else if Optional_thunk.is_some t.wakeup_reader then failwith "yield_reader: only one callback can be registered at a time" - else t.wakeup_reader <- k + else t.wakeup_reader <- Optional_thunk.some k ;; let wakeup_reader t = let f = t.wakeup_reader in - t.wakeup_reader <- default_wakeup; - f () + t.wakeup_reader <- Optional_thunk.none; + Optional_thunk.unchecked_value f () ;; let on_wakeup_writer t k = if is_closed t then failwith "on_wakeup_writer on closed conn" - else if not (t.wakeup_writer == default_wakeup) + else if Optional_thunk.is_some t.wakeup_writer then failwith "yield_writer: only one callback can be registered at a time" - else t.wakeup_writer <- k + else t.wakeup_writer <- Optional_thunk.some k ;; let wakeup_writer t = let f = t.wakeup_writer in - t.wakeup_writer <- default_wakeup; - f () + t.wakeup_writer <- Optional_thunk.none; + Optional_thunk.unchecked_value f () ;; let transfer_writer_callback t reqd = - let f = t.wakeup_writer in - t.wakeup_writer <- default_wakeup; - Reqd.on_more_output_available reqd f + if Optional_thunk.is_some t.wakeup_writer + then ( + let f = t.wakeup_writer in + t.wakeup_writer <- Optional_thunk.none; + Reqd.on_more_output_available reqd (Optional_thunk.unchecked_value f)) ;; let default_error_handler ?request:_ error handle = @@ -149,8 +149,8 @@ let create ?(config=Config.default) ?(error_handler=default_error_handler) reque ; request_handler = request_handler ; error_handler = error_handler ; request_queue - ; wakeup_writer = default_wakeup - ; wakeup_reader = default_wakeup + ; wakeup_writer = Optional_thunk.none + ; wakeup_reader = Optional_thunk.none } let shutdown_reader t =