diff --git a/examples/eio/dune b/examples/eio/dune index 1b01b8b..cf34605 100644 --- a/examples/eio/dune +++ b/examples/eio/dune @@ -1,6 +1,11 @@ (executables (libraries httpun httpun-eio httpun_examples base stdio eio_main eio-ssl) - (names eio_echo_post eio_get eio_ssl_get) + (names + eio_echo_post + eio_get + eio_ssl_get + eio_connect_server + eio_connect_client) (flags :standard -warn-error -A)) (alias diff --git a/examples/eio/eio_connect_client.ml b/examples/eio/eio_connect_client.ml new file mode 100644 index 0000000..d0e2f21 --- /dev/null +++ b/examples/eio/eio_connect_client.ml @@ -0,0 +1,102 @@ +module Arg = Stdlib.Arg +open Httpun + +let http_handler ~on_eof response response_body = + match response with + | { Response.status = `OK; _ } as response -> + Format.eprintf "response: %a@." Response.pp_hum response; + let rec on_read bs ~off ~len = + Bigstringaf.substring ~off ~len bs |> print_string; + flush stdout; + Body.Reader.schedule_read response_body ~on_read ~on_eof + in + Body.Reader.schedule_read response_body ~on_read ~on_eof + | response -> + Format.fprintf Format.err_formatter "%a\n%!" Response.pp_hum response; + Stdlib.exit 124 + +let proxy_handler _env ~sw ~headers flow ~on_eof response _response_body = + Format.eprintf "CONNECT response: %a@." Response.pp_hum response; + match response with + | { Response.status = `OK; _ } as response -> + (* This means we can now communicate via any protocol on the socket since + the server approved the tunnel. + + We'll be boring and use HTTP/1.1 again. *) + let connection = Httpun_eio.Client.create_connection ~sw flow in + let exit_cond = Eio.Condition.create () in + Eio.Fiber.fork ~sw (fun () -> + let response_handler = + http_handler ~on_eof:(fun () -> + Stdlib.Format.eprintf "http eof@."; + Eio.Condition.broadcast exit_cond; + on_eof ()) + in + let request_body = + Httpun_eio.Client.request + ~flush_headers_immediately:true + ~error_handler:Httpun_examples.Client.error_handler + ~response_handler + connection + (Request.create ~headers `GET "/") + in + Body.Writer.close request_body); + Eio.Condition.await_no_mutex exit_cond; + Httpun_eio.Client.shutdown connection |> Eio.Promise.await + | _response -> Stdlib.exit 124 + +let main port proxy_host = + let real_host = "example.com:80" in + Eio_main.run (fun _env -> + Eio.Switch.run (fun sw -> + let fd = Unix.socket ~cloexec:true Unix.PF_INET Unix.SOCK_STREAM 0 in + let addrs = + Eio_unix.run_in_systhread (fun () -> + Unix.getaddrinfo + proxy_host + (Int.to_string port) + [ Unix.(AI_FAMILY PF_INET) ]) + in + Eio_unix.run_in_systhread (fun () -> + Unix.connect fd (List.hd addrs).ai_addr); + let socket = Eio_unix.Net.import_socket_stream ~sw ~close_unix:true fd in + let headers = Headers.of_list [ "host", real_host ] in + let connection = Httpun_eio.Client.create_connection ~sw socket in + + let exit_cond = Eio.Condition.create () in + Eio.Fiber.fork ~sw (fun ()-> + let response_handler = + fun response response_body -> + Eio.Fiber.fork ~sw @@ fun () -> + proxy_handler _env ~sw socket ~headers ~on_eof:(fun () -> + Stdlib.Format.eprintf "(connect) eof@."; + Eio.Condition.broadcast exit_cond) + response + response_body + in + let request_body = + Httpun_eio.Client.request + ~flush_headers_immediately:true + ~error_handler:Httpun_examples.Client.error_handler + ~response_handler + connection + (Request.create ~headers `CONNECT real_host) + in + Body.Writer.close request_body; + Eio.Condition.await_no_mutex exit_cond; + + Httpun_eio.Client.shutdown connection |> Eio.Promise.await))) + +let () = + let host = ref None in + let port = ref 80 in + Arg.parse + [ "-p", Set_int port, " Port number (80 by default)" ] + (fun host_argument -> host := Some host_argument) + "lwt_get.exe [-p N] HOST"; + let host = + match !host with + | None -> failwith "No hostname provided" + | Some host -> host + in + main !port host diff --git a/examples/eio/eio_connect_server.ml b/examples/eio/eio_connect_server.ml new file mode 100644 index 0000000..eb8b141 --- /dev/null +++ b/examples/eio/eio_connect_server.ml @@ -0,0 +1,120 @@ +(* curl -v -p -x http://localhost:8080 http://example.com *) +open Base +module Arg = Stdlib.Arg +open Httpun_eio +open Httpun + +let error_handler (_ : Eio.Net.Sockaddr.stream) = + Httpun_examples.Server.error_handler + +let request_handler + env + ~sw + ~u + flow + (_ : Eio.Net.Sockaddr.stream) + { Gluten.reqd; _ } + = + match Reqd.request reqd with + | { Request.meth = `CONNECT; headers; _ } -> + Stdlib.Format.eprintf "x: %a@." Request.pp_hum (Reqd.request reqd); + let host, port = + let host_and_port = Headers.get_exn headers "host" in + let[@ocaml.warning "-8"] [ host; port ] = + String.split_on_chars ~on:[ ':' ] host_and_port + in + host, port + in + let () = + (* todo: try/with *) + let p, u' = Eio.Promise.create () in + Eio.Fiber.fork ~sw (fun () -> + Eio.Net.with_tcp_connect + (Eio.Stdenv.net env) + ~host + ~service:port + (fun upstream -> + Eio.Promise.resolve u' (); + Stdlib.Format.eprintf + "connected to upstream %s (port %s)@." + host + port; + Eio.Fiber.both + (fun () -> Eio.Flow.copy flow upstream) + (fun () -> Eio.Flow.copy upstream flow); + Eio.Promise.resolve_ok u ())); + Eio.Promise.await p + in + Reqd.respond_with_string reqd (Response.create `OK) "" + | _ -> + let headers = Headers.of_list [ "connection", "close" ] in + Reqd.respond_with_string + reqd + (Response.create ~headers `Method_not_allowed) + "" + +let log_connection_error ex = + Eio.traceln "Uncaught exception handling client: %a" Fmt.exn ex + +let main port = + Eio_main.run (fun env -> + let listen_address = `Tcp (Eio.Net.Ipaddr.V4.loopback, port) in + let network = Eio.Stdenv.net env in + let handler ~u = + fun ~sw client_addr socket -> + let request_handler = request_handler env ~sw ~u socket in + Server.create_connection_handler + ~request_handler + ~error_handler + ~sw + client_addr + socket + in + Eio.Switch.run (fun sw -> + let socket = + Eio.Net.listen + ~reuse_addr:true + ~reuse_port:false + ~backlog:5 + ~sw + network + listen_address + in + Stdio.printf "Listening on port %i and echoing POST requests.\n" port; + Stdio.printf "To send a POST request, try one of the following\n\n"; + Stdio.printf + " echo \"Testing echo POST\" | dune exec examples/async/async_post.exe\n"; + Stdio.printf + " echo \"Testing echo POST\" | dune exec examples/lwt/lwt_post.exe\n"; + Stdio.printf + " echo \"Testing echo POST\" | curl -XPOST --data @- \ + http://localhost:%d\n\n\ + %!" + port; + let domain_mgr = Eio.Stdenv.domain_mgr env in + let p, _ = Eio.Promise.create () in + for _i = 1 to Stdlib.Domain.recommended_domain_count () do + Eio.Fiber.fork_daemon ~sw (fun () -> + Eio.Domain_manager.run domain_mgr (fun () -> + Eio.Switch.run (fun sw -> + while true do + Eio.Net.accept_fork + socket + ~sw + ~on_error:log_connection_error + (fun client_sock client_addr -> + let p, u = Eio.Promise.create () in + handler ~sw ~u client_addr client_sock; + Eio.Promise.await_exn p) + done; + `Stop_daemon))) + done; + Eio.Promise.await p)) + +let () = + let port = ref 8080 in + Arg.parse + [ "-p", Arg.Set_int port, " Listening port number (8080 by default)" ] + ignore + "Echoes POST requests. Runs forever."; + main !port diff --git a/lib/reqd.ml b/lib/reqd.ml index ef3c54e..b1929b5 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -257,8 +257,8 @@ let persistent_connection t = t.persistent let input_state t : Io_state.t = - match t.response_state with - | Upgrade _ -> Ready + match t.response_state, t.request.meth with + | Upgrade _,_ | _, `CONNECT -> Wait | _ -> if Body.Reader.is_closed t.request_body then Complete @@ -266,8 +266,11 @@ let input_state t : Io_state.t = then Ready else Wait -let output_state { response_state; writer; _ } = - Response_state.output_state response_state ~writer +let output_state { request; response_state; writer; _ } = + Response_state.output_state + response_state + ~request_method:request.meth + ~writer let flush_request_body t = if Body.Reader.has_pending_output t.request_body diff --git a/lib/respd.ml b/lib/respd.ml index 7c0499e..ebc39ee 100644 --- a/lib/respd.ml +++ b/lib/respd.ml @@ -25,13 +25,18 @@ type t = ; mutable persistent : bool } -let create error_handler request request_body writer response_handler = +let create error_handler (request: Request.t) request_body writer response_handler = let rec handler response body = let t = Lazy.force t in if t.persistent then t.persistent <- Response.persistent_connection response; - let next_state : Request_state.t = match response.status with - | `Switching_protocols -> + let next_state : Request_state.t = match request.meth, response.status with + (* From RFC9110ยง6.4.1: + * 2xx (Successful) responses to a CONNECT request method (Section + * 9.3.6) switch the connection to tunnel mode instead of having + * content. *) + | `CONNECT, #Status.successful + | _, `Switching_protocols -> Upgraded response | _ -> Received_response (response, body) @@ -103,9 +108,7 @@ let input_state t : Io_state.t = else if Body.Reader.is_read_scheduled response_body then Ready else Wait - (* Upgraded is "Complete" because the descriptor doesn't wish to receive - * any more input. *) - | Upgraded _ + | Upgraded _ -> Wait | Closed -> Complete let output_state { request_body; state; writer; _ } : Io_state.t = diff --git a/lib/response_state.ml b/lib/response_state.ml index a88f305..80c3ba7 100644 --- a/lib/response_state.ml +++ b/lib/response_state.ml @@ -4,18 +4,23 @@ type t = | Streaming of Response.t * Body.Writer.t | Upgrade of Response.t * (unit -> unit) -let output_state t ~writer : Io_state.t = - match t with - | Fixed _ -> Complete - | Waiting -> - if Serialize.Writer.is_closed writer then Complete - else Wait - | Streaming(_, response_body) -> - if Serialize.Writer.is_closed writer then Complete - else if Body.Writer.requires_output response_body - then Ready - else Complete - | Upgrade _ -> Ready +let output_state = + let response_sent_state = function + | `CONNECT -> Io_state.Wait + | _ -> Complete + in + fun t ~request_method ~writer : Io_state.t -> + match t with + | Upgrade _ -> Wait + | Waiting -> + if Serialize.Writer.is_closed writer then Complete + else Wait + | Fixed _ -> response_sent_state request_method + | Streaming(_, response_body) -> + if Serialize.Writer.is_closed writer then response_sent_state request_method + else if Body.Writer.requires_output response_body + then Ready + else response_sent_state request_method let flush_response_body t = match t with diff --git a/lib/server_connection.ml b/lib/server_connection.ml index e737627..3398517 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -338,8 +338,18 @@ let rec _next_write_operation t = if Reader.is_closed t.reader then shutdown t; Writer.next t.writer - | Error { response_state; _ } -> - match Response_state.output_state response_state ~writer:t.writer with + | Error { request; response_state } -> + match + let request_method = + Option.value + ~default:`GET + (Option.map (fun (t: Request.t) -> t.meth) request) + in + Response_state.output_state + response_state + ~request_method + ~writer:t.writer + with | Wait -> `Yield | Ready -> flush_response_error_body response_state; diff --git a/lib_test/test_client_connection.ml b/lib_test/test_client_connection.ml index 17e278c..f3e0a18 100644 --- a/lib_test/test_client_connection.ml +++ b/lib_test/test_client_connection.ml @@ -1866,6 +1866,41 @@ let test_read_response_before_shutdown () = connection_is_shutdown t; ;; +let test_client_connect () = + let writer_woken_up = ref false in + let reader_woken_up = ref false in + let request' = Request.create + ~headers:(Headers.of_list ["host", "example.com:80"]) + `CONNECT "/" + in + let t = create () in + let response = Response.create `OK in + let body = + request + t + request' + ~flush_headers_immediately:true + ~response_handler:(default_response_handler response) + ~error_handler:no_error_handler + in + write_request t request'; + writer_yielded t; + Body.Writer.close body; + reader_ready t; + read_response t response; + reader_yielded t; + yield_reader t (fun () -> reader_woken_up := true); + writer_yielded t; + yield_writer t (fun () -> writer_woken_up := true); + Alcotest.(check bool) "Reader hasn't woken up yet" false !reader_woken_up; + Alcotest.(check bool) "Writer hasn't woken up yet" false !writer_woken_up; + shutdown t; + Alcotest.(check bool) "Reader woken up" true !reader_woken_up; + Alcotest.(check bool) "Writer woken up" true !writer_woken_up; + connection_is_shutdown t; +;; + + let tests = [ "commit parse after every header line", `Quick, test_commit_parse_after_every_header ; "GET" , `Quick, test_get @@ -1914,4 +1949,5 @@ let tests = ; "shut down closes request body ", `Quick, test_read_response_before_shutdown ; "report exn during body read", `Quick, test_report_exn_during_body_read ; "read response after write eof", `Quick, test_can_read_response_after_write_eof + ; "Client support for CONNECT", `Quick, test_client_connect ] diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index c5a2aac..22a0863 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -957,7 +957,7 @@ let test_respond_with_upgrade () = read_request t (Request.create `GET "/"); write_response ~msg:"Upgrade response written" t (Response.create `Switching_protocols); Alcotest.(check bool) "Callback was called" true !upgraded; - reader_ready t; + reader_yielded t; ;; let test_unexpected_eof () = @@ -2438,6 +2438,24 @@ let test_write_response_after_read_eof () = connection_is_shutdown t; ;; +let test_connect_method () = + let upgraded = ref false in + let upgrade_handler reqd = + Reqd.respond_with_upgrade reqd Headers.empty (fun () -> + upgraded := true) + in + let t = create ~error_handler upgrade_handler in + read_request + t + (Request.create + ~headers:(Headers.of_list [ "host", "example.com:80" ]) + `CONNECT + "/"); + write_response ~msg:"Upgrade response written" t (Response.create `Switching_protocols); + Alcotest.(check bool) "Callback was called" true !upgraded; + reader_yielded t; +;; + let tests = [ "initial reader state" , `Quick, test_initial_reader_state ; "shutdown reader closed", `Quick, test_reader_is_closed_after_eof @@ -2520,4 +2538,5 @@ let tests = ; "can read more requests after write eof", `Quick, test_can_read_more_requests_after_write_eof ; "can read more requests after write eof (before response sent)", `Quick, test_can_read_more_requests_after_write_eof_before_send_response ; "write response after reader EOF", `Quick,test_write_response_after_read_eof + ; "CONNECT method", `Quick, test_connect_method ]