diff --git a/.ocamlformat b/.ocamlformat index 3e21906..d782af1 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,2 +1,2 @@ -version=0.26.1 +# version=0.26.1 ocaml-version=4.08 diff --git a/arpaca.opam b/arpaca.opam new file mode 100644 index 0000000..de4e38d --- /dev/null +++ b/arpaca.opam @@ -0,0 +1,40 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "An Eio implementation of gRPC client" +description: "Functionality for building gRPC services and rpcs with `eio`." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "grpc" {= version} + "grpc-client-eio" {= version} + "grpc-client" {= version} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/dune b/dune new file mode 100644 index 0000000..494744c --- /dev/null +++ b/dune @@ -0,0 +1,3 @@ +(dirs bench examples lib) + +(vendored_dirs ocaml-h2 gluten) diff --git a/dune-project b/dune-project index 32c4d3a..ecc6a07 100644 --- a/dune-project +++ b/dune-project @@ -19,6 +19,8 @@ (maintainers "Daniel Quernheim ") +(cram enable) + (source (github dialohq/ocaml-grpc)) @@ -28,7 +30,7 @@ (name grpc) (synopsis "A modular gRPC library") (description - "This library builds some of the signatures and implementations of gRPC functionality. This is used in the more specialised package `grpc-lwt` which has more machinery, however this library can also be used to do some bits yourself.") + "This library contains the implementation of (de)serialization of gRPC messages and statuses.") (tags (network rpc serialisation)) (depends @@ -36,12 +38,34 @@ (>= 4.08)) (bigstringaf (>= 0.9.1)) - (h2 - (>= 0.9.0)) ppx_deriving (uri (>= 4.0.0)))) +(package + (name grpc-server) + (synopsis "Reusable logic for server side gRPC") + (description + "All modules are networking-layer and concurrency-layer agnostic.") + (tags + (network rpc serialisation)) + (depends + (ocaml + (>= 4.08)) + (grpc (= :version)))) + +(package + (name grpc-client) + (synopsis "Reusable logic for client side gRPC") + (description + "All modules are networking-layer and concurrency-layer agnostic.") + (tags + (network rpc serialisation)) + (depends + (ocaml + (>= 4.08)) + (grpc (= :version)))) + (package (name grpc-lwt) (synopsis "An Lwt implementation of gRPC") @@ -50,7 +74,7 @@ (tags (network rpc serialisation)) (depends - (grpc + (grpc-server (= :version)) (h2 (>= 0.9.0)) @@ -70,7 +94,7 @@ (>= 4.11)) (async (>= v0.16)) - (grpc + (grpc-server (= :version)) (h2 (>= 0.9.0)) @@ -79,18 +103,71 @@ stringext)) (package - (name grpc-eio) - (synopsis "An Eio implementation of gRPC") + (name grpc-server-eio) + (deprecated_package_names grpc-eio) + (synopsis "An Eio implementation of gRPC server") (description "Functionality for building gRPC services and rpcs with `eio`.") (depends (eio (>= 0.12)) - (grpc + (grpc-server + (= :version)) + stringext)) + +(package + (name grpc-client-eio) + (synopsis "An Eio implementation of gRPC client") + (description + "Functionality for building gRPC services and rpcs with `eio`.") + (depends + (eio + (>= 0.12)) + (grpc-client + (= :version)))) + +(package + (name grpc-eio-io-client-h2-ocaml-protoc) + (synopsis "An h2 implementation of gRPC networking layer for eio based clients.") + (depends + (grpc-client-eio (= :version)) (h2 (>= 0.9.0)) - stringext)) + pbrt + pbrt_services + eio + h2-eio + grpc-eio-core)) + +(package + (name grpc-eio-io-server-h2-ocaml-protoc) + (synopsis "An h2 implementation of gRPC networking layer for eio based servers.") + (depends + (grpc-server-eio + (= :version)) + (h2 + (>= 0.9.0)) + stringext + pbrt + pbrt_services + eio + h2-eio + grpc-eio-core)) + + +(package + (name arpaca) + (synopsis "An Eio implementation of gRPC client") + (description + "Functionality for building gRPC services and rpcs with `eio`.") + (depends + (grpc + (= :version)) + (grpc-client-eio + (= :version)) + (grpc-client + (= :version)))) (package (name grpc-examples) @@ -154,3 +231,12 @@ grpc (notty (>= 0.2.3)))) + +(package + (name grpc-eio-core) + (synopsis "Benchmarking package for gRPC") + (description "Benchmarking package for gRPC.") + (tags + (network rpc serialisation benchmark)) + (depends + eio)) diff --git a/examples/greeter-client-eio/dune b/examples/greeter-client-eio/dune index 37f97bc..654ad7a 100644 --- a/examples/greeter-client-eio/dune +++ b/examples/greeter-client-eio/dune @@ -1,3 +1,8 @@ (executable (name greeter_client_eio) - (libraries grpc grpc-eio ocaml-protoc-plugin eio_main greeter h2 h2-eio)) + (libraries + grpc-client-eio + ocaml-protoc-plugin + eio_main + greeter + grpc-eio-net-client-h2)) diff --git a/examples/greeter-client-eio/greeter_client_eio.ml b/examples/greeter-client-eio/greeter_client_eio.ml index c8b0530..1ed0bda 100644 --- a/examples/greeter-client-eio/greeter_client_eio.ml +++ b/examples/greeter-client-eio/greeter_client_eio.ml @@ -1,56 +1,33 @@ let main env = let name = if Array.length Sys.argv > 1 then Sys.argv.(1) else "anonymous" in - let host = "localhost" in - let port = "8080" in let network = Eio.Stdenv.net env in let run sw = - let inet, port = - Eio_unix.run_in_systhread (fun () -> - Unix.getaddrinfo host port [ Unix.(AI_FAMILY PF_INET) ]) - |> List.filter_map (fun (addr : Unix.addr_info) -> - match addr.ai_addr with - | Unix.ADDR_UNIX _ -> None - | ADDR_INET (addr, port) -> Some (addr, port)) - |> List.hd - in - let addr = `Tcp (Eio_unix.Net.Ipaddr.of_unix inet, port) in - let socket = Eio.Net.connect ~sw network addr in - let connection = - H2_eio.Client.create_connection ~sw ~error_handler:ignore socket - in - let open Ocaml_protoc_plugin in let open Greeter.Mypackage in let encode, decode = Service.make_client_functions Greeter.sayHello in - let encoded_request = - HelloRequest.make ~name () |> encode |> Writer.contents - in - let f decoder = - match decoder with - | Some decoder -> ( - Reader.create decoder |> decode |> function - | Ok v -> v - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error e))) - | None -> Greeter.SayHello.Response.make () + let net = + Grpc_eio_net_client_h2.create_client ~sw ~net:network + "http://localhost:8080" in let result = - Grpc_eio.Client.call ~service:"mypackage.Greeter" ~rpc:"SayHello" - ~do_request:(H2_eio.Client.request connection ~error_handler:ignore) - ~handler:(Grpc_eio.Client.Rpc.unary encoded_request ~f) - () + Grpc_client_eio.Client.unary ~sw ~net ~service:"mypackage.Greeter" + ~method_name:"SayHello" + ~encode:(fun x -> x |> encode |> Writer.contents) + ~decode:(fun x -> Reader.create x |> decode) + ~headers:(Grpc_client.make_request_headers `Proto) + (HelloRequest.make ~name ()) in - Eio.Promise.await (H2_eio.Client.shutdown connection); - result + match result with + | Ok message -> Eio.traceln "%s" message + | Error (`Rpc (response, status)) -> + Eio.traceln "Error: %a, %a" H2.Status.pp_hum response.status + Grpc.Status.pp status + | Error (`Connection _err) -> Eio.traceln "Connection error" + | Error (`Decoding err) -> + Eio.traceln "Decoding error: %a" Ocaml_protoc_plugin.Result.pp_error err in Eio.Switch.run run -let () = - match Eio_main.run main with - | Ok (message, status) -> - Eio.traceln "%s: %s" (Grpc.Status.show status) message - | Error err -> Eio.traceln "Error: %a" H2.Status.pp_hum err +let () = Eio_main.run main diff --git a/examples/greeter-server-eio/dune b/examples/greeter-server-eio/dune index 8108aa6..9d68bae 100644 --- a/examples/greeter-server-eio/dune +++ b/examples/greeter-server-eio/dune @@ -1,3 +1,9 @@ (executable (name greeter_server_eio) - (libraries grpc grpc-eio ocaml-protoc-plugin eio_main greeter h2 h2-eio)) + (libraries + eio + grpc-server-eio + ocaml-protoc-plugin + eio_main + greeter + grpc-eio-net-server-h2)) diff --git a/examples/greeter-server-eio/greeter_server_eio.ml b/examples/greeter-server-eio/greeter_server_eio.ml index 16aaba0..0d41123 100644 --- a/examples/greeter-server-eio/greeter_server_eio.ml +++ b/examples/greeter-server-eio/greeter_server_eio.ml @@ -1,6 +1,7 @@ -open Grpc_eio +module Server = Grpc_server_eio +module Net = Grpc_eio_net_server_h2 -let say_hello buffer = +let say_hello env buffer = let open Ocaml_protoc_plugin in let open Greeter.Mypackage in let decode, encode = Service.make_service_functions Greeter.sayHello in @@ -16,55 +17,35 @@ let say_hello buffer = else Format.sprintf "Hello, %s!" request in let reply = Greeter.SayHello.Response.make ~message () in - (Grpc.Status.(v OK), Some (encode reply |> Writer.contents)) - -let connection_handler server sw = - let error_handler client_address ?request:_ _error start_response = - Eio.traceln "Error in request from:%a" Eio.Net.Sockaddr.pp client_address; - let response_body = start_response H2.Headers.empty in - H2.Body.Writer.write_string response_body - "There was an error handling your request.\n"; - H2.Body.Writer.close response_body - in - let request_handler client_address request_descriptor = - Eio.traceln "Handling a request from:%a" Eio.Net.Sockaddr.pp client_address; - Eio.Fiber.fork ~sw (fun () -> - Grpc_eio.Server.handle_request server request_descriptor) - in - fun socket addr -> - H2_eio.Server.create_connection_handler ?config:None ~request_handler - ~error_handler addr ~sw socket + Eio.Time.sleep env#clock 10.0; + (Grpc_server.trailers_with_code OK, Some (encode reply |> Writer.contents)) let serve server env = let port = 8080 in let net = Eio.Stdenv.net env in let addr = `Tcp (Eio.Net.Ipaddr.V4.loopback, port) in Eio.Switch.run @@ fun sw -> - let handler = connection_handler server sw in let server_socket = Eio.Net.listen net ~sw ~reuse_addr:true ~backlog:10 addr in - let rec listen () = - Eio.Net.accept_fork ~sw server_socket - ~on_error:(fun exn -> Eio.traceln "%s" (Printexc.to_string exn)) - handler; - listen () + let connection_handler client_addr socket = + Eio.Switch.run (fun sw -> + Net.connection_handler ~sw server client_addr socket) in - Printf.printf "Listening on port %i for grpc requests\n" port; - print_endline ""; - print_endline "Try running:"; - print_endline ""; - print_endline - {| dune exec -- examples/greeter-client-eio/greeter_client_eio.exe |}; - listen () + Eio.Net.run_server + ~on_error:(fun exn -> Eio.traceln "%s" (Printexc.to_string exn)) + server_socket connection_handler -let () = - let greeter_service = - Server.Service.( - v () |> add_rpc ~name:"SayHello" ~rpc:(Unary say_hello) |> handle_request) - in - let server = - Server.( - v () |> add_service ~name:"mypackage.Greeter" ~service:greeter_service) - in - Eio_main.run (serve server) +let mk_handler f = + { Grpc_server_eio.Rpc.headers = (fun _ -> Grpc_server.headers `Proto); f } + +let server env = + let add_rpc = Server.Service.add_rpc in + let open Server.Rpc in + let service = + Server.Service.v () + |> add_rpc ~name:"SayHello" ~rpc:(mk_handler (unary (say_hello env))) + in + Server.(make () |> add_service ~name:"mypackage.Greeter" ~service) + +let () = Eio_main.run (fun env -> serve (server env) env) diff --git a/examples/routeguide/proto/dune b/examples/routeguide/proto/dune index 50e1fe2..b7ca9e8 100644 --- a/examples/routeguide/proto/dune +++ b/examples/routeguide/proto/dune @@ -3,7 +3,7 @@ (package grpc-examples) (preprocess (pps ppx_deriving.show ppx_deriving.eq)) - (libraries ocaml-protoc-plugin)) + (libraries pbrt)) (rule (targets route_guide.ml) @@ -11,8 +11,4 @@ (:proto route_guide.proto)) (action (run - protoc - -I - . - "--ocaml_out=annot=[@@deriving show { with_path = false }, eq]:." - %{proto}))) + ocaml-protoc --ocaml_all_types_ppx "deriving show" --int32_type int_t --int64_type int_t --binary --ml_out ./ %{proto}))) diff --git a/examples/routeguide/proto/route_guide.proto b/examples/routeguide/proto/route_guide.proto index 789b4cb..0263ea9 100644 --- a/examples/routeguide/proto/route_guide.proto +++ b/examples/routeguide/proto/route_guide.proto @@ -102,4 +102,4 @@ message RouteSummary { // The duration of the traversal in seconds. int32 elapsed_time = 4; -} \ No newline at end of file +} diff --git a/examples/routeguide/src/client.ml b/examples/routeguide/src/client.ml index 47d8dba..d039077 100644 --- a/examples/routeguide/src/client.ml +++ b/examples/routeguide/src/client.ml @@ -1,139 +1,109 @@ -open Grpc_eio -open Routeguide.Route_guide.Routeguide -open Ocaml_protoc_plugin - -(* $MDX part-begin=client-h2 *) -let client ~sw host port network = - let inet, port = - Eio_unix.run_in_systhread (fun () -> - Unix.getaddrinfo host port [ Unix.(AI_FAMILY PF_INET) ]) - |> List.filter_map (fun (addr : Unix.addr_info) -> - match addr.ai_addr with - | Unix.ADDR_UNIX _ -> None - | ADDR_INET (addr, port) -> Some (addr, port)) - |> List.hd +open Routeguide +module Client = Grpc_client_eio.Client + +let get_feature sw io request = + let response = + Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"GetFeature" + ~headers:(Grpc_client.make_request_headers `Proto) (fun encoder -> + Route_guide.encode_pb_point request encoder) in - let addr = `Tcp (Eio_unix.Net.Ipaddr.of_unix inet, port) in - let socket = Eio.Net.connect ~sw network addr in - H2_eio.Client.create_connection ~sw ~error_handler:ignore socket + match response with + | `Success ({ response = res; _ } as result) -> + `Success + { + result with + response = + res.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_feature; + } + | ( `Premature_close _ | `Write_error _ | `Connection_error _ + | `Response_not_ok _ ) as rest -> + rest (* $MDX part-end *) (* $MDX part-begin=client-get-feature *) -let call_get_feature connection point = - let encode, decode = Service.make_client_functions RouteGuide.getFeature in +let call_get_feature sw io point = let response = - Client.call ~service:"routeguide.RouteGuide" ~rpc:"GetFeature" - ~do_request:(H2_eio.Client.request connection ~error_handler:ignore) - ~handler: - (Client.Rpc.unary - (encode point |> Writer.contents) - ~f:(fun response -> - match response with - | Some response -> ( - Reader.create response |> decode |> function - | Ok feature -> feature - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error e))) - | None -> Feature.make ())) - () + Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"GetFeature" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun encoder -> Route_guide.encode_pb_point point encoder) in match response with - | Ok (res, _ok) -> Printf.printf "RESPONSE = {%s}" (Feature.show res) - | Error _ -> Printf.printf "an error occurred" + | `Success { response = res; _ } -> + Printf.printf "RESPONSE = {%s}%!" + (Route_guide.show_feature + (res.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_feature)) + | _ -> Printf.printf "an error occurred" (* $MDX part-end *) + (* $MDX part-begin=client-list-features *) -let print_features connection = +let print_features sw io = let rectangle = - Rectangle.make - ~lo:(Point.make ~latitude:400000000 ~longitude:(-750000000) ()) - ~hi:(Point.make ~latitude:420000000 ~longitude:(-730000000) ()) + Route_guide.default_rectangle + ~lo: + (Some + (Route_guide.default_point ~latitude:400000000 + ~longitude:(-750000000) ())) + ~hi: + (Some + (Route_guide.default_point ~latitude:420000000 + ~longitude:(-730000000) ())) () in - let encode, decode = Service.make_client_functions RouteGuide.listFeatures in let stream = - Client.call ~service:"routeguide.RouteGuide" ~rpc:"ListFeatures" - ~do_request:(H2_eio.Client.request connection ~error_handler:ignore) - ~handler: - (Client.Rpc.server_streaming - (encode rectangle |> Writer.contents) - ~f:(fun responses -> - let stream = - Seq.map - (fun str -> - Reader.create str |> decode |> function - | Ok feature -> feature - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error e))) - responses - in - stream)) - () + Client.Server_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"ListFeatures" + ~headers:(Grpc_client.make_request_headers `Proto) + (Route_guide.encode_pb_rectangle rectangle) (fun _ ~read -> + Seq.iter + (fun f -> + Printf.printf "RESPONSE = {%s}%!" + (Route_guide.show_feature + (f.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_feature))) + read) in match stream with - | Ok (results, _ok) -> - Seq.iter - (fun f -> Printf.printf "RESPONSE = {%s}" (Feature.show f)) - results - | Error e -> - failwith (Printf.sprintf "HTTP2 error: %s" (H2.Status.to_string e)) + | `Stream_result { err = None; _ } -> () + | _ -> failwith "an erra" (* $MDX part-end *) -(* $MDX part-begin=client-random-point *) -let random_point () : Point.t = +(* $MDX part-begin=client-record-route *) +let random_point () = let latitude = (Random.int 180 - 90) * 10000000 in let longitude = (Random.int 360 - 180) * 10000000 in - Point.make ~latitude ~longitude () + Route_guide.default_point ~latitude ~longitude () -(* $MDX part-end *) -(* $MDX part-begin=client-record-route *) -let run_record_route connection = +let run_record_route sw io = let points = Random.int 100 |> Seq.unfold (function 0 -> None | x -> Some (random_point (), x - 1)) in - let encode, decode = Service.make_client_functions RouteGuide.recordRoute in let response = - Client.call ~service:"routeguide.RouteGuide" ~rpc:"RecordRoute" - ~do_request:(H2_eio.Client.request connection ~error_handler:ignore) - ~handler: - (Client.Rpc.client_streaming ~f:(fun f response -> - (* Stream points to server. *) - Seq.iter - (fun point -> - encode point |> Writer.contents |> fun x -> Seq.write f x) - points; - - (* Signal we have finished sending points. *) - Seq.close_writer f; - - (* Decode RouteSummary responses. *) - Eio.Promise.await response |> function - | Some str -> ( - Reader.create str |> decode |> function - | Ok feature -> feature - | Error err -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error err))) - | None -> failwith (Printf.sprintf "No RouteSummary received."))) - () + Client.Client_streaming.call ~io ~sw ~service:"routeguide.RouteGuide" + ~headers:(Grpc_client.make_request_headers `Proto) + ~method_name:"RecordRoute" (fun _ ~writer -> + Seq.iter + (fun point -> + writer.write (Route_guide.encode_pb_point point) |> ignore; + Printf.printf "SENT = {%s}\n%!" (Route_guide.show_point point)) + points) in match response with - | Ok (result, _ok) -> - Printf.printf "SUMMARY = {%s}" (RouteSummary.show result) - | Error e -> - failwith (Printf.sprintf "HTTP2 error: %s" (H2.Status.to_string e)) + | `Success { response; _ } -> + Printf.printf "SUMMARY = {%s}\n%!" + (Route_guide.show_route_summary + (response.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_route_summary)) + | _ -> failwith "Error occured" (* $MDX part-end *) (* $MDX part-begin=client-route-chat-1 *) -let run_route_chat clock connection = +let run_route_chat clock io sw = (* Generate locations. *) let location_count = 5 in Printf.printf "Generating %i locations\n" location_count; @@ -143,83 +113,73 @@ let run_route_chat clock connection = | 0 -> None | x -> Some - ( RouteNote.make ~location:(random_point ()) + ( Route_guide.default_route_note + ~location:(Some (random_point ())) ~message:(Printf.sprintf "Random Message %i" x) (), x - 1 )) in (* $MDX part-end *) (* $MDX part-begin=client-route-chat-2 *) - let encode, decode = Service.make_client_functions RouteGuide.routeChat in - let rec go writer reader notes = + let rec go ~send ~close reader notes = match Seq.uncons notes with - | None -> - Seq.close_writer writer (* Signal no more notes from the client. *) + | None -> () (* Signal no more notes from the server. *) | Some (route_note, xs) -> ( - encode route_note |> Writer.contents |> fun x -> - Seq.write writer x; + send (Route_guide.encode_pb_route_note route_note) |> ignore; - (* Yield and sleep, waiting for server reply. *) Eio.Time.sleep clock 1.0; - Eio.Fiber.yield (); - - match Seq.uncons reader with - | None -> failwith "Expecting response" - | Some (response, reader') -> - let route_note = - Reader.create response |> decode |> function - | Ok route_note -> route_note - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error e)) - in - Printf.printf "NOTE = {%s}\n" (RouteNote.show route_note); - go writer reader' xs) + + match reader () with + | Seq.Nil -> failwith "Expecting response" + | Seq.Cons (route_note, reader') -> + Printf.printf "NOTE = {%s}\n%!" + (Route_guide.show_route_note + (route_note.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_route_note)); + go ~send ~close reader' xs) in let result = - Client.call ~service:"routeguide.RouteGuide" ~rpc:"RouteChat" - ~do_request:(H2_eio.Client.request connection ~error_handler:ignore) - ~handler: - (Client.Rpc.bidirectional_streaming ~f:(fun writer reader -> - go writer reader route_notes)) - () + Client.Bidirectional_streaming.call ~service:"routeguide.RouteGuide" + ~method_name:"RouteChat" ~io ~sw + ~headers:(Grpc_client.make_request_headers `Proto) (fun _ ~writer ~read -> + go ~send:writer.write ~close:writer.close read route_notes; + []) in match result with - | Ok ((), _ok) -> () - | Error e -> - failwith (Printf.sprintf "HTTP2 error: %s" (H2.Status.to_string e)) + | `Stream_result { err = None; _ } -> () + | _e -> failwith "Error" (* $MDX part-end *) +(* $MDX part-end *) + (* $MDX part-begin=client-main *) let main env = - let port = "8080" in - let host = "localhost" in let clock = Eio.Stdenv.clock env in let network = Eio.Stdenv.net env in let () = Random.self_init () in let run sw = - let connection = client ~sw host port network in + let io = + Io_client_h2_ocaml_protoc.create_client ~net:network ~sw + "http://localhost:8080" + in + + Printf.printf "*** SIMPLE RPC ***\n%!"; - Printf.printf "*** SIMPLE RPC ***\n"; let request = - RouteGuide.GetFeature.Request.make ~latitude:409146138 - ~longitude:(-746188906) () + Route_guide.default_point ~latitude:409146138 ~longitude:(-746188906) () in - let result = call_get_feature connection request in + let result = call_get_feature sw io request in - Printf.printf "\n*** SERVER STREAMING ***\n"; - print_features connection; + Printf.printf "\n*** SERVER STREAMING ***\n%!"; + print_features sw io; - Printf.printf "\n*** CLIENT STREAMING ***\n"; - run_record_route connection; + Printf.printf "\n*** CLIENT STREAMING ***\n%!"; + run_record_route sw io; - Printf.printf "\n*** BIDIRECTIONAL STREAMING ***\n"; - run_route_chat clock connection; - - Eio.Promise.await (H2_eio.Client.shutdown connection); + Printf.printf "\n*** BIDIRECTIONAL STREAMING ***\n%!"; + run_route_chat clock io sw; result in @@ -228,3 +188,17 @@ let main env = let () = Eio_main.run main (* $MDX part-end *) + +let list_features ~sw ~io request handler = + Client.Server_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"ListFeatures" + ~headers:(Grpc_client.make_request_headers `Proto) + (Route_guide.encode_pb_rectangle request) (fun net_response ~read -> + let responses = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_feature) + read + in + handler net_response responses) diff --git a/examples/routeguide/src/dune b/examples/routeguide/src/dune index 9c5afaf..1cc839f 100644 --- a/examples/routeguide/src/dune +++ b/examples/routeguide/src/dune @@ -3,12 +3,14 @@ (package grpc-examples) (public_names routeguide-server routeguide-client) (libraries - grpc-eio + grpc-server-eio + grpc-client-eio eio_main - h2-eio ocaml-protoc-plugin routeguide yojson - ppx_deriving_yojson.runtime) + ppx_deriving_yojson.runtime + io_server_h2_ocaml_protoc + io_client_h2_ocaml_protoc) (preprocess (pps ppx_deriving_yojson ppx_deriving.eq))) diff --git a/examples/routeguide/src/server.ml b/examples/routeguide/src/server.ml index bfa30d9..debc320 100644 --- a/examples/routeguide/src/server.ml +++ b/examples/routeguide/src/server.ml @@ -1,37 +1,29 @@ -open Grpc_eio -open Routeguide.Route_guide.Routeguide -open Ocaml_protoc_plugin +open Routeguide +module Server = Grpc_server_eio +module R = Route_guide (* Derived data types to make reading JSON data easier. *) -type location = { latitude : int; longitude : int } [@@deriving yojson] -type feature = { location : location; name : string } [@@deriving yojson] +type location = R.point = { latitude : int; longitude : int } +[@@deriving yojson] + +type feature = { name : string; location : location } [@@deriving yojson] type feature_list = feature list [@@deriving yojson] -let features : Feature.t list ref = ref [] +let features : feature list ref = ref [] module RouteNotesMap = Hashtbl.Make (struct - type t = Point.t + type t = Route_guide.point - let equal = Point.equal + let equal = ( = ) let hash s = Hashtbl.hash s end) (** Load route_guide data from a JSON file. *) -let load path : Feature.t list = +let load path : feature list = let json = Yojson.Safe.from_file path in - match feature_list_of_yojson json with - | Ok v -> - List.map - (fun feature -> - Feature.make ~name:feature.name - ~location: - (Point.make ~longitude:feature.location.longitude - ~latitude:feature.location.latitude ()) - ()) - v - | Error err -> failwith err - -let in_range (point : Point.t) (rect : Rectangle.t) : bool = + match feature_list_of_yojson json with Ok v -> v | Error err -> failwith err + +let in_range (point : R.point) (rect : R.rectangle) : bool = let lo = Option.get rect.lo in let hi = Option.get rect.hi in @@ -48,7 +40,7 @@ let radians_of_degrees = ( *. ) (pi /. 180.) (* Calculates the distance between two points using the "haversine" formula. *) (* This code was taken from http://www.movable-type.co.uk/scripts/latlong.html. *) -let calc_distance (p1 : Point.t) (p2 : Point.t) : int = +let calc_distance (p1 : R.point) (p2 : R.point) : int = let cord_factor = 1e7 in let r = 6_371_000.0 in (* meters *) @@ -73,189 +65,172 @@ let calc_distance (p1 : Point.t) (p2 : Point.t) : int = Float.to_int (r *. c) (* $MDX part-begin=server-get-feature *) -let get_feature (buffer : string) = - let decode, encode = Service.make_service_functions RouteGuide.getFeature in - (* Decode the request. *) - let point = - Reader.create buffer |> decode |> function - | Ok v -> v - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" (Result.show_error e)) - in - Eio.traceln "GetFeature = {:%s}" (Point.show point); - (* Lookup the feature and if found return it. *) +let get_feature _ point = + Format.printf "%a" Route_guide.pp_point point; + Eio.traceln "GetFeature = {:%s}" (R.show_point point); let feature = - List.find_opt - (fun (f : Feature.t) -> - match (f.location, point) with - | Some p1, p2 -> Point.equal p1 p2 - | _, _ -> false) - !features + List.find_opt (fun (f : feature) -> f.location = point) !features + |> Option.map (fun { location; name } : R.feature -> + { R.name; location = Some location }) in Eio.traceln "Found feature %s" - (feature |> Option.map Feature.show |> Option.value ~default:"Missing"); + (feature |> Option.map R.show_feature |> Option.value ~default:"Missing"); match feature with - | Some feature -> - (Grpc.Status.(v OK), Some (feature |> encode |> Writer.contents)) + | Some feature -> (feature, []) | None -> (* No feature was found, return an unnamed feature. *) - ( Grpc.Status.(v OK), - Some (Feature.make ~location:point () |> encode |> Writer.contents) ) + (R.default_feature ~location:(Some point) (), []) + +let get_feature = + Grpc_server_eio.Rpc.unary (fun req -> + let feature, headers = + get_feature () + (req.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_point) + in + ((fun encoder -> R.encode_pb_feature feature encoder), headers)) (* $MDX part-end *) -(* $MDX part-begin=server-list-features *) -let list_features (buffer : string) (f : string -> unit) = - (* Decode request. *) - let decode, encode = Service.make_service_functions RouteGuide.listFeatures in - let rectangle = - Reader.create buffer |> decode |> function - | Ok v -> v - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" (Result.show_error e)) - in +(* $MDX part-begin=server-grpc *) - (* Lookup and reply with features found. *) - let () = - List.iter - (fun (feature : Feature.t) -> - if in_range (Option.get feature.location) rectangle then - encode feature |> Writer.contents |> f - else ()) - !features - in - Grpc.Status.(v OK) +let mk_handler f _req rpc = + rpc.Grpc_server_eio.Rpc.accept (Grpc_server.headers `Proto) f + +(* +let route_guide_service clock = + let add_rpc = Server.Service.add_rpc in + let open Server.Rpc in + Server.Service.v () + |> add_rpc ~name:"GetFeature" ~rpc:(mk_handler (unary get_feature)) + |> add_rpc ~name:"ListFeatures" + ~rpc:(mk_handler (server_streaming list_features)) + |> add_rpc ~name:"RecordRoute" + ~rpc:(mk_handler (client_streaming (record_route clock))) + |> add_rpc ~name:"RouteChat" ~rpc:(mk_handler route_chat) + +let server clock = + Server.( + make () + |> add_service ~name:"routeguide.RouteGuide" + ~service:(route_guide_service clock)) +*) (* $MDX part-end *) -(* $MDX part-begin=server-record-route *) -let record_route (clock : _ Eio.Time.clock) (stream : string Seq.t) = - Eio.traceln "RecordRoute"; - - let last_point = ref None in - let start = Eio.Time.now clock in - let decode, encode = Service.make_service_functions RouteGuide.recordRoute in - - let point_count, feature_count, distance = - Seq.fold_left - (fun (point_count, feature_count, distance) i -> - let point = - Reader.create i |> decode |> function - | Ok v -> v - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error e)) - in - Eio.traceln " ==> Point = {%s}" (Point.show point); - - (* Increment the point count *) - let point_count = point_count + 1 in - - (* Find features *) - let feature_count = - List.find_all - (fun (feature : Feature.t) -> - Point.equal (Option.get feature.location) point) - !features - |> fun x -> List.length x + feature_count - in - - (* Calculate the distance *) - let distance = - match !last_point with - | Some last_point -> calc_distance last_point point - | None -> distance - in - last_point := Some point; - (point_count, feature_count, distance)) - (0, 0, 0) stream - in - let stop = Eio.Time.now clock in - let elapsed_time = int_of_float (stop -. start) in - let summary = - RouteSummary.make ~point_count ~feature_count ~distance ~elapsed_time () - in - Eio.traceln "RecordRoute exit\n"; - (Grpc.Status.(v OK), Some (encode summary |> Writer.contents)) + +let list_features = + Grpc_server_eio.Rpc.server_streaming (fun req write -> + let rectangle = + req.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_rectangle + in + + let () = + List.iter + (fun feature -> + if in_range feature.location rectangle then + write (fun encoder -> + R.encode_pb_feature + { R.location = Some feature.location; name = feature.name } + encoder) + else ()) + !features + in + []) (* $MDX part-end *) (* $MDX part-begin=server-route-chat *) -let route_chat (stream : string Seq.t) (f : string -> unit) = - Printf.printf "RouteChat\n"; +let route_chat = + fun read write -> + Printf.printf "RouteChat\n%!"; - let decode, encode = Service.make_service_functions RouteGuide.routeChat in Seq.iter (fun i -> let note = - Reader.create i |> decode |> function - | Ok v -> v - | Error e -> - failwith - (Printf.sprintf "Could not decode request: %s" - (Result.show_error e)) + i.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_route_note in - Printf.printf " ==> Note = {%s}\n" (RouteNote.show note); - encode note |> Writer.contents |> f) - stream; - - Printf.printf "RouteChat exit\n"; - Grpc.Status.(v OK) + Printf.printf " ==> Note = {%s}\n%!" (Route_guide.show_route_note note); + write (Route_guide.encode_pb_route_note note)) + read; + Printf.printf "RouteChat exit\n%!"; + [] (* $MDX part-end *) -(* $MDX part-begin=server-grpc *) -let route_guide_service clock = - Server.Service.( - v () - |> add_rpc ~name:"GetFeature" ~rpc:(Unary get_feature) - |> add_rpc ~name:"ListFeatures" ~rpc:(Server_streaming list_features) - |> add_rpc ~name:"RecordRoute" ~rpc:(Client_streaming (record_route clock)) - |> add_rpc ~name:"RouteChat" ~rpc:(Bidirectional_streaming route_chat) - |> handle_request) +(* $MDX part-begin=server-record-route *) -let server clock = - Server.( - v () - |> add_service ~name:"routeguide.RouteGuide" - ~service:(route_guide_service clock)) +(* $MDX part-end *) +(* $MDX part-begin=server-record-route *) +let record_route clock = + Grpc_server_eio.Rpc.client_streaming (fun stream -> + Eio.traceln "RecordRoute"; + let last_point = ref None in + let start = Eio.Time.now clock in + + let point_count, feature_count, distance = + Seq.fold_left + (fun (point_count, feature_count, distance) point -> + let point = + point.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_point + in + Eio.traceln " ==> Point = {%s}" (Route_guide.show_point point); + + (* Increment the point count *) + let point_count = point_count + 1 in + + (* Find features *) + let feature_count = + List.find_all + (fun (feature : feature) -> feature.location = point) + !features + |> fun x -> List.length x + feature_count + in + + (* Calculate the distance *) + let distance = + match !last_point with + | Some last_point -> calc_distance last_point point + | None -> distance + in + last_point := Some point; + (point_count, feature_count, distance)) + (0, 0, 0) stream + in + ( Route_guide.encode_pb_route_summary + { + point_count; + feature_count; + distance; + elapsed_time = Eio.Time.now clock -. start |> Float.to_int; + }, + [] )) (* $MDX part-end *) -let connection_handler server ~sw = - let error_handler client_address ?request:_ _error start_response = - Eio.traceln "Error in request from:%a" Eio.Net.Sockaddr.pp client_address; - let response_body = start_response H2.Headers.empty in - H2.Body.Writer.write_string response_body - "There was an error handling your request.\n"; - H2.Body.Writer.close response_body - in - let request_handler _client_address request_descriptor = - Eio.Fiber.fork ~sw (fun () -> - Grpc_eio.Server.handle_request server request_descriptor) - in - fun socket addr -> - H2_eio.Server.create_connection_handler ?config:None ~request_handler - ~error_handler addr socket ~sw + +let server clock ~service ~meth = + match (service, meth) with + | "routeguide.RouteGuide", "GetFeature" -> mk_handler get_feature + | "routeguide.RouteGuide", "ListFeatures" -> mk_handler list_features + | "routeguide.RouteGuide", "RecordRoute" -> mk_handler (record_route clock) + | "routeguide.RouteGuide", "RouteChat" -> mk_handler route_chat + | _ -> + raise (Grpc_server_eio.Server_error (Grpc.Status.make Unimplemented, [])) (* $MDX part-begin=server-main *) -let serve server env = +let serve server env : unit = let port = 8080 in let net = Eio.Stdenv.net env in - let clock = Eio.Stdenv.clock env in let addr = `Tcp (Eio.Net.Ipaddr.V4.loopback, port) in Eio.Switch.run @@ fun sw -> - let handler = connection_handler ~sw (server clock) in let server_socket = Eio.Net.listen net ~sw ~reuse_addr:true ~backlog:10 addr in - let rec listen () = - Eio.Net.accept_fork ~sw server_socket - ~on_error:(fun exn -> Eio.traceln "%s" (Printexc.to_string exn)) - handler; - listen () + let connection_handler client_addr socket = + Eio.Switch.run (fun sw -> + Io_server_h2_ocaml_protoc.connection_handler ~sw server client_addr + socket) in - Eio.traceln "Listening on port %i for grpc requests\n" port; - listen () + Eio.Net.run_server + ~on_error:(fun exn -> Eio.traceln "%s" (Printexc.to_string exn)) + server_socket connection_handler let () = let path = @@ -266,5 +241,5 @@ let () = (* Load features. *) features := load path; - Eio_main.run (serve server) + Eio_main.run (fun env -> serve (server (Eio.Stdenv.clock env)) env) (* $MDX part-end *) diff --git a/flake.lock b/flake.lock index 6d4b0fc..4d1f71d 100644 --- a/flake.lock +++ b/flake.lock @@ -1,24 +1,186 @@ { "nodes": { + "flake-parts": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib" + }, + "locked": { + "lastModified": 1712014858, + "narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "9126214d0a59633752a136528f5f3b9aa8565b7d", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-filter": { + "locked": { + "lastModified": 1710156097, + "narHash": "sha256-1Wvk8UP7PXdf8bCCaEoMnOT1qe5/Duqgj+rL8sRQsSM=", + "owner": "numtide", + "repo": "nix-filter", + "rev": "3342559a24e85fc164b295c3444e8a139924675b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "nix-filter", + "type": "github" + } + }, "nixpkgs": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs_2" + }, "locked": { - "lastModified": 1609870929, - "narHash": "sha256-aVGF0O3T+Xg4avzyCWhkZG6DvqItK6u/1Y4yY7jnj80=", - "owner": "sternenseemann", + "lastModified": 1712561094, + "narHash": "sha256-cRvbal29hZjqtu9/hpQo4fGCH2YGKn+Kqo3apDOf5bo=", + "owner": "nix-ocaml", + "repo": "nix-overlays", + "rev": "bf4dbbb8793e72575f07489e317cc6309bca7f17", + "type": "github" + }, + "original": { + "owner": "nix-ocaml", + "repo": "nix-overlays", + "rev": "bf4dbbb8793e72575f07489e317cc6309bca7f17", + "type": "github" + } + }, + "nixpkgs-lib": { + "locked": { + "dir": "lib", + "lastModified": 1711703276, + "narHash": "sha256-iMUFArF0WCatKK6RzfUJknjem0H9m4KgorO/p3Dopkk=", + "owner": "NixOS", "repo": "nixpkgs", - "rev": "2de4f7dab09871fd05856ffde8f8e3bd40635579", + "rev": "d8fe5e6c92d0d190646fb9f1056741a229980089", "type": "github" }, "original": { - "owner": "sternenseemann", - "ref": "ppx_deriving-5.1", + "dir": "lib", + "owner": "NixOS", + "ref": "nixos-unstable", "repo": "nixpkgs", "type": "github" } }, + "nixpkgs_2": { + "locked": { + "lastModified": 1712514290, + "narHash": "sha256-Uvy+mgMdqRhuazAXwMQHVELi+yPGNj6+VTppWTurxRE=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "274e6aa01f2c2266e1cd8debdb82863cd83e2ff7", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "274e6aa01f2c2266e1cd8debdb82863cd83e2ff7", + "type": "github" + } + }, + "ocaml-overlay": { + "inputs": { + "flake-utils": "flake-utils_2", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1702307644, + "narHash": "sha256-uehhPApUVm+7jZ/MuHeZvJlWV8BB4ckkGb4iLZ5F0fU=", + "owner": "nix-ocaml", + "repo": "nix-overlays", + "rev": "a6364bea92bb35b01a3a70eed9a5cdb1063e128e", + "type": "github" + }, + "original": { + "owner": "nix-ocaml", + "repo": "nix-overlays", + "rev": "a6364bea92bb35b01a3a70eed9a5cdb1063e128e", + "type": "github" + } + }, "root": { "inputs": { - "nixpkgs": "nixpkgs" + "flake-parts": "flake-parts", + "nix-filter": "nix-filter", + "nixpkgs": "nixpkgs", + "ocaml-overlay": "ocaml-overlay" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" } } }, diff --git a/flake.nix b/flake.nix index 5b5a73c..fa2cd90 100644 --- a/flake.nix +++ b/flake.nix @@ -1,78 +1,178 @@ { - description = "A modular gRPC library"; + description = "Description for the project"; inputs = { nixpkgs = { - url = "github:sternenseemann/nixpkgs/ppx_deriving-5.1"; + url = + "github:nix-ocaml/nix-overlays/bf4dbbb8793e72575f07489e317cc6309bca7f17"; }; + flake-parts.url = "github:hercules-ci/flake-parts"; + nix-filter.url = "github:numtide/nix-filter"; + ocaml-overlay.url = + "github:nix-ocaml/nix-overlays/a6364bea92bb35b01a3a70eed9a5cdb1063e128e"; + ocaml-overlay.inputs.nixpkgs.follows = "nixpkgs"; }; - outputs = { self, nixpkgs }: - with import nixpkgs { system = "x86_64-linux"; }; - let - h2-src = fetchFromGitHub { - owner = "jeffa5"; - repo = "ocaml-h2"; - rev = "36bd7bfa46fb0eb2bce184413f663a46a5e0dd3b"; - sha256 = "sha256-8vsRpx0JVN6KHOVfKit6LhlQqGTO1ofRhfyDgJ7dGz0="; - }; + outputs = inputs@{ flake-parts, nix-filter, ocaml-overlay, ... }: + flake-parts.lib.mkFlake { inherit inputs; } { + systems = + [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; - hpack = ocamlPackages.buildDunePackage { - pname = "hpack"; - version = "0.2.0"; - src = h2-src; - useDune2 = true; - buildInputs = (with ocamlPackages; [ angstrom faraday ]); - }; + imports = [ inputs.flake-parts.flakeModules.easyOverlay ]; - h2 = ocamlPackages.buildDunePackage { - pname = "h2"; - version = "0.7.0"; - src = h2-src; - useDune2 = true; - buildInputs = (with ocamlPackages; [ hpack result httpaf psq base64 ]); - }; - in - { - packages.x86_64-linux = rec { - grpc = - ocamlPackages.buildDunePackage { - pname = "grpc"; - version = "0.1.0"; - src = self; - useDune2 = true; - doCheck = true; - buildInputs = (with ocamlPackages; [ uri h2 ppx_deriving ]); + perSystem = { config, self', inputs', system, ... }: + let + pkgs = (((import inputs.nixpkgs { + inherit system; + config.allowUnfree = true; + overlays = [ ocaml-overlay.outputs.overlays ]; + })).extend (import ./overlay.nix)).extend (self: super: { + ocamlPackages = super.ocaml-ng.ocamlPackages_5_1; + }); + camlPkgs = pkgs.ocaml-ng.ocamlPackages_5_1; + bechamel-notty = camlPkgs.buildDunePackage { + pname = "bechamel-notty"; + version = "0.5.0"; + duneVersion = "3"; + propagatedBuildInputs = + [ camlPkgs.notty camlPkgs.fmt camlPkgs.bechamel ]; + src = pkgs.fetchFromGitHub { + owner = "mirage"; + repo = "bechamel"; + rev = "v0.5.0"; + sha256 = "sha256-aTz80gjVi+ITqi8TXH1NjWPECuTcLFvTEDC7BoRo+6M="; + fetchSubmodules = true; + }; }; - - grpc-lwt = - ocamlPackages.buildDunePackage { - pname = "grpc-lwt"; + dialo-ocaml-protoc-plugin = camlPkgs.buildDunePackage { + pname = "ocaml-protoc-plugin"; version = "0.1.0"; - src = self; - useDune2 = true; - doCheck = true; - buildInputs = (with ocamlPackages; [ ocaml-protoc lwt stringext h2 grpc ]); - }; - }; - - defaultPackage.x86_64-linux = self.packages.x86_64-linux.grpc; + duneVersion = "3"; - devShell.x86_64-linux = mkShell { - buildInputs = [ - ocaml - opam + INCLUDE_GOOGLE_PROTOBUF = "${pkgs.protobuf}/include"; - m4 - pkgconfig + nativeBuildInputs = [ pkgs.protobuf ]; + propagatedBuildInputs = [ pkgs.protobuf pkgs.pkg-config ]; + buildInputs = with camlPkgs; [ lwt stringext ]; + src = pkgs.fetchFromGitHub { + owner = "dialohq"; + repo = "ocaml-protoc-plugin"; + rev = "b814b305520563fff58388682cb360660cc29c47"; + sha256 = "sha256-NgFvc+HTJXc17GwyfA0VqlWXx9R35FJ6CSEQrQ52Jds="; + fetchSubmodules = true; + }; + }; - nixpkgs-fmt - rnix-lsp - ]; + in { + devShells.default = pkgs.mkShell { + inputsFrom = [ + self'.packages.grpc + self'.packages.grpc-lwt + self'.packages.grpc-async + self'.packages.grpc-eio + self'.packages.grpc-examples + self'.packages.grpc-bench + ]; + nativeBuildInputs = with pkgs; [ + nil + nixfmt + camlPkgs.ocaml-lsp + camlPkgs.ocamlformat + camlPkgs.ocaml-protoc + ]; + }; - shellHook = '' - eval $(opam env) - ''; - }; + packages = { + grpc-bench = camlPkgs.buildDunePackage { + pname = "grpc-bench"; + version = "0.1.0"; + duneVersion = "3"; + buildInputs = with camlPkgs; [ + self'.packages.grpc + self'.packages.grpc-lwt + self'.packages.grpc-async + self'.packages.grpc-eio + bechamel-notty + bigstringaf + ]; + src = nix-filter.lib.filter { + root = ./.; + include = [ "dune-project" "examples" ]; + }; + }; + grpc-examples = camlPkgs.buildDunePackage { + pname = "grpc-examples"; + version = "0.1.0"; + duneVersion = "3"; + nativeBuildInputs = with camlPkgs; [ + dialo-ocaml-protoc-plugin + ppx_jane + ppx_deriving + ppx_deriving_yojson + ]; + buildInputs = with camlPkgs; [ + h2-lwt-unix + conduit-lwt-unix + core_unix + ppx_deriving_yojson + cohttp-lwt-unix + camlPkgs.h2-eio + camlPkgs.h2-async + tls-async + self'.packages.grpc + self'.packages.grpc-lwt + self'.packages.grpc-async + self'.packages.grpc-eio + ]; + src = nix-filter.lib.filter { + root = ./.; + include = [ "dune-project" "examples" ]; + }; + }; + grpc = camlPkgs.buildDunePackage { + pname = "grpc"; + version = "0.1.0"; + duneVersion = "3"; + nativeBuildInputs = with camlPkgs; [ mdx ]; + propagatedBuildInputs = with camlPkgs; [ ppxlib ]; + buildInputs = with camlPkgs; [ uri h2 ppx_deriving ]; + src = nix-filter.lib.filter { + root = ./.; + include = [ "dune-project" "lib/grpc" ]; + }; + }; + grpc-lwt = camlPkgs.buildDunePackage { + pname = "grpc-lwt"; + version = "0.1.0"; + duneVersion = "3"; + buildInputs = with camlPkgs; [ self'.packages.grpc lwt ]; + src = nix-filter.lib.filter { + root = ./.; + include = [ "dune-project" "lib/grpc-lwt" ]; + }; + }; + grpc-async = camlPkgs.buildDunePackage { + pname = "grpc-async"; + version = "0.1.0"; + duneVersion = "3"; + buildInputs = with camlPkgs; [ self'.packages.grpc async ]; + src = nix-filter.lib.filter { + root = ./.; + include = [ "dune-project" "lib/grpc-async" ]; + }; + }; + grpc-eio = camlPkgs.buildDunePackage { + pname = "grpc-eio"; + version = "0.1.0"; + duneVersion = "3"; + buildInputs = with camlPkgs; [ self'.packages.grpc eio ]; + src = nix-filter.lib.filter { + root = ./.; + include = [ "dune-project" "lib/grpc-eio" ]; + }; + }; + }; + packages.default = self'.packages.grpc; + }; }; } diff --git a/grpc-async.opam b/grpc-async.opam index fd4a3b4..702cac6 100644 --- a/grpc-async.opam +++ b/grpc-async.opam @@ -22,7 +22,7 @@ depends: [ "dune" {>= "3.7"} "ocaml" {>= "4.11"} "async" {>= "v0.16"} - "grpc" {= version} + "grpc-server" {= version} "h2" {>= "0.9.0"} "ppx_jane" {>= "v0.16.0"} "stringext" diff --git a/grpc-client-eio.opam b/grpc-client-eio.opam new file mode 100644 index 0000000..1ad0b77 --- /dev/null +++ b/grpc-client-eio.opam @@ -0,0 +1,39 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "An Eio implementation of gRPC client" +description: "Functionality for building gRPC services and rpcs with `eio`." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "eio" {>= "0.12"} + "grpc-client" {= version} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/grpc-client.opam b/grpc-client.opam new file mode 100644 index 0000000..90a909b --- /dev/null +++ b/grpc-client.opam @@ -0,0 +1,41 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "Reusable logic for client side gRPC" +description: + "All modules are networking-layer and concurrency-layer agnostic." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +tags: ["network" "rpc" "serialisation"] +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "ocaml" {>= "4.08"} + "grpc" {= version} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/grpc-eio-core.opam b/grpc-eio-core.opam new file mode 100644 index 0000000..84c8e6f --- /dev/null +++ b/grpc-eio-core.opam @@ -0,0 +1,39 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "Benchmarking package for gRPC" +description: "Benchmarking package for gRPC." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +tags: ["network" "rpc" "serialisation" "benchmark"] +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "eio" + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/grpc-eio-io-client-h2-ocaml-protoc.opam b/grpc-eio-io-client-h2-ocaml-protoc.opam new file mode 100644 index 0000000..c2c581c --- /dev/null +++ b/grpc-eio-io-client-h2-ocaml-protoc.opam @@ -0,0 +1,44 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: + "An h2 implementation of gRPC networking layer for eio based clients." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "grpc-client-eio" {= version} + "h2" {>= "0.9.0"} + "pbrt" + "pbrt_services" + "eio" + "h2-eio" + "grpc-eio-core" + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/grpc-eio-io-server-h2-ocaml-protoc.opam b/grpc-eio-io-server-h2-ocaml-protoc.opam new file mode 100644 index 0000000..0a1fcac --- /dev/null +++ b/grpc-eio-io-server-h2-ocaml-protoc.opam @@ -0,0 +1,45 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: + "An h2 implementation of gRPC networking layer for eio based servers." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "grpc-server-eio" {= version} + "h2" {>= "0.9.0"} + "stringext" + "pbrt" + "pbrt_services" + "eio" + "h2-eio" + "grpc-eio-core" + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/grpc-lwt.opam b/grpc-lwt.opam index e6797ef..f87d5c5 100644 --- a/grpc-lwt.opam +++ b/grpc-lwt.opam @@ -19,7 +19,7 @@ doc: "https://dialohq.github.io/ocaml-grpc" bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" depends: [ "dune" {>= "3.7"} - "grpc" {= version} + "grpc-server" {= version} "h2" {>= "0.9.0"} "lwt" {>= "5.3.0"} "stringext" diff --git a/grpc-eio.opam b/grpc-server-eio.opam similarity index 92% rename from grpc-eio.opam rename to grpc-server-eio.opam index 7f00944..29bc4dc 100644 --- a/grpc-eio.opam +++ b/grpc-server-eio.opam @@ -1,6 +1,6 @@ # This file is generated by dune, edit dune-project instead opam-version: "2.0" -synopsis: "An Eio implementation of gRPC" +synopsis: "An Eio implementation of gRPC server" description: "Functionality for building gRPC services and rpcs with `eio`." maintainer: ["Daniel Quernheim "] authors: [ @@ -19,8 +19,7 @@ bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" depends: [ "dune" {>= "3.7"} "eio" {>= "0.12"} - "grpc" {= version} - "h2" {>= "0.9.0"} + "grpc-server" {= version} "stringext" "odoc" {with-doc} ] diff --git a/grpc-server.opam b/grpc-server.opam new file mode 100644 index 0000000..94abdb8 --- /dev/null +++ b/grpc-server.opam @@ -0,0 +1,41 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "Reusable logic for server side gRPC" +description: + "All modules are networking-layer and concurrency-layer agnostic." +maintainer: ["Daniel Quernheim "] +authors: [ + "Andrew Jeffery " + "Daniel Quernheim " + "Michael Bacarella " + "Sven Anderson " + "Tim McGilchrist " + "Wojtek Czekalski " + "dimitris.mostrous " +] +license: "BSD-3-Clause" +tags: ["network" "rpc" "serialisation"] +homepage: "https://github.com/dialohq/ocaml-grpc" +doc: "https://dialohq.github.io/ocaml-grpc" +bug-reports: "https://github.com/dialohq/ocaml-grpc/issues" +depends: [ + "dune" {>= "3.7"} + "ocaml" {>= "4.08"} + "grpc" {= version} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/dialohq/ocaml-grpc.git" diff --git a/grpc.opam b/grpc.opam index 8355be4..1002cdb 100644 --- a/grpc.opam +++ b/grpc.opam @@ -2,7 +2,7 @@ opam-version: "2.0" synopsis: "A modular gRPC library" description: - "This library builds some of the signatures and implementations of gRPC functionality. This is used in the more specialised package `grpc-lwt` which has more machinery, however this library can also be used to do some bits yourself." + "This library contains the implementation of (de)serialization of gRPC messages and statuses." maintainer: ["Daniel Quernheim "] authors: [ "Andrew Jeffery " @@ -22,7 +22,6 @@ depends: [ "dune" {>= "3.7"} "ocaml" {>= "4.08"} "bigstringaf" {>= "0.9.1"} - "h2" {>= "0.9.0"} "ppx_deriving" "uri" {>= "4.0.0"} "odoc" {with-doc} diff --git a/lib/grpc-async/client.ml b/lib/async/client.ml similarity index 100% rename from lib/grpc-async/client.ml rename to lib/async/client.ml diff --git a/lib/grpc-async/client.mli b/lib/async/client.mli similarity index 60% rename from lib/grpc-async/client.mli rename to lib/async/client.mli index 3218b80..63b01c6 100644 --- a/lib/grpc-async/client.mli +++ b/lib/async/client.mli @@ -7,38 +7,42 @@ module Rpc : sig val bidirectional_streaming : handler:(string Pipe.Writer.t -> string Pipe.Reader.t -> 'a Deferred.t) -> 'a handler - (** [bidirectional_streaming ~handler write read] sets up the sending and receiving - logic using [write] and [read], then calls [handler] with a writer pipe and - a reader pipe, for sending and receiving payloads to and from the server. + (** [bidirectional_streaming ~handler write read] sets up the sending and + receiving logic using [write] and [read], then calls [handler] with a + writer pipe and a reader pipe, for sending and receiving payloads to and + from the server. - The stream is closed when the deferred returned by the handler becomes determined. *) + The stream is closed when the deferred returned by the handler becomes + determined. *) val client_streaming : handler:(string Pipe.Writer.t -> string option Deferred.t -> 'a Deferred.t) -> 'a handler (** [client_streaming ~handler write read] sets up the sending and receiving - logic using [write] and [read], then calls [handler] with a writer pipe to send - payloads to the server. + logic using [write] and [read], then calls [handler] with a writer pipe to + send payloads to the server. - The stream is closed when the deferred returned by the handler becomes determined. *) + The stream is closed when the deferred returned by the handler becomes + determined. *) val server_streaming : handler:(string Pipe.Reader.t -> 'a Deferred.t) -> encoded_request:string -> 'a handler - (** [server_streaming ~handler encoded_request write read] sets up the sending and - receiving logic using [write] and [read], then sends [encoded_request] and calls - [handler] with a pipe of responses. + (** [server_streaming ~handler encoded_request write read] sets up the sending + and receiving logic using [write] and [read], then sends [encoded_request] + and calls [handler] with a pipe of responses. - The stream is closed when the deferred returned by the handler becomes determined. *) + The stream is closed when the deferred returned by the handler becomes + determined. *) val unary : handler:(string option -> 'a Deferred.t) -> encoded_request:string -> 'a handler - (** [unary ~handler ~encoded_request] sends the encoded request to the server . When the - response is received, the handler is called with an option response. The response is - is None if the server sent an empty response. *) + (** [unary ~handler ~encoded_request] sends the encoded request to the server + . When the response is received, the handler is called with an option + response. The response is is None if the server sent an empty response. *) end type response_handler = H2.Client_connection.response_handler @@ -60,6 +64,6 @@ val call : ?headers:H2.Headers.t -> unit -> ('a * Grpc.Status.t, H2.Status.t) Core._result Deferred.t -(** [call ~service ~rpc ~handler ~do_request ()] calls the rpc endpoint given - by [service] and [rpc] using the [do_request] function. The [handler] is - called when this request is set up to send and receive data. *) +(** [call ~service ~rpc ~handler ~do_request ()] calls the rpc endpoint given by + [service] and [rpc] using the [do_request] function. The [handler] is called + when this request is set up to send and receive data. *) diff --git a/lib/grpc-async/connection.ml b/lib/async/connection.ml similarity index 100% rename from lib/grpc-async/connection.ml rename to lib/async/connection.ml diff --git a/lib/grpc-async/dune b/lib/async/dune similarity index 100% rename from lib/grpc-async/dune rename to lib/async/dune diff --git a/lib/grpc-async/grpc_async.ml b/lib/async/grpc_async.ml similarity index 100% rename from lib/grpc-async/grpc_async.ml rename to lib/async/grpc_async.ml diff --git a/lib/grpc-async/server.ml b/lib/async/server.ml similarity index 100% rename from lib/grpc-async/server.ml rename to lib/async/server.ml diff --git a/lib/grpc-async/server.mli b/lib/async/server.mli similarity index 76% rename from lib/grpc-async/server.mli rename to lib/async/server.mli index ed131f8..8023875 100644 --- a/lib/grpc-async/server.mli +++ b/lib/async/server.mli @@ -7,15 +7,18 @@ module Rpc : sig type client_streaming = string Pipe.Reader.t -> (Grpc.Status.t * string option) Deferred.t - (** [client_streaming] is the type for an rpc where the client streams the requests and the server responds once. *) + (** [client_streaming] is the type for an rpc where the client streams the + requests and the server responds once. *) type server_streaming = string -> string Pipe.Writer.t -> Grpc.Status.t Deferred.t - (** [server_streaming] is the type for an rpc where the client sends one request and the server sends multiple responses. *) + (** [server_streaming] is the type for an rpc where the client sends one + request and the server sends multiple responses. *) type bidirectional_streaming = string Pipe.Reader.t -> string Pipe.Writer.t -> Grpc.Status.t Deferred.t - (** [bidirectional_streaming] is the type for an rpc where both the client and server can send multiple messages. *) + (** [bidirectional_streaming] is the type for an rpc where both the client and + server can send multiple messages. *) type t = | Unary of unary @@ -25,29 +28,36 @@ module Rpc : sig (** [t] represents the types of rpcs available in gRPC. *) val unary : f:unary -> H2.Reqd.t -> unit Deferred.t - (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and handles sending the response. *) + (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and + handles sending the response. *) val client_streaming : f:client_streaming -> H2.Reqd.t -> unit Deferred.t - (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from and handles sending the response. *) + (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from + and handles sending the response. *) val server_streaming : f:server_streaming -> H2.Reqd.t -> unit Deferred.t - (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] and handles sending the responses pushed out. *) + (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] + and handles sending the responses pushed out. *) val bidirectional_streaming : f:bidirectional_streaming -> H2.Reqd.t -> unit Deferred.t - (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests from and andles sending the responses pushed out. *) + (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests + from and andles sending the responses pushed out. *) end module Service : sig type t - (** [t] represents a gRPC service with potentially multiple rpcs and the information needed to route to them. *) + (** [t] represents a gRPC service with potentially multiple rpcs and the + information needed to route to them. *) val v : unit -> t (** [v ()] creates a new service *) val add_rpc : name:string -> rpc:Rpc.t -> t -> t - (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to it with [name]. *) + (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to + it with [name]. *) val handle_request : t -> H2.Reqd.t -> unit - (** [handle_request t reqd] handles routing [reqd] to the correct rpc if available in [t]. *) + (** [handle_request t reqd] handles routing [reqd] to the correct rpc if + available in [t]. *) end diff --git a/lib/eio/arpaca/bin/codegen.ml b/lib/eio/arpaca/bin/codegen.ml new file mode 100644 index 0000000..1d2a3fc --- /dev/null +++ b/lib/eio/arpaca/bin/codegen.ml @@ -0,0 +1,300 @@ +open Ocaml_protoc_compiler_lib +module Ot = Pb_codegen_ocaml_type +module F = Pb_codegen_formatting + +let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string = + match rpc with + | Rpc_scalar ty -> Pb_codegen_util.string_of_field_type ty + | Rpc_stream ty -> Pb_codegen_util.string_of_field_type ty + +let rpc_kind (req : Ot.rpc_type) (res : Ot.rpc_type) = + match (req, res) with + | Rpc_scalar _, Rpc_scalar _ -> `Unary + | Rpc_scalar _, Rpc_stream _ -> `Server_streaming + | Rpc_stream _, Rpc_scalar _ -> `Client_streaming + | Rpc_stream _, Rpc_stream _ -> `Bidirectional_streaming + +let function_name_encode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string + = + let f ty = + match ty with + | Ot.Ft_unit -> "(fun () enc -> Pbrt.Encoder.empty_nested enc)" + | Ot.Ft_user_defined_type udt -> + let function_prefix = "encode_pb" in + Pb_codegen_util.function_name_of_user_defined ~function_prefix udt + | _ -> + Printf.eprintf "cannot binary-encode request for %s in service %s\n%!" + rpc_name service_name; + exit 1 + in + match ty with Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty + +let function_name_decode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string + = + let f ty = + match ty with + | Ot.Ft_unit -> "(fun d -> Pbrt.Decoder.empty_nested d)" + | Ot.Ft_user_defined_type udt -> + let function_prefix = "decode_pb" in + Pb_codegen_util.function_name_of_user_defined ~function_prefix udt + | _ -> + Printf.eprintf "cannot decode binary request for %s in service %s\n%!" + rpc_name service_name; + exit 1 + in + match ty with Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty + +let to_snake_case = + let regex = + Re.replace (Re.compile Re.upper) ~f:(fun g -> + if Re.Group.start g 0 > 0 then + "_" ^ String.lowercase_ascii (Re.Group.get g 0) + else Re.Group.get g 0) + in + fun str -> regex str + +let service_name_of_package path = String.concat "." path + +let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit + = + let typ_mod_name = String.capitalize_ascii proto_gen_module in + let service_name = service.service_name in + let gen_rpc sc i (rpc : Ot.rpc) = + if i > 0 then F.empty_line sc; + let rpc_name = rpc.rpc_name in + match rpc_kind rpc.rpc_req rpc.rpc_res with + | `Unary -> + F.linep sc + {|let %s ~sw ~io request = + let response = + Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s.%s" + ~method_name:%S + ~headers:(Grpc_client.make_request_headers `Proto) + (%s.%s request) + in + match response with + | `Success ({ response = res; _ } as result) -> + `Success + { + result with + response = + res.Grpc_eio_core.Body_reader.consume %s.%s; + } + | ( `Premature_close _ | `Write_error _ | `Connection_error _ + | `Response_not_ok _ ) as rest -> + rest|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Server_streaming -> + F.linep sc + {|let %s ~sw ~io request handler = + Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (%s.%s request) (fun net_response ~read -> + let responses = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + %s.%s) + read + in + handler net_response responses)|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Client_streaming -> + F.linep sc + {|let %s ~sw ~io handler = + let response = + Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer -> + let writer' req = writer.write (%s.%s req) in + handler net_response ~writer:writer') + in + match response with + | `Success ({ response = res; _ } as result) -> + `Success + { + result with + response = + res.Grpc_eio_core.Body_reader.consume + %s.%s; + } + | ( `Premature_close _ | `Stream_error _ | `Connection_error _ + | `Response_not_ok _ ) as rest -> + rest|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Bidirectional_streaming -> + F.linep sc + {|let %s ~sw ~io handler = + Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer ~read -> + let writer' req = writer.write (%s.%s req) in + let read' = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + %s.%s) + read + in + handler net_response ~writer:writer' ~read:read')|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + in + List.iteri (gen_rpc sc) service.service_body + +let gen_service_server_struct ~proto_gen_module (service : Ot.service) top_scope + : unit = + let typ_mod_name = String.capitalize_ascii proto_gen_module in + let gen_rpc_sig sc i (rpc : Ot.rpc) = + if i > 0 then F.empty_line sc; + let name = Pb_codegen_util.function_name_of_rpc rpc in + + F.linep sc "val %s :" (to_snake_case name); + F.linep sc " Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t ->"; + let req_type = + Printf.sprintf "%s.%s" typ_mod_name (ocaml_type_of_rpc_type rpc.rpc_req) + in + let res_type = + Printf.sprintf "%s.%s" typ_mod_name (ocaml_type_of_rpc_type rpc.rpc_res) + in + match rpc_kind rpc.rpc_req rpc.rpc_res with + | `Unary -> + F.linep sc {| %s -> + %s * (string * string) list|} req_type res_type + | `Client_streaming -> + F.linep sc + {| %s Seq.t -> + %s * (string * string) list|} + req_type res_type + | `Server_streaming -> + F.linep sc + {| %s -> + (%s -> unit) -> + (string * string) list|} + req_type res_type + | `Bidirectional_streaming -> + F.linep sc + {| %s Seq.t -> + (%s -> unit) -> + (string * string) list|} + req_type res_type + in + + let gen_impl_sig sc = + List.iteri (gen_rpc_sig sc) service.service_body + (* now generate a function from the module type to a [Service_server.t] *) + in + + let gen_rpc_handler sc (rpc : Ot.rpc) = + let rpc_name = rpc.rpc_name in + let service_name = service.service_name in + + F.linep sc {|| "%s.%s", %S ->|} + (String.concat "." service.service_packages) + service.service_name rpc.rpc_name; + let impl = Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case in + + let decoder_func = + Printf.sprintf "%s.%s" typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req) + in + let encoder_func = + Printf.sprintf "%s.%s" typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_res) + in + + let p = F.linep in + let sub = F.sub_scope in + + sub sc (fun sc -> + p sc {|fun req { Grpc_server_eio.Rpc.accept } ->|}; + sub sc (fun sc -> + p sc {|accept Grpc_server.headers_grpc_proto|}; + sub sc (fun sc -> + match rpc_kind rpc.rpc_req rpc.rpc_res with + | `Unary -> + p sc {|(Grpc_server_eio.Rpc.unary (fun grpc_req ->|}; + F.line sc {|let response, trailers =|}; + sub sc (fun sc -> + p sc + {|Impl.%s req (grpc_req.Grpc_eio_core.Body_reader.consume %s)|} + impl decoder_func); + F.line sc "in"; + p sc {|((%s response), trailers )))|} encoder_func + | `Client_streaming -> + p sc + {|(Grpc_server_eio.Rpc.client_streaming (fun grpc_req_seq ->|}; + p sc {|let response, trailers =|}; + sub sc (fun sc -> + p sc {|Impl.%s req|} impl; + sub sc (fun sc -> + p sc {|(Seq.map (fun grpc_req ->|}; + sub sc (fun sc -> + p sc + {|grpc_req.Grpc_eio_core.Body_reader.consume %s|} + decoder_func); + p sc {|) grpc_req_seq)|})); + p sc "in"; + p sc {|((%s response), trailers)))|} encoder_func + | `Server_streaming -> + p sc + {|(Grpc_server_eio.Rpc.server_streaming (fun grpc_req write ->|}; + p sc {|let trailers =|}; + sub sc (fun sc -> + p sc {|Impl.%s req|} impl; + sub sc (fun sc -> + p sc + {|(grpc_req.Grpc_eio_core.Body_reader.consume %s)|} + decoder_func; + p sc {|(fun resp -> write (%s resp))|} encoder_func)); + p sc "in"; + p sc {|trailers))|} + | `Bidirectional_streaming -> + p sc {|(fun grpc_req_seq write ->|}; + p sc {|let trailers =|}; + sub sc (fun sc -> + p sc {|Impl.%s req|} impl; + sub sc (fun sc -> + p sc + {|(Seq.map (fun grpc_req -> grpc_req.Grpc_eio_core.Body_reader.consume %s) grpc_req_seq)|} + decoder_func; + p sc {|(fun resp -> write (%s resp))|} encoder_func)); + p sc "in"; + p sc {|trailers)|}))) + in + + let sc = top_scope in + + F.line sc "module type Implementation = sig"; + F.sub_scope sc gen_impl_sig; + F.line sc "end"; + F.empty_line sc; + F.linep sc "let create_server (module Impl : Implementation) ~service ~meth ="; + F.sub_scope sc (fun sc -> + F.linep sc "match (service, meth) with"; + List.iter (gen_rpc_handler sc) service.service_body; + F.linep sc + {|| _ -> + raise (Grpc_server_eio.Server_error (Grpc.Status.make Unimplemented, []))|}) diff --git a/lib/eio/arpaca/bin/codegen.mli b/lib/eio/arpaca/bin/codegen.mli new file mode 100644 index 0000000..128362f --- /dev/null +++ b/lib/eio/arpaca/bin/codegen.mli @@ -0,0 +1,13 @@ +open Ocaml_protoc_compiler_lib + +val gen_service_server_struct : + proto_gen_module:string -> + Pb_codegen_ocaml_type.service -> + Pb_codegen_formatting.scope -> + unit + +val gen_service_client_struct : + proto_gen_module:string -> + Pb_codegen_ocaml_type.service -> + Pb_codegen_formatting.scope -> + unit diff --git a/lib/eio/arpaca/bin/dune b/lib/eio/arpaca/bin/dune new file mode 100644 index 0000000..618b4c0 --- /dev/null +++ b/lib/eio/arpaca/bin/dune @@ -0,0 +1,5 @@ +(executable + (package arpaca) + (name main) + (public_name arpaca-gen) + (libraries grpc-client-eio ocaml-protoc.compiler-lib re cmdliner)) diff --git a/lib/eio/arpaca/bin/main.ml b/lib/eio/arpaca/bin/main.ml new file mode 100644 index 0000000..9665fb9 --- /dev/null +++ b/lib/eio/arpaca/bin/main.ml @@ -0,0 +1,182 @@ +open Ocaml_protoc_compiler_lib +module Pt = Pb_parsing_parse_tree +module Tt = Pb_typing_type_tree + +let find_imported_file include_dirs file_name = + if Sys.file_exists file_name then file_name + else + let found_file = + List.fold_left + (fun found_file include_dir -> + let try_file_name = Filename.concat include_dir file_name in + match (found_file, Sys.file_exists try_file_name) with + | None, true -> Some try_file_name + | Some previous, true -> + Printf.eprintf + ("[Warning] Imported file %s found in 2 directories, " + ^^ "picking: %s\n") + file_name previous; + found_file + | _, false -> found_file) + None include_dirs + in + + match found_file with + | None -> Pb_exception.import_file_not_found file_name + | Some file_name -> file_name + +let compile proto_file_name include_dirs unsigned_tag = + (* parsing *) + let protos = + Pb_parsing.parse_file + (fun file_name -> + let file_name = find_imported_file include_dirs file_name in + (file_name, Pb_util.read_file file_name)) + proto_file_name + in + + (* file options can be overriden/added with command line arguments *) + let protos = + List.map + (fun proto -> + { + proto with + Pt.file_options = Pb_option.merge proto.Pt.file_options []; + }) + protos + in + + let proto_file_options = + let main_proto = List.hd protos in + main_proto.Pt.file_options + in + + (* typing *) + let typed_proto = Pb_typing.perform_typing protos in + let all_typed_protos = List.flatten typed_proto.proto_types in + + (* Only get the types which are part of the given proto file + (compilation unit) *) + let typed_proto = + { + typed_proto with + Tt.proto_types = + List.filter + (function + | { Tt.file_name; _ } :: _ when file_name = proto_file_name -> true + | _ -> false) + typed_proto.proto_types; + } + in + + (* -- OCaml Backend -- *) + let module BO = Pb_codegen_backend in + let ocaml_proto = + BO.compile ~unsigned_tag ~all_types:all_typed_protos typed_proto + in + (ocaml_proto, proto_file_options) + +open Cmdliner + +(* Validate the protobuf file *) +let validate_proto file = + if Sys.file_exists (Sys.getcwd () ^ "/" ^ file) then Ok file + else Error (`Msg (Printf.sprintf "The protobuf file %s does not exist." file)) + +(* Validate the output directory *) +let validate_output_dir dir = + let open Sys in + if file_exists dir && is_directory dir then Ok dir + else + Error + (`Msg + (Printf.sprintf + "The output directory %s does not exist or is not a directory." dir)) + +let include_path = + let doc = "Include path for protobuf file" in + Arg.(value & opt_all string [] & info [ "I" ] ~docv:"DIR" ~doc) + +let suffix = + let doc = "Include path for protobuf file" in + Arg.(value & opt string "" & info [ "s"; "suffix" ] ~docv:"SUFFIX" ~doc) + +let output_path = + let doc = "Output directory where the files will be written" in + Arg.( + required + & opt (some (conv (validate_output_dir, Format.pp_print_string))) (Some ".") + & info [ "o" ] ~docv:"DIR" ~doc) + +(* Protobuf file argument *) +let proto_file = + let doc = "Protobuf file to process" in + Arg.( + required + & pos 0 (some (conv (validate_proto, Format.pp_print_string))) None + & info [] ~docv:"PROTO" ~doc) + +let prepare proto_file_name include_dirs = + let { Pb_codegen_ocaml_type.proto_services; _ }, _ = + compile proto_file_name include_dirs false + in + let proto_gen_module = + Pb_codegen_util.caml_file_name_of_proto_file_name ~proto_file_name + in + (proto_services, proto_gen_module) + +(* Client command *) +let client_cmd = + let doc = "Generate client-side stubs." in + let info = Cmd.info "client" ~doc in + let term = + Term.( + const (fun proto includes output suffix -> + let proto_services, proto_gen_module = prepare proto includes in + List.iter + (fun svc -> + let scope = Pb_codegen_formatting.empty_scope () in + Codegen.gen_service_client_struct ~proto_gen_module svc scope; + + let out = + Out_channel.open_text + (output ^ "/" ^ svc.service_name ^ suffix ^ ".ml") + in + + Pb_codegen_all.F.output out scope) + proto_services) + $ proto_file $ include_path $ output_path $ suffix) + in + Cmd.v info term + +(* Server command *) +let server_cmd = + let doc = "Generate server side stubs" in + let info = Cmd.info "server" ~doc in + let term = + Term.( + const (fun proto includes output suffix -> + let proto_services, proto_gen_module = prepare proto includes in + List.iter + (fun svc -> + let scope = Pb_codegen_formatting.empty_scope () in + Codegen.gen_service_server_struct ~proto_gen_module svc scope; + + let out = + Out_channel.open_text + (output ^ "/" ^ svc.service_name ^ suffix ^ ".ml") + in + + Pb_codegen_all.F.output out scope) + proto_services) + $ proto_file $ include_path $ output_path $ suffix) + in + Cmd.v info term + +(* Main command *) +let cmds = [ client_cmd; server_cmd ] + +let () = + let doc = "A command-line tool with client and server modes." in + let info = Cmd.info "command" ~doc in + exit (Cmd.eval (Cmd.group info cmds)) diff --git a/lib/eio/arpaca/codegen_tests/dune b/lib/eio/arpaca/codegen_tests/dune new file mode 100644 index 0000000..719fc86 --- /dev/null +++ b/lib/eio/arpaca/codegen_tests/dune @@ -0,0 +1,2 @@ +(cram + (deps route_guide.proto %{bin:arpaca-gen})) diff --git a/lib/eio/arpaca/codegen_tests/route_guide.proto b/lib/eio/arpaca/codegen_tests/route_guide.proto new file mode 100644 index 0000000..0263ea9 --- /dev/null +++ b/lib/eio/arpaca/codegen_tests/route_guide.proto @@ -0,0 +1,105 @@ +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package routeguide; + +// Interface exported by the server. +service RouteGuide { + // A simple RPC. + // + // Obtains the feature at a given position. + // + // A feature with an empty name is returned if there's no feature at the given + // position. + rpc GetFeature(Point) returns (Feature) {} + + // A server-to-client streaming RPC. + // + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + rpc ListFeatures(Rectangle) returns (stream Feature) {} + + // A client-to-server streaming RPC. + // + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + rpc RecordRoute(stream Point) returns (RouteSummary) {} + + // A Bidirectional streaming RPC. + // + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} +} + +// Points are represented as latitude-longitude pairs in the E7 representation +// (degrees multiplied by 10**7 and rounded to the nearest integer). +// Latitudes should be in the range +/- 90 degrees and longitude should be in +// the range +/- 180 degrees (inclusive). +message Point { + int32 latitude = 1; + int32 longitude = 2; +} + +// A latitude-longitude rectangle, represented as two diagonally opposite +// points "lo" and "hi". +message Rectangle { + // One corner of the rectangle. + Point lo = 1; + + // The other corner of the rectangle. + Point hi = 2; +} + +// A feature names something at a given point. +// +// If a feature could not be named, the name is empty. +message Feature { + // The name of the feature. + string name = 1; + + // The point where the feature is detected. + Point location = 2; +} + +// A RouteNote is a message sent while at a given point. +message RouteNote { + // The location from which the message is sent. + Point location = 1; + + // The message to be sent. + string message = 2; +} + +// A RouteSummary is received in response to a RecordRoute rpc. +// +// It contains the number of individual points received, the number of +// detected features, and the total distance covered as the cumulative sum of +// the distance between each point. +message RouteSummary { + // The number of points received. + int32 point_count = 1; + + // The number of known features passed while traversing the route. + int32 feature_count = 2; + + // The distance covered in metres. + int32 distance = 3; + + // The duration of the traversal in seconds. + int32 elapsed_time = 4; +} diff --git a/lib/eio/arpaca/codegen_tests/test_client.t b/lib/eio/arpaca/codegen_tests/test_client.t new file mode 100644 index 0000000..05e5ea6 --- /dev/null +++ b/lib/eio/arpaca/codegen_tests/test_client.t @@ -0,0 +1,73 @@ +Simplest possible cram test + $ arpaca-gen client route_guide.proto -o . + $ cat RouteGuide.ml + let get_feature ~sw ~io request = + let response = + Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"GetFeature" + ~headers:(Grpc_client.make_request_headers `Proto) + (Route_guide.encode_pb_point request) + in + match response with + | `Success ({ response = res; _ } as result) -> + `Success + { + result with + response = + res.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_feature; + } + | ( `Premature_close _ | `Write_error _ | `Connection_error _ + | `Response_not_ok _ ) as rest -> + rest + + let list_features ~sw ~io request handler = + Client.Server_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"ListFeatures" + ~headers:(Grpc_client.make_request_headers `Proto) + (Route_guide.encode_pb_rectangle request) (fun net_response ~read -> + let responses = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_feature) + read + in + handler net_response responses) + + let record_route ~sw ~io handler = + let response = + Client.Client_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"RecordRoute" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer -> + let writer' req = writer (Route_guide.encode_pb_point request) in + handler net_response ~writer:writer') + in + match response with + | `Success ({ response = res; _ } as result) -> + `Success + { + result with + response = + res.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_route_summary; + } + | ( `Premature_close _ | `Write_error _ | `Connection_error _ + | `Response_not_ok _ ) as rest -> + rest + + let route_chat ~sw ~io handler = + Client.Bidirectional_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"RouteChat" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer ~read -> + let writer' req = writer (Route_guide.encode_pb_route_note request) in + let read' = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_route_note) + read + in + handler net_response ~writer:writer' ~read:read') + diff --git a/lib/eio/arpaca/codegen_tests/test_server.t b/lib/eio/arpaca/codegen_tests/test_server.t new file mode 100644 index 0000000..01df8c3 --- /dev/null +++ b/lib/eio/arpaca/codegen_tests/test_server.t @@ -0,0 +1,70 @@ +Simplest possible cram test + $ arpaca-gen server -o . --suffix _server route_guide.proto + $ cat RouteGuide_server.ml + module type Implementation = sig + val get_feature : + H2.Request.t -> + Route_guide.point -> + Route_guide.feature * (string * string) list + + val list_features : + H2.Request.t -> + Route_guide.rectangle -> + (Route_guide.feature -> unit) -> + (string * string) list + + val record_route : + H2.Request.t -> + Route_guide.point Seq.t -> + Route_guide.route_summary * (string * string) list + + val route_chat : + H2.Request.t -> + Route_guide.route_note Seq.t -> + (Route_guide.route_note -> unit) -> + (string * string) list + end + + let create_server (module Impl : Implementation) ~service ~meth = + match (service, meth) with + | "routeguide.RouteGuide", "GetFeature" -> + fun req { Grpc_server_eio.Rpc.accept } -> + accept Grpc_server.headers_grpc_proto + (Grpc_server_eio.Rpc.unary (fun grpc_req -> + let response, trailers = + Impl.get_feature req (grpc_req.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_point) + in + ((Route_guide.encode_pb_feature response), trailers ))) + | "routeguide.RouteGuide", "ListFeatures" -> + fun req { Grpc_server_eio.Rpc.accept } -> + accept Grpc_server.headers_grpc_proto + (Grpc_server_eio.Rpc.server_streaming (fun grpc_req write -> + let trailers = + Impl.list_features req + (grpc_req.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_rectangle) + (fun resp -> write (Route_guide.encode_pb_feature resp)) + in + trailers)) + | "routeguide.RouteGuide", "RecordRoute" -> + fun req { Grpc_server_eio.Rpc.accept } -> + accept Grpc_server.headers_grpc_proto + (Grpc_server_eio.Rpc.client_streaming (fun grpc_req_seq -> + let response, trailers = + Impl.record_route req + (Seq.map (fun grpc_req -> + grpc_req.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_point + ) grpc_req_seq) + in + ((Route_guide.encode_pb_route_summary response), trailers))) + | "routeguide.RouteGuide", "RouteChat" -> + fun req { Grpc_server_eio.Rpc.accept } -> + accept Grpc_server.headers_grpc_proto + (fun grpc_req_seq write -> + let trailers = + Impl.route_chat req + (Seq.map (fun grpc_req -> grpc_req.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_route_note) grpc_req_seq) + (fun resp -> write (Route_guide.encode_pb_route_note resp)) + in + trailers) + | _ -> + raise (Grpc_server_eio.Server_error (Grpc.Status.make Unimplemented, [])) diff --git a/lib/eio/arpaca/integration_tests/client.ml b/lib/eio/arpaca/integration_tests/client.ml new file mode 100644 index 0000000..75a0362 --- /dev/null +++ b/lib/eio/arpaca/integration_tests/client.ml @@ -0,0 +1,128 @@ +let print_features sw io = + let rectangle = + Route_guide.default_rectangle + ~lo: + (Some + (Route_guide.default_point ~latitude:400000000 + ~longitude:(-750000000) ())) + ~hi: + (Some + (Route_guide.default_point ~latitude:420000000 + ~longitude:(-730000000) ())) + () + in + + let stream = + RouteGuide_client.list_features ~sw ~io rectangle (fun _ read -> + Seq.iter + (fun f -> + Printf.printf "RESPONSE = {%s}%!" (Route_guide.show_feature f)) + read) + in + match stream with + | `Stream_result { err = None; _ } -> () + | _ -> failwith "an erra" + +let random_point () = + let latitude = (Random.int 180 - 90) * 10000000 in + let longitude = (Random.int 360 - 180) * 10000000 in + Route_guide.default_point ~latitude ~longitude () + +let run_record_route sw io = + let points = + Random.int 100 + |> Seq.unfold (function 0 -> None | x -> Some (random_point (), x - 1)) + in + + let response = + RouteGuide_client.record_route ~io ~sw (fun _ ~writer -> + Seq.iter + (fun point -> + writer point |> ignore; + Printf.printf "SENT = {%s}\n%!" (Route_guide.show_point point)) + points) + in + match response with + | `Success { response; _ } -> + Printf.printf "SUMMARY = {%s}\n%!" + (Route_guide.show_route_summary response) + | _ -> failwith "Error occured" + +let run_route_chat clock io sw = + (* Generate locations. *) + let location_count = 5 in + Printf.printf "Generating %i locations\n" location_count; + let route_notes = + location_count + |> Seq.unfold (function + | 0 -> None + | x -> + Some + ( Route_guide.default_route_note + ~location:(Some (random_point ())) + ~message:(Printf.sprintf "Random Message %i" x) + (), + x - 1 )) + in + (* $MDX part-end *) + (* $MDX part-begin=client-route-chat-2 *) + let rec go ~send reader notes = + match Seq.uncons notes with + | None -> () (* Signal no more notes from the server. *) + | Some (route_note, xs) -> ( + send route_note |> ignore; + + Eio.Time.sleep clock 1.0; + + match reader () with + | Seq.Nil -> failwith "Expecting response" + | Seq.Cons (route_note, reader') -> + Printf.printf "NOTE = {%s}\n%!" + (Route_guide.show_route_note route_note); + go ~send reader' xs) + in + let result = + RouteGuide_client.route_chat ~io ~sw (fun _ ~writer ~read -> + go ~send:writer read route_notes; + []) + in + match result with + | `Stream_result { err = None; _ } -> () + | _e -> failwith "Error" + +let main env = + let clock = Eio.Stdenv.clock env in + let network = Eio.Stdenv.net env in + let () = Random.self_init () in + + let run sw = + let io = + Io_client_h2_ocaml_protoc.create_client ~net:network ~sw + "http://localhost:8080" + in + + Printf.printf "*** SIMPLE RPC ***\n%!"; + + let result = + RouteGuide_client.get_feature ~sw ~io + (Route_guide.default_point ~latitude:409146138 ~longitude:(-746188906) + ()) + in + Printf.printf "RESPONSE = {%s}\n%!" + (match result with + | `Success { response; _ } -> Route_guide.show_feature response + | _ -> failwith "Error occured"); + + Printf.printf "\n*** SERVER STREAMING ***\n%!"; + print_features sw io; + + Printf.printf "\n*** CLIENT STREAMING ***\n%!"; + run_record_route sw io; + + Printf.printf "\n*** BIDIRECTIONAL STREAMING ***\n%!"; + run_route_chat clock io sw + in + + Eio.Switch.run run + +let () = Eio_main.run main diff --git a/lib/eio/arpaca/integration_tests/dune b/lib/eio/arpaca/integration_tests/dune new file mode 100644 index 0000000..c7067c9 --- /dev/null +++ b/lib/eio/arpaca/integration_tests/dune @@ -0,0 +1,39 @@ +(rule + (targets route_guide.ml) + (deps + (:proto route_guide.proto)) + (action + (run + ocaml-protoc --ocaml_all_types_ppx "deriving show" --int32_type int_t --int64_type int_t --binary --ml_out ./ %{proto}))) + +(rule + (targets RouteGuide_client.ml) + (deps + (:proto route_guide.proto) %{bin:arpaca-gen}) + (action + (run + arpaca-gen client -o ./ --suffix _client %{proto}))) + +(rule + (targets RouteGuide_server.ml) + (deps + (:proto route_guide.proto) %{bin:arpaca-gen}) + (action + (run + arpaca-gen server -o ./ --suffix _server %{proto}))) + +(executables + (names server client) + (libraries + grpc-server-eio + grpc-client-eio + eio_main + ocaml-protoc-plugin + routeguide + yojson + ppx_deriving_yojson.runtime + io_server_h2_ocaml_protoc + io_client_h2_ocaml_protoc) + (preprocess + (pps ppx_deriving_yojson ppx_deriving.eq ppx_deriving.show))) + diff --git a/lib/eio/arpaca/integration_tests/route_guide.proto b/lib/eio/arpaca/integration_tests/route_guide.proto new file mode 100644 index 0000000..0263ea9 --- /dev/null +++ b/lib/eio/arpaca/integration_tests/route_guide.proto @@ -0,0 +1,105 @@ +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package routeguide; + +// Interface exported by the server. +service RouteGuide { + // A simple RPC. + // + // Obtains the feature at a given position. + // + // A feature with an empty name is returned if there's no feature at the given + // position. + rpc GetFeature(Point) returns (Feature) {} + + // A server-to-client streaming RPC. + // + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + rpc ListFeatures(Rectangle) returns (stream Feature) {} + + // A client-to-server streaming RPC. + // + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + rpc RecordRoute(stream Point) returns (RouteSummary) {} + + // A Bidirectional streaming RPC. + // + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} +} + +// Points are represented as latitude-longitude pairs in the E7 representation +// (degrees multiplied by 10**7 and rounded to the nearest integer). +// Latitudes should be in the range +/- 90 degrees and longitude should be in +// the range +/- 180 degrees (inclusive). +message Point { + int32 latitude = 1; + int32 longitude = 2; +} + +// A latitude-longitude rectangle, represented as two diagonally opposite +// points "lo" and "hi". +message Rectangle { + // One corner of the rectangle. + Point lo = 1; + + // The other corner of the rectangle. + Point hi = 2; +} + +// A feature names something at a given point. +// +// If a feature could not be named, the name is empty. +message Feature { + // The name of the feature. + string name = 1; + + // The point where the feature is detected. + Point location = 2; +} + +// A RouteNote is a message sent while at a given point. +message RouteNote { + // The location from which the message is sent. + Point location = 1; + + // The message to be sent. + string message = 2; +} + +// A RouteSummary is received in response to a RecordRoute rpc. +// +// It contains the number of individual points received, the number of +// detected features, and the total distance covered as the cumulative sum of +// the distance between each point. +message RouteSummary { + // The number of points received. + int32 point_count = 1; + + // The number of known features passed while traversing the route. + int32 feature_count = 2; + + // The distance covered in metres. + int32 distance = 3; + + // The duration of the traversal in seconds. + int32 elapsed_time = 4; +} diff --git a/lib/eio/arpaca/integration_tests/server.ml b/lib/eio/arpaca/integration_tests/server.ml new file mode 100644 index 0000000..e17f75e --- /dev/null +++ b/lib/eio/arpaca/integration_tests/server.ml @@ -0,0 +1,171 @@ +module R = Route_guide + +type location = R.point = { latitude : int; longitude : int } +[@@deriving yojson] + +type feature = { name : string; location : location } [@@deriving yojson] +type feature_list = feature list [@@deriving yojson] + +let features : feature list ref = ref [] + +module RouteNotesMap = Hashtbl.Make (struct + type t = Route_guide.point + + let equal = ( = ) + let hash s = Hashtbl.hash s +end) + +(** Load route_guide data from a JSON file. *) +let load path : feature list = + let json = Yojson.Safe.from_file path in + match feature_list_of_yojson json with Ok v -> v | Error err -> failwith err + +let in_range (point : R.point) (rect : R.rectangle) : bool = + let lo = Option.get rect.lo in + let hi = Option.get rect.hi in + + let left = Int.min lo.longitude hi.longitude in + let right = Int.max lo.longitude hi.longitude in + let top = Int.max lo.latitude hi.latitude in + let bottom = Int.min lo.latitude hi.latitude in + + point.longitude >= left && point.longitude <= right + && point.latitude >= bottom && point.latitude <= top + +let pi = 4. *. atan 1. +let radians_of_degrees = ( *. ) (pi /. 180.) + +(* Calculates the distance between two points using the "haversine" formula. *) +(* This code was taken from http://www.movable-type.co.uk/scripts/latlong.html. *) +let calc_distance (p1 : R.point) (p2 : R.point) : int = + let cord_factor = 1e7 in + let r = 6_371_000.0 in + (* meters *) + let lat1 = Float.of_int p1.latitude /. cord_factor in + let lat2 = Float.of_int p2.latitude /. cord_factor in + let lng1 = Float.of_int p1.longitude /. cord_factor in + let lng2 = Float.of_int p2.longitude /. cord_factor in + + let lat_rad1 = radians_of_degrees lat1 in + let lat_rad2 = radians_of_degrees lat2 in + + let delta_lat = radians_of_degrees (lat2 -. lat1) in + let delta_lng = radians_of_degrees (lng2 -. lng1) in + + let a = + (sin (delta_lat /. 2.0) *. sin (delta_lat /. 2.0)) + +. cos lat_rad1 *. cos lat_rad2 + *. sin (delta_lng /. 2.0) + *. sin (delta_lng /. 2.0) + in + let c = 2.0 *. atan2 (sqrt a) (sqrt (1.0 -. a)) in + Float.to_int (r *. c) + +let serve server env : unit = + let port = 8080 in + let net = Eio.Stdenv.net env in + let addr = `Tcp (Eio.Net.Ipaddr.V4.loopback, port) in + Eio.Switch.run @@ fun sw -> + let server_socket = + Eio.Net.listen net ~sw ~reuse_addr:true ~backlog:10 addr + in + let connection_handler client_addr socket = + Eio.Switch.run (fun sw -> + Io_server_h2_ocaml_protoc.connection_handler ~sw server client_addr + socket) + in + Eio.Net.run_server + ~on_error:(fun exn -> Eio.traceln "%s" (Printexc.to_string exn)) + server_socket connection_handler + +let () = + let path = + if Array.length Sys.argv > 1 then Sys.argv.(1) + else failwith "Path to datafile required." + in + + (* Load features. *) + features := load path; + Eio_main.run (fun env -> + let module RouteGuideRpc : RouteGuide_server.Implementation = struct + let get_feature _ point = + Format.printf "%a" Route_guide.pp_point point; + Eio.traceln "GetFeature = {:%s}" (R.show_point point); + let feature = + List.find_opt (fun (f : feature) -> f.location = point) !features + |> Option.map (fun { location; name } : R.feature -> + { R.name; location = Some location }) + in + Eio.traceln "Found feature %s" + (feature |> Option.map R.show_feature + |> Option.value ~default:"Missing"); + match feature with + | Some feature -> (feature, []) + | None -> + (* No feature was found, return an unnamed feature. *) + (R.default_feature ~location:(Some point) (), []) + + let list_features _ rectangle (write : R.feature -> unit) = + List.iter + (fun feature -> + if in_range feature.location rectangle then + write + { R.location = Some feature.location; name = feature.name } + else ()) + !features; + [] + + let record_route _ read = + let clock = Eio.Stdenv.clock env in + Eio.traceln "RecordRoute"; + let last_point = ref None in + let start = Eio.Time.now clock in + + let point_count, feature_count, distance = + Seq.fold_left + (fun (point_count, feature_count, distance) point -> + Eio.traceln " ==> Point = {%s}" (Route_guide.show_point point); + + (* Increment the point count *) + let point_count = point_count + 1 in + + (* Find features *) + let feature_count = + List.find_all + (fun (feature : feature) -> feature.location = point) + !features + |> fun x -> List.length x + feature_count + in + + (* Calculate the distance *) + let distance = + match !last_point with + | Some last_point -> calc_distance last_point point + | None -> distance + in + last_point := Some point; + (point_count, feature_count, distance)) + (0, 0, 0) read + in + ( ({ + R.point_count; + feature_count; + distance; + elapsed_time = Eio.Time.now clock -. start |> Float.to_int; + } + : R.route_summary), + [] ) + + let route_chat _ read write = + Printf.printf "RouteChat\n%!"; + + Seq.iter + (fun note -> + Printf.printf " ==> Note = {%s}\n%!" + (Route_guide.show_route_note note); + write note) + read; + Printf.printf "RouteChat exit\n%!"; + [] + end in + serve (RouteGuide_server.create_server (module RouteGuideRpc)) env) diff --git a/lib/eio/client/client.ml b/lib/eio/client/client.ml new file mode 100644 index 0000000..afd7ae2 --- /dev/null +++ b/lib/eio/client/client.ml @@ -0,0 +1,428 @@ +type ('net_response, 'response, 'stream_err, 'headers) recv = { + net_response : 'net_response; + recv_seq : ('response, 'stream_err) Grpc_eio_core.Recv_seq.t; + trailers : 'headers Eio.Promise.t; +} + +type 'request writer = { + write : 'request -> bool; + (* Returns true if the write was successful, false if the stream is in error state. Throws if the stream was closed. *) + close : unit -> unit; +} + +type ('net_response, 'headers) resp_not_ok = { + net_response : 'net_response; + grpc_status : Grpc.Status.t; + trailers : 'headers; +} + +type ('net_response, + 'headers, + 'request, + 'response, + 'conn_error, + 'stream_error) + connection = { + writer : 'request writer; + recv : + ( ('net_response, 'response, 'stream_error, 'headers) recv, + 'conn_error ) + result + Eio.Promise.t; + grpc_status : Grpc.Status.t Eio.Promise.t; + write_exn : exn option ref; +} + +type ('net_response, 'headers, 'conn_err) common_error = + [ `Connection_error of 'conn_err + | `Response_not_ok of ('net_response, 'headers) resp_not_ok ] + +let call (type headers net_response request response stream_error conn_error) + ~sw + ~(io : + (headers, net_response, request, response, stream_error, conn_error) Io.t) + ~service ~method_name ~(headers : Grpc_client.request_headers) () : + ( ( net_response, + headers, + request, + response, + conn_error, + stream_error ) + connection, + conn_error ) + result = + let (module Io') = io in + let path = Grpc_client.make_path ~service ~method_name in + match Io'.send_request ~headers path with + | Error conn_error -> Error conn_error + | Ok (writer', recv_net) -> + let write_exn = ref None in + let writer = + { + write = + (fun req -> + try + writer'.write req; + true + with exn -> + write_exn := Some exn; + false); + close = writer'.close; + } + in + let status, status_notify = Eio.Promise.create () in + let recv, recv_notify = Eio.Promise.create () in + let () = + Eio.Fiber.fork ~sw (fun () -> + Eio.Promise.resolve recv_notify + (match Eio.Promise.await recv_net with + | Error conn_error -> + Eio.Promise.resolve status_notify + (Grpc.Status.make ~error_message:"Connection error" + Grpc.Status.Unknown); + Error conn_error + | Ok { response; next; trailers } -> + Eio.Fiber.fork ~sw (fun () -> + Eio.Promise.resolve status_notify + (Grpc_client.status_of_trailers + ~get_header: + (Io'.Headers.get (Eio.Promise.await trailers)))); + Ok { net_response = response; recv_seq = next; trailers })) + in + Ok { writer; recv; grpc_status = status; write_exn } + +type ('stream_err, 'headers) streaming_err = { + stream_error : 'stream_err option; + write_exn : exn option; + grpc_status : Grpc.Status.t; +} + +type ('a, 'headers, 'stream_err) streaming_result = { + result : 'a; + trailers : 'headers; + err : ('stream_err, 'headers) streaming_err option; +} + +module Bidirectional_streaming = struct + type ('a, 'headers, 'stream_err, 'conn_err, 'net_response) result' = + [ `Stream_result of ('a, 'headers, 'stream_err) streaming_result + | ('net_response, 'headers, 'conn_err) common_error ] + + let call (type headers net_response request response stream_error conn_error) + ~sw + ~(io : + ( headers, + net_response, + request, + response, + stream_error, + conn_error ) + Io.t) ~service ~method_name ~headers f : + (_, headers, stream_error, conn_error, net_response) result' = + match call ~sw ~io ~service ~method_name ~headers () with + | Ok { writer; recv; grpc_status; write_exn } -> ( + match Eio.Promise.await recv with + | Ok { net_response; recv_seq; trailers } -> + let (module Io') = io in + if Io'.Net_response.is_ok net_response then ( + let error = ref None in + let closed = ref false in + let writer = + { + write = writer.write; + close = + (fun () -> + writer.close (); + closed := true); + } + in + let rec read recv_seq' () = + match recv_seq' () with + | Grpc_eio_core.Recv_seq.Done -> Seq.Nil + | Err e -> + let () = error := Some e in + Seq.Nil + | Next (t, next) -> Seq.Cons (t, fun () -> read next ()) + in + + let res = f net_response ~writer ~read:(read recv_seq) in + if not !closed then writer.close (); + match !error with + | Some error -> + `Stream_result + { + result = res; + trailers = Eio.Promise.await trailers; + err = + Some + { + stream_error = Some error; + grpc_status = Eio.Promise.await grpc_status; + write_exn = !write_exn; + }; + } + | None -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | Grpc.Status.OK -> ( + match !write_exn with + | None -> + `Stream_result + { + result = res; + err = None; + trailers = Eio.Promise.await trailers; + } + | Some _ -> + `Stream_result + { + result = res; + trailers = Eio.Promise.await trailers; + err = + Some + { + write_exn = !write_exn; + grpc_status = Eio.Promise.await grpc_status; + stream_error = None; + }; + }) + | _ -> + `Stream_result + { + result = res; + trailers = Eio.Promise.await trailers; + err = + Some + { + grpc_status = status; + stream_error = None; + write_exn = !write_exn; + }; + })) + else + `Response_not_ok + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + } + | Error e -> `Connection_error e) + | Error e -> `Connection_error e +end + +module Unary = struct + type ('net_response, 'headers, 'stream_err) premature_close = { + trailers : 'headers; + grpc_status : Grpc.Status.t; + net_response : 'net_response; + stream_error : 'stream_err option; + } + + type ('net_response, 'response, 'headers) success = { + net_response : 'net_response; + response : 'response; + trailers : 'headers; + } + + type ('response, 'headers, 'stream_err, 'conn_err, 'net_response) result' = + [ `Success of ('net_response, 'response, 'headers) success + | `Premature_close of ('net_response, 'headers, 'stream_err) premature_close + | `Response_not_ok of ('net_response, 'headers) resp_not_ok + | `Connection_error of 'conn_err + | `Write_error of exn ] + + let call (type headers net_response request response stream_error conn_error) + ~sw + ~(io : + ( headers, + net_response, + request, + response, + stream_error, + conn_error ) + Io.t) ~service ~method_name ~headers request : + (_, headers, stream_error, conn_error, net_response) result' = + match call ~sw ~io ~service ~method_name ~headers () with + | Ok { writer; recv; grpc_status; write_exn } -> ( + try + if not (writer.write request) then + `Write_error (Option.get !write_exn) + else ( + writer.close (); + match Eio.Promise.await recv with + | Ok { net_response; recv_seq; trailers } -> + let (module Io') = io in + if Io'.Net_response.is_ok net_response then + match recv_seq () with + | Grpc_eio_core.Recv_seq.Done -> + `Premature_close + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + stream_error = None; + } + | Err stream_error -> + `Premature_close + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + stream_error = Some stream_error; + } + | Next (response, _) -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | OK -> + `Success + { + net_response; + response; + trailers = Eio.Promise.await trailers; + } + | _ -> + (* Not reachable under normal circumstances + https://github.com/grpc/grpc/issues/12824 *) + `Response_not_ok + { + net_response; + grpc_status = status; + trailers = Eio.Promise.await trailers; + }) + else + `Response_not_ok + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + } + | Error e -> `Connection_error e) + with exn -> `Write_error exn) + | Error e -> `Connection_error e +end + +module Client_streaming = struct + type ('a, 'headers, 'stream_err) stream_err = { + trailers : 'headers; + grpc_status : Grpc.Status.t; + result : 'a; + stream_error : 'stream_err; + write_exn : exn option; + } + + type ('a, 'response, 'headers) success = { + result : 'a; + response : 'response; + trailers : 'headers; + write_exn : exn option; + } + + type ('a, 'headers) premature_close = { + result : 'a; + trailers : 'headers; + grpc_status : Grpc.Status.t; + write_exn : exn option; + } + + type ('a, 'headers, 'stream_err, 'conn_err, 'net_response, 'response) result' = + [ `Success of ('a, 'response, 'headers) success + | `Premature_close of ('a, 'headers) premature_close + | `Stream_error of ('a, 'headers, 'stream_err) stream_err + | ('net_response, 'headers, 'conn_err) common_error ] + + let call (type headers net_response request response stream_error conn_error) + ~sw + ~(io : + ( headers, + net_response, + request, + response, + stream_error, + conn_error ) + Io.t) ~service ~method_name ~headers f : + (_, headers, stream_error, conn_error, net_response, response) result' = + match call ~sw ~io ~service ~method_name ~headers () with + | Ok { writer; recv; grpc_status; write_exn } -> ( + match Eio.Promise.await recv with + | Error e -> `Connection_error e + | Ok { net_response; recv_seq; trailers } -> + let (module Io') = io in + if Io'.Net_response.is_ok net_response then ( + let closed = ref false in + let writer = + { + write = writer.write; + close = + (fun () -> + writer.close (); + closed := true); + } + in + + let res = f net_response ~writer in + if not !closed then writer.close (); + + match recv_seq () with + | Grpc_eio_core.Recv_seq.Done -> + `Premature_close + { + result = res; + trailers = Eio.Promise.await trailers; + grpc_status = Eio.Promise.await grpc_status; + write_exn = !write_exn; + } + | Err e -> + `Stream_error + { + result = res; + stream_error = e; + trailers = Eio.Promise.await trailers; + grpc_status = Eio.Promise.await grpc_status; + write_exn = !write_exn; + } + | Next (t, _) -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | OK -> + `Success + { + result = res; + response = t; + trailers = Eio.Promise.await trailers; + write_exn = !write_exn; + } + | _ -> + `Response_not_ok + { + net_response; + grpc_status = status; + trailers = Eio.Promise.await trailers; + })) + else + `Response_not_ok + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + }) + | Error e -> `Connection_error e +end + +module Server_streaming = struct + let call ~sw ~io ~service ~method_name ~headers request f = + let result = + Bidirectional_streaming.call ~sw ~io ~service ~method_name ~headers + (fun net_response ~writer ~read -> + if writer.write request then ( + writer.close (); + `Stream (f net_response ~read)) + else `Write_error) + in + let module Bs = Bidirectional_streaming in + match result with + | (`Connection_error _ | `Response_not_ok _) as e -> e + | `Stream_result { result; err; trailers } -> ( + match result with + | `Write_error -> `Write_error (err, trailers) + | `Stream res -> `Stream_result { result = res; err; trailers }) +end diff --git a/lib/eio/client/client.mli b/lib/eio/client/client.mli new file mode 100644 index 0000000..be6bf2c --- /dev/null +++ b/lib/eio/client/client.mli @@ -0,0 +1,199 @@ +type ('net_response, 'response, 'stream_err, 'headers) recv = { + net_response : 'net_response; + recv_seq : ('response, 'stream_err) Grpc_eio_core.Recv_seq.t; + trailers : 'headers Eio.Promise.t; +} + +type 'request writer = { write : 'request -> bool; close : unit -> unit } + +type ('net_response, + 'headers, + 'request, + 'response, + 'conn_error, + 'stream_error) + connection = { + writer : 'request writer; + recv : + ( ('net_response, 'response, 'stream_error, 'headers) recv, + 'conn_error ) + result + Eio.Promise.t; + grpc_status : Grpc.Status.t Eio.Promise.t; + write_exn : exn option ref; +} + +type ('net_response, 'headers) resp_not_ok = { + net_response : 'net_response; + grpc_status : Grpc.Status.t; + trailers : 'headers; +} + +type ('net_response, 'headers, 'conn_err) common_error = + [ `Connection_error of 'conn_err + | `Response_not_ok of ('net_response, 'headers) resp_not_ok ] + +val call : + sw:Eio.Switch.t -> + io: + ( 'headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'conn_error ) + Io.t -> + service:string -> + method_name:string -> + headers:Grpc_client.request_headers -> + unit -> + ( ( 'net_response, + 'headers, + 'request, + 'response, + 'conn_error, + 'stream_error ) + connection, + 'conn_error ) + result + +type ('stream_err, 'headers) streaming_err = { + stream_error : 'stream_err option; + write_exn : exn option; + grpc_status : Grpc.Status.t; +} + +type ('a, 'headers, 'stream_err) streaming_result = { + result : 'a; + trailers : 'headers; + err : ('stream_err, 'headers) streaming_err option; +} + +module Unary : sig + type ('net_response, 'headers, 'stream_err) premature_close = { + trailers : 'headers; + grpc_status : Grpc.Status.t; + net_response : 'net_response; + stream_error : 'stream_err option; + } + + type ('net_response, 'response, 'headers) success = { + net_response : 'net_response; + response : 'response; + trailers : 'headers; + } + + type ('response, 'headers, 'stream_err, 'conn_err, 'net_response) result' = + [ `Premature_close of ('net_response, 'headers, 'stream_err) premature_close + | `Success of ('net_response, 'response, 'headers) success + | `Write_error of exn + | ('net_response, 'headers, 'conn_err) common_error ] + + val call : + sw:Eio.Switch.t -> + io: + ( 'headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'conn_error ) + Io.t -> + service:string -> + method_name:string -> + headers:Grpc_client.request_headers -> + 'request -> + ('response, 'headers, 'stream_error, 'conn_error, 'net_response) result' +end + +module Client_streaming : sig + type ('a, 'headers, 'stream_err) stream_err = { + trailers : 'headers; + grpc_status : Grpc.Status.t; + result : 'a; + stream_error : 'stream_err; + write_exn : exn option; + } + + type ('a, 'response, 'headers) success = { + result : 'a; + response : 'response; + trailers : 'headers; + write_exn : exn option; + } + + type ('a, 'headers) premature_close = { + result : 'a; + trailers : 'headers; + grpc_status : Grpc.Status.t; + write_exn : exn option; + } + + type ('a, 'headers, 'stream_err, 'conn_err, 'net_response, 'response) result' = + [ `Premature_close of ('a, 'headers) premature_close + | `Stream_error of ('a, 'headers, 'stream_err) stream_err + | `Success of ('a, 'response, 'headers) success + | ('net_response, 'headers, 'conn_err) common_error ] + + val call : + sw:Eio.Switch.t -> + io: + ( 'headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'conn_error ) + Io.t -> + service:string -> + method_name:string -> + headers:Grpc_client.request_headers -> + ('net_response -> writer:'request writer -> 'a) -> + ('a, 'headers, 'stream_error, 'conn_error, 'net_response, 'response) result' +end + +module Server_streaming : sig + val call : + sw:Eio.Switch.t -> + io: + ( 'headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'conn_err ) + Io.t -> + service:string -> + method_name:string -> + headers:Grpc_client.request_headers -> + 'request -> + ('net_response -> read:(unit -> 'response Seq.node) -> 'a) -> + [ `Stream_result of ('a, 'headers, 'stream_error) streaming_result + | `Write_error of ('stream_error, 'headers) streaming_err option * 'headers + | ('net_response, 'headers, 'conn_err) common_error ] +end + +module Bidirectional_streaming : sig + type ('a, 'headers, 'stream_err, 'conn_err, 'net_response) result' = + [ `Stream_result of ('a, 'headers, 'stream_err) streaming_result + | ('net_response, 'headers, 'conn_err) common_error ] + + val call : + sw:Eio.Switch.t -> + io: + ( 'headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'conn_error ) + Io.t -> + service:string -> + method_name:string -> + headers:Grpc_client.request_headers -> + ('net_response -> + writer:'request writer -> + read:(unit -> 'response Seq.node) -> + 'a) -> + ('a, 'headers, 'stream_error, 'conn_error, 'net_response) result' +end diff --git a/lib/eio/client/dune b/lib/eio/client/dune new file mode 100644 index 0000000..2f99c2a --- /dev/null +++ b/lib/eio/client/dune @@ -0,0 +1,4 @@ +(library + (name grpc_client_eio) + (public_name grpc-client-eio) + (libraries grpc eio grpc-client grpc-eio-core)) diff --git a/lib/eio/client/io.ml b/lib/eio/client/io.ml new file mode 100644 index 0000000..5a0272c --- /dev/null +++ b/lib/eio/client/io.ml @@ -0,0 +1,68 @@ +type 'request writer = { + write : 'request -> unit; + (* Returns true if the write was successful, false if the stream is in error state. Throws if the stream was closed. *) + close : unit -> unit; +} + +type ('net_response, 'response, 'headers, 'err) reader = { + response : 'net_response; + trailers : 'headers Eio.Promise.t; + next : ('response, 'err) Grpc_eio_core.Recv_seq.t; +} + +type ('net_response, + 'response, + 'headers, + 'stream_err, + 'conn_err) + reader_or_error = + (('net_response, 'response, 'headers, 'stream_err) reader, 'conn_err) result + +module type S = sig + module Headers : sig + type t + + val get : t -> string -> string option + end + + module Net_response : sig + type t + + val is_ok : t -> bool + val headers : t -> Headers.t + end + + type request + type response + type connection_error + type stream_error + + val send_request : + headers:Grpc_client.request_headers -> + string -> + ( request writer + * ( Net_response.t, + response, + Headers.t, + stream_error, + connection_error ) + reader_or_error + Eio.Promise.t, + connection_error ) + result +end + +type ('headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'connection_error) + t = + (module S + with type Net_response.t = 'net_response + and type Headers.t = 'headers + and type connection_error = 'connection_error + and type stream_error = 'stream_error + and type request = 'request + and type response = 'response) diff --git a/lib/eio/core/body_reader.ml b/lib/eio/core/body_reader.ml new file mode 100644 index 0000000..10f4278 --- /dev/null +++ b/lib/eio/core/body_reader.ml @@ -0,0 +1,285 @@ +open Recv_seq + +type t = { bytes : Bytes.t; len : int } + +let buffer_count = ref 0 +let total_length = ref 0 +let m2 = ref 0.0 +let mean = ref 0.0 +let buckets = [| 0; 0; 0; 0; 0 |] +(* Corresponds to ranges 0-100, 101-300, 301-700, 701-1000, 1001+ *) + +let free _ = decr buffer_count + +(* Function to update mean and variance dynamically *) +let update_statistics new_len = + let new_count = float_of_int (!buffer_count + 1) in + let delta = float_of_int new_len -. !mean in + mean := !mean +. (delta /. new_count); + let delta2 = float_of_int new_len -. !mean in + m2 := !m2 +. (delta *. delta2 (* This is an online formula for variance *)); + let index = + if new_len <= 100 then 0 + else if new_len <= 300 then 1 + else if new_len <= 700 then 2 + else if new_len <= 1000 then 3 + else 4 + in + buckets.(index) <- buckets.(index) + 1 + +let get_next msg_len = + incr buffer_count; + update_statistics msg_len; + total_length := !total_length + msg_len; + Bytes.create msg_len + +(* Calculate average message length *) +let average_msg_len () = + float_of_int !total_length /. float_of_int !buffer_count + +(* Calculate standard deviation of message length *) +let stddev_msg_len () = sqrt (!m2 /. float_of_int !buffer_count) + +(* Function to get the counts of each bucket *) +let get_buckets () = buckets + +type 'a consumer = { consume : 'b. ('a -> 'b) -> 'b } + +let to_consumer t = + { + consume = + (fun f -> + let res = f t in + free t; + res); + } + +let extract_msg_len ~data ~off = + let high = Bigstringaf.get_int16_be data off in + let low = Bigstringaf.get_int16_be data (off + 2) in + (high lsl 16) lor low + +let rec unwrap_message_with_header ~data ~off ~len ~into:promise ~read_next + ~read_more = + if len >= 5 then + let _compressed = Bigstringaf.get data off in + let msg_len = extract_msg_len ~data ~off:(off + 1) in + unwrap_message ~msg_len ~data ~off:(off + 5) ~len:(len - 5) ~into:promise + ~read_next ~read_more + else + let header_buffer = Bigstringaf.create 5 in + Bigstringaf.blit data ~src_off:off header_buffer ~dst_off:0 ~len; + read_more (`Header (header_buffer, 5 - len)) ~into:promise + +and unwrap_message ~msg_len ~data ~off ~len ~into:promise ~read_next ~read_more + = + if len >= msg_len then ( + let bytes = get_next msg_len in + let next_decoder = bytes in + Bigstringaf.blit_to_bytes data ~src_off:off bytes ~dst_off:0 ~len:msg_len; + if len = msg_len then + Eio.Promise.resolve promise + (Next + ( to_consumer { bytes = next_decoder; len = msg_len }, + fun () -> read_next () )) + else + let next, next_u = Eio.Promise.create () in + unwrap_message_with_header ~data ~off:(off + msg_len) ~len:(len - msg_len) + ~into:next_u ~read_more ~read_next; + + Eio.Promise.resolve promise + (Next + ( to_consumer { bytes = next_decoder; len = msg_len }, + fun () -> Eio.Promise.await next ))) + else + let bytes = Bytes.create msg_len in + Bigstringaf.blit_to_bytes data ~src_off:off bytes ~dst_off:0 ~len; + read_more (`Body (bytes, msg_len, msg_len - len)) ~into:promise + +let rec read_more schedule_read buffer ~into:promise = + schedule_read + ~on_eof:(fun () -> Eio.Promise.resolve promise (Err `Unexpected_eof)) + ~on_read:(fun bigstring ~off ~len -> + match buffer with + | `Header (buffer, remaining) -> + if len < remaining then ( + Bigstringaf.blit bigstring ~src_off:off buffer + ~dst_off:(5 - remaining) ~len; + read_more schedule_read + (`Header (buffer, remaining - len)) + ~into:promise) + else ( + Bigstringaf.blit bigstring ~src_off:off buffer + ~dst_off:(5 - remaining) ~len:remaining; + let _compressed = Bigstringaf.get buffer off in + let msg_len = extract_msg_len ~data:buffer ~off:(off + 1) in + unwrap_message ~msg_len ~data:buffer ~off:remaining + ~len:(len - remaining) ~into:promise + ~read_next:(fun () -> read_next schedule_read) + ~read_more:(read_more schedule_read)) + | `Body (buffer, msg_len, remaining) -> + if len >= remaining then ( + Bigstringaf.blit_to_bytes bigstring ~src_off:off buffer + ~dst_off:(msg_len - remaining) ~len:remaining; + if len > remaining then ( + let next, next_u = Eio.Promise.create () in + unwrap_message_with_header ~data:bigstring ~off:(off + remaining) + ~len:(len - remaining) ~into:next_u + ~read_next:(fun () -> read_next schedule_read) + ~read_more:(read_more schedule_read); + Eio.Promise.resolve promise + (Next + ( to_consumer { bytes = buffer; len = msg_len }, + fun () -> Eio.Promise.await next ))) + else + Eio.Promise.resolve promise + (Next + ( to_consumer { bytes = buffer; len = msg_len }, + fun () -> read_next schedule_read ))) + else ( + Bigstringaf.blit_to_bytes bigstring ~src_off:off buffer + ~dst_off:(msg_len - remaining) ~len; + read_more schedule_read + (`Body (buffer, msg_len, remaining - len)) + ~into:promise)) + +and read_next schedule_read = + let promise, promise_u = Eio.Promise.create () in + schedule_read + ~on_eof:(fun () -> Eio.Promise.resolve promise_u Done) + ~on_read:(fun bigstring ~off ~len -> + unwrap_message_with_header ~data:bigstring ~off ~len ~into:promise_u + ~read_next:(fun () -> read_next schedule_read) + ~read_more:(read_more schedule_read)); + Eio.Promise.await promise + +let fill_header ~pos ~length buffer = + (* write compressed flag (uint8) *) + Bigstringaf.set buffer pos '\x00'; + (* write msg length (uint32 be) *) + Bigstringaf.set_int16_be buffer (pos + 1) (length lsr 16); + Bigstringaf.set_int16_be buffer (pos + 3) (length land 0xFFFF) + +exception Unexpected_eof + +let to_seq_exn = + let rec iter s = + match s () with + | Next (msg, cons) -> Seq.Cons (msg, fun () -> iter cons) + | Done -> Seq.Nil + | Err `Unexpected_eof -> raise Unexpected_eof + in + fun sequence () -> iter sequence + +let%expect_test "extracting multiple messages" = + Eio_mock.Backend.run @@ fun _env -> + let promise, promise_u = Eio.Promise.create () in + let test_buffer = Bigstringaf.create ((3 * 5) + 1 + 2 + 3) in + fill_header ~pos:0 ~length:1 test_buffer; + fill_header ~pos:6 ~length:2 test_buffer; + fill_header ~pos:13 ~length:3 test_buffer; + Bigstringaf.blit_from_string "1" ~src_off:0 test_buffer ~dst_off:5 ~len:1; + Bigstringaf.blit_from_string "22" ~src_off:0 test_buffer ~dst_off:11 ~len:2; + Bigstringaf.blit_from_string "333" ~src_off:0 test_buffer ~dst_off:18 ~len:3; + + unwrap_message_with_header ~data:test_buffer ~off:0 + ~len:(Bigstringaf.length test_buffer) + ~into:promise_u + ~read_next:(fun () -> Done) + ~read_more:(fun _ -> raise Not_found); + + (fun () -> Eio.Promise.await promise) + |> to_seq_exn + |> Seq.iter (fun { consume } -> + consume (fun { bytes; len } -> + Bytes.sub_string bytes 0 len |> print_endline)); + [%expect {| + 1 + 22 + 333 + |}] + +let%expect_test "extracting single message" = + Eio_mock.Backend.run @@ fun _env -> + let promise, promise_u = Eio.Promise.create () in + let test_buffer = Bigstringaf.create 6 in + fill_header ~pos:0 ~length:1 test_buffer; + Bigstringaf.blit_from_string "1" ~src_off:0 test_buffer ~dst_off:5 ~len:1; + + unwrap_message_with_header ~data:test_buffer ~off:0 + ~len:(Bigstringaf.length test_buffer) + ~into:promise_u + ~read_next:(fun () -> Done) + ~read_more:(fun _ -> raise Not_found); + + (fun () -> Eio.Promise.await promise) + |> to_seq_exn + |> Seq.iter (fun { consume } -> + consume (fun { bytes; len } -> + Bytes.sub_string bytes 0 len |> print_endline)); + [%expect {| + 1 + |}] + +let%test_module "reading body" = + (module struct + let get_reader reads = + let buffer = Bigstringaf.create 65536 in + let packets = ref reads in + fun ~on_eof ~on_read -> + match !packets with + | [] -> on_eof () + | packet :: rest -> + packets := rest; + on_read buffer ~off:0 ~len:(packet buffer) + + let%test "reading partial body (error)" = + let schedule_read = + get_reader + [ + (fun buf -> + fill_header ~pos:0 ~length:3 buf; + 5); + (fun buf -> + Bigstringaf.blit_from_string "12" ~src_off:0 buf ~dst_off:0 ~len:2; + 2); + ] + in + let result = read_next schedule_read in + match result with Err `Unexpected_eof -> true | _ -> false + + let%expect_test "reading body in multiple chunks" = + Eio_mock.Backend.run @@ fun _env -> + let header = Bigstringaf.create 5 in + fill_header ~pos:0 ~length:10 header; + let schedule_read = + get_reader + [ + (fun buf -> + Bigstringaf.blit header ~src_off:0 buf ~dst_off:0 ~len:3; + 3); + (fun buf -> + Bigstringaf.blit header ~src_off:3 buf ~dst_off:0 ~len:2; + 2); + (fun buf -> + Bigstringaf.blit_from_string "55555" ~src_off:0 buf ~dst_off:0 + ~len:5; + 5); + (fun buf -> + Bigstringaf.blit_from_string "55555" ~src_off:0 buf ~dst_off:0 + ~len:5; + 5); + ] + in + let result = read_next schedule_read in + (match result with + | Done -> print_endline "failure" + | Err `Unexpected_eof -> print_endline "failure" + | Next ({ consume }, cons) -> ( + print_endline + (consume (fun { bytes; len } -> Bytes.sub_string bytes 0 len)); + match cons () with + | Done -> () + | Err `Unexpected_eof | Next _ -> failwith "expected end of sequence")); + [%expect "5555555555"] + end) diff --git a/lib/eio/core/dune b/lib/eio/core/dune new file mode 100644 index 0000000..5b8c879 --- /dev/null +++ b/lib/eio/core/dune @@ -0,0 +1,6 @@ +(library + (name grpc_eio_core) + (public_name grpc-eio-core) + (libraries eio eio.mock) + (preprocess + (pps ppx_expect))) diff --git a/lib/eio/core/recv_seq.ml b/lib/eio/core/recv_seq.ml new file mode 100644 index 0000000..7818c22 --- /dev/null +++ b/lib/eio/core/recv_seq.ml @@ -0,0 +1,20 @@ +type ('a, 'err) t = unit -> ('a, 'err) recv_item +and ('a, 'err) recv_item = Done | Next of 'a * ('a, 'err) t | Err of 'err + +let rec map f recv = + fun () -> + match recv () with + | Done -> Done + | Next (x, recv) -> Next (f x, map f recv) + | Err err -> Err err + +let to_seq ?err_to_exn recv = + let rec loop recv () = + match recv () with + | Done -> Seq.Nil + | Next (x, recv) -> Seq.Cons (x, loop recv) + | Err err -> match err_to_exn with + | None -> failwith "Unexpected error on read. Implement err_to_exn for a more granular error." + | Some f -> raise (f err) + in + loop recv diff --git a/lib/eio/io-client-h2-ocaml-protoc/dune b/lib/eio/io-client-h2-ocaml-protoc/dune new file mode 100644 index 0000000..eb5b505 --- /dev/null +++ b/lib/eio/io-client-h2-ocaml-protoc/dune @@ -0,0 +1,4 @@ +(library + (public_name grpc-eio-io-client-h2-ocaml-protoc) + (name io_client_h2_ocaml_protoc) + (libraries pbrt pbrt_services grpc-client-eio h2 eio h2-eio grpc_eio_core)) diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml new file mode 100644 index 0000000..4cd466f --- /dev/null +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml @@ -0,0 +1,209 @@ +module type Client = sig + val connection : H2_eio.Client.t + + (* This promise might eventually resolve at any point so we should handle it everywhere *) + val connection_error : H2.Client_connection.error Eio.Promise.t + val host : string + val scheme : string +end + +module Headers = struct + type t = H2.Headers.t + + let get = H2.Headers.get +end + +module Net_response = struct + type t = H2.Response.t + + let is_ok t = H2.Status.is_successful t.H2.Response.status + let headers t = t.H2.Response.headers +end + +type connection_error = H2.Client_connection.error +type stream_error = [ connection_error | `Unexpected_eof ] + +type t = + ( H2.Headers.t, + H2.Response.t, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + H2.Client_connection.error ) + Grpc_client_eio.Io.t + +type exn += Write_after_error + +module Growing_buffer = Grpc.Buffer + +(* type resp_consumer = { consume : 'a. (Pbrt.Decoder.t -> 'a) -> 'a } *) + +module Make_net (Client : Client) : + Grpc_client_eio.Io.S + with type Net_response.t = H2.Response.t + and type Headers.t = H2.Headers.t + and type connection_error = connection_error + and type request = Pbrt.Encoder.t -> unit + and type response = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer + and type stream_error = stream_error = struct + module Net_response = Net_response + module Headers = Headers + + type nonrec connection_error = connection_error + type nonrec stream_error = stream_error + type request = Pbrt.Encoder.t -> unit + type response = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer + + let send_request ~(headers : Grpc_client.request_headers) target = + (* We are flushing headers immediately but potentially for the + unary and server streaming cases we shouldn't do it + *) + let request = + H2.Request.create ~scheme:Client.scheme `POST target + ~headers: + (H2.Headers.of_list + [ + (":authority", Client.host); + ("te", headers.te); + ("content-type", headers.content_type); + ]) + in + (* Refs are used in order to prevent from leaky promises. + I find promises that never get resolved a bit of an anti-pattern + *) + let result, result_u = Eio.Promise.create () in + let trailers_handler = ref ignore in + let error_handler = + ref (fun error -> Eio.Promise.resolve_error result_u error) + in + (* Allocate once, use a pool of these *) + let errored = ref false in + (* + let report_net_error resolver trailers_resolver err = + errored := true; + Eio.Promise.resolve resolver + (Grpc_client_eio.Io.Err (err :> stream_error)); + Eio.Promise.resolve trailers_resolver H2.Headers.empty + in + *) + let response_handler response reader = + let trailers, trailers_u = Eio.Promise.create () in + let () = + trailers_handler := + fun trailers -> Eio.Promise.resolve trailers_u trailers + in + + let next = + (* FIXME: connection error handling + + Eio.Switch.run (fun sw -> + Eio.Fiber.fork_daemon ~sw (fun () -> + report_net_error next_item_u trailers_u + (Eio.Promise.await Client.connection_error); + `Stop_daemon); + Eio.Promise.await next_item |> ignore)); + *) + let _ = Client.connection_error in + (fun () -> + Grpc_eio_core.Body_reader.read_next + (H2.Body.Reader.schedule_read reader)) + |> Grpc_eio_core.Recv_seq.map + (fun { Grpc_eio_core.Body_reader.consume } -> + { + Grpc_eio_core.Body_reader.consume = + (fun f -> + consume (fun { Grpc_eio_core.Body_reader.bytes; len } -> + f (Pbrt.Decoder.of_subbytes bytes 0 len))); + }) + in + + Eio.Promise.resolve result_u + (Ok { Grpc_client_eio.Io.response; next; trailers }) + in + let body_writer = + H2_eio.Client.request ~flush_headers_immediately:true Client.connection + ~trailers_handler:(fun trailers -> !trailers_handler trailers) + ~error_handler:(fun error -> !error_handler error) + ~response_handler request + in + let encoder = Pbrt.Encoder.create ~size:65536 () in + Ok + ( { + Grpc_client_eio.Io.write = + (let header_buffer = Bytes.create 5 in + fun input -> + if !errored = true then raise Write_after_error + else ( + Pbrt.Encoder.clear encoder; + input encoder; + let msg = Pbrt.Encoder.to_bytes encoder in + Grpc.Message.fill_header ~length:(Bytes.length msg) + header_buffer; + H2.Body.Writer.write_string body_writer + (Bytes.unsafe_to_string header_buffer); + H2.Body.Writer.write_string body_writer + (Bytes.unsafe_to_string msg))); + close = (fun () -> H2.Body.Writer.close body_writer); + }, + result ) +end + +module Expert = struct + let create_with_socket ~sw ~(socket : _ Eio.Net.stream_socket) ~host ~scheme : + t = + let connection, connection_resolve = Eio.Promise.create () in + let connection_error, connection_error_resolve = Eio.Promise.create () in + Eio.Fiber.fork_daemon ~sw (fun () -> + Eio.Switch.run (fun sw' -> + let conn = + H2_eio.Client.create_connection ~sw:sw' + ~error_handler:(Eio.Promise.resolve connection_error_resolve) + socket + in + Eio.Switch.on_release sw' (fun () -> + Eio.Promise.await (H2_eio.Client.shutdown conn)); + (* For now we're ignoring the errors, we should probably inject them into grpc handlers to let them handle it *) + Eio.Promise.resolve connection_resolve conn); + `Stop_daemon); + let conn = Eio.Promise.await connection in + (module Make_net (struct + let connection = conn + let connection_error = connection_error + let host = host + let scheme = scheme + end)) + + let create_with_address ~(net : Eio_unix.Net.t) ~sw ~scheme ~host ~port = + let inet, port = + Eio_unix.run_in_systhread (fun () -> + Unix.getaddrinfo host (string_of_int port) + [ Unix.(AI_FAMILY PF_INET) ]) + |> List.filter_map (fun (addr : Unix.addr_info) -> + match addr.ai_addr with + | Unix.ADDR_UNIX _ -> None + | ADDR_INET (addr, port) -> Some (addr, port)) + |> List.hd + in + let addr = `Tcp (Eio_unix.Net.Ipaddr.of_unix inet, port) in + let socket = Eio.Net.connect ~sw net addr in + create_with_socket ~socket ~host ~scheme ~sw +end + +let create_client ~net ~sw addr = + let uri = Uri.of_string addr in + let scheme = Uri.scheme uri |> Option.value ~default:"http" in + let host = + match Uri.host uri with + | None -> invalid_arg "No host in uri" + | Some host -> host + in + let port = + Uri.port uri + |> Option.value + ~default: + (match scheme with + | "http" -> 80 + | "https" -> 443 + | _ -> failwith "Don't know default port for this scheme") + in + Expert.create_with_address ~net ~sw ~scheme ~host ~port diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli new file mode 100644 index 0000000..9804e41 --- /dev/null +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli @@ -0,0 +1,29 @@ +type stream_error = [ H2.Client_connection.error | `Unexpected_eof ] + +type t = + ( H2.Headers.t, + H2.Response.t, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + H2.Client_connection.error ) + Grpc_client_eio.Io.t + +module Expert : sig + val create_with_socket : + sw:Eio.Switch.t -> + socket:[> [> `Generic ] Eio.Net.stream_socket_ty ] Eio_unix.source -> + host:string -> + scheme:string -> + t + + val create_with_address : + net:Eio_unix.Net.t -> + sw:Eio.Switch.t -> + scheme:string -> + host:string -> + port:int -> + t +end + +val create_client : net:Eio_unix.Net.t -> sw:Eio.Switch.t -> string -> t diff --git a/lib/eio/io-server-h2-ocaml-protoc/dune b/lib/eio/io-server-h2-ocaml-protoc/dune new file mode 100644 index 0000000..5706f82 --- /dev/null +++ b/lib/eio/io-server-h2-ocaml-protoc/dune @@ -0,0 +1,7 @@ +(library + (public_name grpc-eio-io-server-h2-ocaml-protoc) + (name io_server_h2_ocaml_protoc) + (libraries grpc-server-eio h2-eio pbrt eio.mock grpc-eio-core) + (inline_tests) + (preprocess + (pps ppx_expect))) diff --git a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml new file mode 100644 index 0000000..f195ce4 --- /dev/null +++ b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml @@ -0,0 +1,129 @@ +exception Unexpected_eof + + +module Io = struct + type request = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer + type response = Pbrt.Encoder.t -> unit + + module Growing_buffer = Grpc.Buffer + + module Net_request = struct + type t = Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t + + let is_post (_, _, req) = + match req with { H2.Request.meth = `POST; _ } -> true | _ -> false + + let target (_, _, req) = req.H2.Request.target + + (* Expose a way to interrupt *) + let get_header (_, _, req) name = H2.Headers.get req.H2.Request.headers name + + let to_seq recv = + let rec loop recv () = + match recv () with + | Grpc_eio_core.Recv_seq.Done -> Seq.Nil + | Next (x, recv) -> Seq.Cons (x, loop recv) + | Err `Unexpected_eof -> raise Unexpected_eof + in + loop recv + + let body (_, reqd, _) = + let body = H2.Reqd.request_body reqd in + (fun () -> + Grpc_eio_core.Body_reader.read_next (H2.Body.Reader.schedule_read body)) + |> Grpc_eio_core.Recv_seq.map + (fun { Grpc_eio_core.Body_reader.consume } -> + { + Grpc_eio_core.Body_reader.consume = + (fun f -> + consume (fun { Grpc_eio_core.Body_reader.bytes; len } -> + f (Pbrt.Decoder.of_subbytes bytes 0 len))); + }) + |> to_seq + end + + let write_trailers reqd (trailers : Grpc_server.trailers) = + try + H2.Reqd.schedule_trailers reqd + (H2.Headers.of_list + (("grpc-status", string_of_int trailers.grpc_status) + :: + (match trailers.grpc_message with + | None -> trailers.extra + | Some msg -> ("grpc-message", msg) :: trailers.extra))) + with + | ((Failure "h2.Reqd.schedule_trailers: stream already closed") + [@warning "-52"] (* https://github.com/anmonteiro/ocaml-h2/issues/175 *)) + -> + () + + let respond_streaming ~headers (_, reqd, _) = + let body_writer = + H2.Reqd.respond_with_streaming ~flush_headers_immediately:true reqd + (H2.Response.create + ~headers: + (H2.Headers.of_list + (("content-type", headers.Grpc_server.content_type) + :: headers.extra)) + `OK) + in + let encoder = Pbrt.Encoder.create () in + let close () = H2.Body.Writer.close body_writer in + let header_buffer = Bytes.create 5 in + let write input = + Pbrt.Encoder.clear encoder; + input encoder; + let data = Pbrt.Encoder.to_bytes encoder |> Bytes.unsafe_to_string in + Grpc.Message.fill_header ~length:(String.length data) header_buffer; + H2.Body.Writer.write_string body_writer + (Bytes.unsafe_to_string header_buffer); + H2.Body.Writer.write_string body_writer data; + H2.Body.Writer.flush body_writer ignore + in + let write_trailers = write_trailers reqd in + let is_closed () = H2.Body.Writer.is_closed body_writer in + { Grpc_server_eio.Io.close; write; write_trailers; is_closed } + + let respond_error ~status_code ~headers (_, reqd, _) = + H2.Reqd.respond_with_string reqd + (H2.Response.create + ~headers:(H2.Headers.of_list headers) + (H2.Status.of_code status_code)) + "" +end + +include Io + +let io = + (module Io : Grpc_server_eio.Io.S + with type Net_request.t = Io.Net_request.t + and type request = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer + and type response = Pbrt.Encoder.t -> unit) + +let connection_handler ~sw ?config ?h2_error_handler ?grpc_error_handler server + : 'a Eio.Net.connection_handler = + fun socket addr -> + let error_handler client_address ?request error respond = + (* Report internal error via headers *) + let () = + match h2_error_handler with + | Some f -> f client_address ?request error + | None -> () + in + let writer = + respond + (H2.Headers.of_list + [ + ( "grpc-status", + string_of_int (Grpc.Status.int_of_code Grpc.Status.Internal) ); + ]) + in + H2.Body.Writer.close writer + in + H2_eio.Server.create_connection_handler ?config + ~request_handler:(fun client_addr reqd -> + Eio.Fiber.fork ~sw (fun () -> + Grpc_server_eio.handle_request ~io ?error_handler:grpc_error_handler + server + (client_addr, reqd, H2.Reqd.request reqd))) + ~error_handler addr socket ~sw diff --git a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli new file mode 100644 index 0000000..2184b73 --- /dev/null +++ b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli @@ -0,0 +1,17 @@ +include + Grpc_server_eio.Io.S + with type Net_request.t = Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t + and type request = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer + and type response = Pbrt.Encoder.t -> unit + +val connection_handler : + sw:Eio.Switch.t -> + ?config:H2.Config.t -> + ?h2_error_handler: + (Eio.Net.Sockaddr.stream -> + ?request:H2.Request.t -> + H2.Server_connection.error -> + unit) -> + ?grpc_error_handler:(exn -> (string * string) list) -> + (Net_request.t, request, response) Grpc_server_eio.t -> + 'a Eio.Net.connection_handler diff --git a/lib/eio/server/dune b/lib/eio/server/dune new file mode 100644 index 0000000..d6d009d --- /dev/null +++ b/lib/eio/server/dune @@ -0,0 +1,8 @@ +(library + (name grpc_server_eio) + (public_name grpc-server-eio) + (libraries grpc eio grpc-server)) + +(deprecated_library_name + (old_public_name grpc-eio) + (new_public_name grpc-server-eio)) diff --git a/lib/eio/server/grpc_server_eio.ml b/lib/eio/server/grpc_server_eio.ml new file mode 100644 index 0000000..ed268c0 --- /dev/null +++ b/lib/eio/server/grpc_server_eio.ml @@ -0,0 +1,126 @@ +module Io = Io + +type extra_trailers = (string * string) list + +exception Server_error of Grpc.Status.t * (string * string) list + +module Rpc = struct + type ('request, 'response) unary = + 'request -> 'response * (string * string) list + + type ('req, 'res) client_streaming = + 'req Seq.t -> 'res * (string * string) list + + type ('req, 'res) server_streaming = + 'req -> ('res -> unit) -> (string * string) list + + type ('req, 'res) bidirectional_streaming = + 'req Seq.t -> ('res -> unit) -> (string * string) list + + type ('req, 'res) rpc_impl = + 'req Seq.t -> ('res -> unit) -> (string * string) list + + type rpc_complete = Rpc_complete + + type ('req, 'res) handler_accept = { + accept : + Grpc_server.headers -> + ('req Seq.t -> ('res -> unit) -> extra_trailers) -> + rpc_complete; + } + + type ('net_req, 'req, 'res) handler = + 'net_req -> ('req, 'res) handler_accept -> rpc_complete + + let unary (unary_handler : _ unary) : _ rpc_impl = + fun request_stream respond -> + match request_stream () with + | Seq.Cons (request, _) -> + let response, extra = unary_handler request in + respond response; + extra + (* TODO: Look up which error this is *) + | Seq.Nil -> raise (Server_error (Grpc.Status.make Not_found, [])) + + let client_streaming (client_streaming_handler : _ client_streaming) : + _ rpc_impl = + fun request_stream respond -> + let response, extra = client_streaming_handler request_stream in + respond response; + extra + + let server_streaming (server_streaming_handler : _ server_streaming) : + _ rpc_impl = + fun requests respond -> + match requests () with + | Seq.Cons (request, _) -> server_streaming_handler request respond + | Seq.Nil -> raise (Server_error (Grpc.Status.make Not_found, [])) + (* TODO: Look up which error this is *) +end + +module G = Grpc_server + +type ('net_request, 'req, 'resp) t = + service:string -> meth:string -> ('net_request, 'req, 'resp) Rpc.handler + +let handle_request ?error_handler (type net_request req resp) + ~(io : (net_request, req, resp) Io.t) server request = + let module Io' = (val io) in + let run_handler handler = + let Rpc.Rpc_complete = + handler request + { + Rpc.accept = + (fun headers f -> + let { Io.write; write_trailers; close; is_closed } = + Io'.respond_streaming ~headers request + in + try + let request_stream = Io'.Net_request.body request in + let extra = f request_stream write in + write_trailers + (Grpc_server.make_trailers ~extra (Grpc.Status.make OK)); + close (); + Rpc_complete + with + | Server_error (status, extra) -> + if not (is_closed ()) then ( + write_trailers (Grpc_server.make_trailers ~extra status); + close ()); + Rpc_complete + | exn -> + let extra = + Option.map (fun f -> f exn) error_handler + |> Option.value ~default:[] + in + if not (is_closed ()) then ( + write_trailers + (Grpc_server.make_trailers ~extra + (Grpc.Status.make Internal)); + close ()); + Rpc_complete); + } + in + () + in + match + G.parse_request + ~is_post_request:(Io'.Net_request.is_post request) + ~get_header:(fun header -> Io'.Net_request.get_header request header) + ~path:(Io'.Net_request.target request) + |> Result.map (fun { G.service; meth } -> server ~service ~meth) + |> Result.map run_handler + with + | Ok () -> () + | exception Server_error (status, extra) -> + let status_code, headers' = Grpc.Status.to_net_resp status in + let headers = List.concat [ headers'; extra ] in + Io'.respond_error request ~status_code ~headers + | exception exn -> + let headers = + Option.map (fun f -> f exn) error_handler |> Option.value ~default:[] + in + Io'.respond_error request ~status_code:500 ~headers + | Error e -> + let status_code, headers = Grpc_server.error_to_code_and_headers e in + Io'.respond_error request ~status_code ~headers diff --git a/lib/eio/server/grpc_server_eio.mli b/lib/eio/server/grpc_server_eio.mli new file mode 100644 index 0000000..e168992 --- /dev/null +++ b/lib/eio/server/grpc_server_eio.mli @@ -0,0 +1,41 @@ +module Io = Io + +exception Server_error of Grpc.Status.t * (string * string) list + +type extra_trailers = (string * string) list + +module Rpc : sig + type rpc_complete + + type ('req, 'res) handler_accept = { + accept : + Grpc_server.headers -> + ('req Seq.t -> ('res -> unit) -> extra_trailers) -> + rpc_complete; + } + + type ('net_req, 'req, 'res) handler = + 'net_req -> ('req, 'res) handler_accept -> rpc_complete + + type ('req, 'res) rpc_impl = 'req Seq.t -> ('res -> unit) -> extra_trailers + (** [handler] represents the most general signature of a gRPC handler. *) + + type ('req, 'res) unary = 'req -> 'res * extra_trailers + type ('req, 'res) client_streaming = 'req Seq.t -> 'res * extra_trailers + type ('req, 'res) server_streaming = 'req -> ('res -> unit) -> extra_trailers + type ('req, 'res) bidirectional_streaming = ('req, 'res) rpc_impl + + val unary : ('req, 'res) unary -> ('req, 'res) rpc_impl + val client_streaming : ('req, 'res) client_streaming -> ('req, 'res) rpc_impl + val server_streaming : ('req, 'res) server_streaming -> ('req, 'res) rpc_impl +end + +type ('net_request, 'req, 'resp) t = + service:string -> meth:string -> ('net_request, 'req, 'resp) Rpc.handler + +val handle_request : + ?error_handler:(exn -> extra_trailers) -> + io:('net_request, 'req, 'resp) Io.t -> + ('net_request, 'req, 'resp) t -> + 'net_request -> + unit diff --git a/lib/eio/server/io.ml b/lib/eio/server/io.ml new file mode 100644 index 0000000..43c352c --- /dev/null +++ b/lib/eio/server/io.ml @@ -0,0 +1,34 @@ +type 'request streaming_writer = { + (* replace dis string *) + write : 'request -> unit; + close : unit -> unit; + write_trailers : Grpc_server.trailers -> unit; + is_closed : unit -> bool; +} + +module type S = sig + type request + + module Net_request : sig + type t + + val body : t -> request Seq.t + val is_post : t -> bool + val target : t -> string + val get_header : t -> string -> string option + end + + type response + + val respond_streaming : + headers:Grpc_server.headers -> Net_request.t -> response streaming_writer + + val respond_error : + status_code:int -> headers:(string * string) list -> Net_request.t -> unit +end + +type ('net_request, 'request, 'response) t = + (module S + with type Net_request.t = 'net_request + and type request = 'request + and type response = 'response) diff --git a/lib/grpc-eio/readme.md b/lib/eio/server/readme.md similarity index 100% rename from lib/grpc-eio/readme.md rename to lib/eio/server/readme.md diff --git a/lib/grpc-client/dune b/lib/grpc-client/dune new file mode 100644 index 0000000..119484d --- /dev/null +++ b/lib/grpc-client/dune @@ -0,0 +1,4 @@ +(library + (name grpc_client) + (public_name grpc-client) + (libraries grpc)) diff --git a/lib/grpc-client/grpc_client.ml b/lib/grpc-client/grpc_client.ml new file mode 100644 index 0000000..942a749 --- /dev/null +++ b/lib/grpc-client/grpc_client.ml @@ -0,0 +1,31 @@ +type request_headers = { content_type : string; te : string } + +let make_request_headers ?(te = []) format = + { + content_type = Grpc.Message.format_to_content_type format; + te = + (match te with + | [] -> "trailers" + | te -> Printf.sprintf "trailers; %s" (String.concat "; " te)); + } + +let make_path ~service ~method_name = + Printf.sprintf "/%s/%s" service method_name + +let status_of_trailers ~get_header = + match get_header "grpc-status" with + | None -> + Grpc.Status.make ~error_message:"Server did not return grpc-status" + Grpc.Status.Unknown + | Some s -> ( + match Option.bind (int_of_string_opt s) Grpc.Status.code_of_int with + | None -> + Grpc.Status.make + ~error_message: + (Printf.sprintf "Server returned an invalid grpc-status %s" s) + Grpc.Status.Unknown + | Some status -> + Grpc.Status.make ?error_message:(get_header "grpc-message") status) + +let trailers_missing_status = + Grpc.Status.make ~error_message:"Trailers missing" Grpc.Status.Unknown diff --git a/lib/grpc-client/grpc_client.mli b/lib/grpc-client/grpc_client.mli new file mode 100644 index 0000000..93b2deb --- /dev/null +++ b/lib/grpc-client/grpc_client.mli @@ -0,0 +1,8 @@ +type request_headers = { content_type : string; te : string } + +val make_request_headers : + ?te:string list -> Grpc.Message.format -> request_headers + +val make_path : service:string -> method_name:string -> string +val status_of_trailers : get_header:(string -> string option) -> Grpc.Status.t +val trailers_missing_status : Grpc.Status.t diff --git a/lib/grpc-eio/client.ml b/lib/grpc-eio/client.ml deleted file mode 100644 index 4efe5cd..0000000 --- a/lib/grpc-eio/client.ml +++ /dev/null @@ -1,106 +0,0 @@ -type response_handler = H2.Client_connection.response_handler - -type do_request = - ?flush_headers_immediately:bool -> - ?trailers_handler:(H2.Headers.t -> unit) -> - H2.Request.t -> - response_handler:response_handler -> - H2.Body.Writer.t - -let make_request ~scheme ~service ~rpc ~headers = - H2.Request.create ~scheme `POST ("/" ^ service ^ "/" ^ rpc) ~headers - -let default_headers = - H2.Headers.of_list - [ ("te", "trailers"); ("content-type", "application/grpc+proto") ] - -let make_trailers_handler () = - let status, status_notify = Eio.Promise.create () in - let trailers_handler headers = - let code = - match H2.Headers.get headers "grpc-status" with - | None -> None - | Some s -> Option.bind (int_of_string_opt s) Grpc.Status.code_of_int - in - match (code, Eio.Promise.is_resolved status) with - | Some code, false -> - let message = H2.Headers.get headers "grpc-message" in - let status = Grpc.Status.v ?message code in - Eio.Promise.resolve status_notify status - | Some _, true (* This should never happen, but just in case. *) | _ -> () - in - (status, trailers_handler) - -let get_response_and_bodies request = - let response, response_notify = Eio.Promise.create () in - let read_body, read_body_notify = Eio.Promise.create () in - let response_handler response body = - Eio.Promise.resolve response_notify response; - Eio.Promise.resolve read_body_notify body - in - let write_body = request ~response_handler in - let response = Eio.Promise.await response in - let read_body = Eio.Promise.await read_body in - (response, read_body, write_body) - -let call ~service ~rpc ?(scheme = "https") ~handler ~(do_request : do_request) - ?(headers = default_headers) () = - let request = make_request ~service ~rpc ~scheme ~headers in - let status, trailers_handler = make_trailers_handler () in - let response, read_body, write_body = - get_response_and_bodies - (do_request ~flush_headers_immediately:true request ~trailers_handler) - in - match response.status with - | `OK -> - trailers_handler response.headers; - let result = handler write_body read_body in - let status = - match Eio.Promise.is_resolved status with - (* In case no grpc-status appears in headers or trailers. *) - | true -> Eio.Promise.await status - | false -> - Grpc.Status.v ~message:"Server did not return grpc-status" - Grpc.Status.Unknown - in - Ok (result, status) - | error_status -> Error error_status - -module Rpc = struct - type 'a handler = H2.Body.Writer.t -> H2.Body.Reader.t -> 'a - - let bidirectional_streaming ~f write_body read_body = - let response_reader, response_writer = Seq.create_reader_writer () in - let request_reader, request_writer = Seq.create_reader_writer () in - Connection.grpc_recv_streaming read_body response_writer; - let res, res_notify = Eio.Promise.create () in - Eio.Fiber.both - (fun () -> - Eio.Promise.resolve res_notify (f request_writer response_reader)) - (fun () -> - Connection.grpc_send_streaming_client write_body request_reader); - Eio.Promise.await res - - let client_streaming ~f = - bidirectional_streaming ~f:(fun request_writer responses -> - let response, response_resolver = Eio.Promise.create () in - Eio.Fiber.pair - (fun () -> f request_writer response) - (fun () -> - Eio.Promise.resolve response_resolver - (Seq.read_and_exhaust responses)) - |> fst) - - let server_streaming ~f request = - bidirectional_streaming ~f:(fun request_writer responses -> - Seq.write request_writer request; - Seq.close_writer request_writer; - f responses) - - let unary ~f request = - bidirectional_streaming ~f:(fun request_writer responses -> - Seq.write request_writer request; - Seq.close_writer request_writer; - let response = Seq.read_and_exhaust responses in - f response) -end diff --git a/lib/grpc-eio/client.mli b/lib/grpc-eio/client.mli deleted file mode 100644 index 745d33c..0000000 --- a/lib/grpc-eio/client.mli +++ /dev/null @@ -1,48 +0,0 @@ -module Rpc : sig - type 'a handler = H2.Body.Writer.t -> H2.Body.Reader.t -> 'a - - val bidirectional_streaming : - f:(string Seq.writer -> string Seq.t -> 'a) -> 'a handler - (** [bidirectional_streaming ~f write read] sets up the sending and receiving - logic using [write] and [read], then calls [f] with a push function for - requests and a stream of responses. *) - - val client_streaming : - f:(string Seq.writer -> string option Eio.Promise.t -> 'a) -> 'a handler - (** [client_streaming ~f write read] sets up the sending and receiving - logic using [write] and [read], then calls [f] with a push function for - requests and promise for the response. *) - - val server_streaming : f:(string Seq.t -> 'a) -> string -> 'a handler - (** [server_streaming ~f enc write read] sets up the sending and receiving - logic using [write] and [read], then sends [enc] and calls [f] with a - stream of responses. *) - - val unary : f:(string option -> 'a) -> string -> 'a handler - (** [unary ~f enc write read] sets up the sending and receiving - logic using [write] and [read], then sends [enc] and calls [f] with a - promise for the response. *) -end - -type response_handler = H2.Client_connection.response_handler - -type do_request = - ?flush_headers_immediately:bool -> - ?trailers_handler:(H2.Headers.t -> unit) -> - H2.Request.t -> - response_handler:response_handler -> - H2.Body.Writer.t -(** [do_request] is the type of a function that performs the request *) - -val call : - service:string -> - rpc:string -> - ?scheme:string -> - handler:'a Rpc.handler -> - do_request:do_request -> - ?headers:H2.Headers.t -> - unit -> - ('a * Grpc.Status.t, H2.Status.t) result -(** [call ~service ~rpc ~handler ~do_request ()] calls the rpc endpoint given - by [service] and [rpc] using the [do_request] function. The [handler] is - called when this request is set up to send and receive data. *) diff --git a/lib/grpc-eio/connection.ml b/lib/grpc-eio/connection.ml deleted file mode 100644 index 3de3965..0000000 --- a/lib/grpc-eio/connection.ml +++ /dev/null @@ -1,45 +0,0 @@ -let grpc_recv_streaming body message_buffer_writer = - let request_buffer = Grpc.Buffer.v () in - let on_eof () = Seq.close_writer message_buffer_writer in - let rec on_read buffer ~off ~len = - Grpc.Buffer.copy_from_bigstringaf ~src_off:off ~src:buffer - ~dst:request_buffer ~length:len; - Grpc.Message.extract_all (Seq.write message_buffer_writer) request_buffer; - H2.Body.Reader.schedule_read body ~on_read ~on_eof - in - H2.Body.Reader.schedule_read body ~on_read ~on_eof - -let grpc_send_streaming_client body encoder_stream = - Seq.iter - (fun encoder -> - let payload = Grpc.Message.make encoder in - H2.Body.Writer.write_string body payload) - encoder_stream; - H2.Body.Writer.close body - -let grpc_send_streaming request encoder_stream status_promise = - let body = - H2.Reqd.respond_with_streaming ~flush_headers_immediately:true request - (H2.Response.create - ~headers: - (H2.Headers.of_list [ ("content-type", "application/grpc+proto") ]) - `OK) - in - Seq.iter - (fun input -> - let payload = Grpc.Message.make input in - H2.Body.Writer.write_string body payload; - H2.Body.Writer.flush body (fun () -> ())) - encoder_stream; - let status = Eio.Promise.await status_promise in - H2.Reqd.schedule_trailers request - (H2.Headers.of_list - ([ - ( "grpc-status", - string_of_int (Grpc.Status.int_of_code (Grpc.Status.code status)) ); - ] - @ - match Grpc.Status.message status with - | None -> [] - | Some message -> [ ("grpc-message", message) ])); - H2.Body.Writer.close body diff --git a/lib/grpc-eio/dune b/lib/grpc-eio/dune deleted file mode 100644 index 39ce5ea..0000000 --- a/lib/grpc-eio/dune +++ /dev/null @@ -1,4 +0,0 @@ -(library - (name grpc_eio) - (public_name grpc-eio) - (libraries grpc h2 eio)) diff --git a/lib/grpc-eio/grpc_eio.ml b/lib/grpc-eio/grpc_eio.ml deleted file mode 100644 index c7e9399..0000000 --- a/lib/grpc-eio/grpc_eio.ml +++ /dev/null @@ -1,3 +0,0 @@ -module Server = Server -module Client = Client -module Seq = Seq diff --git a/lib/grpc-eio/seq.ml b/lib/grpc-eio/seq.ml deleted file mode 100644 index dcba634..0000000 --- a/lib/grpc-eio/seq.ml +++ /dev/null @@ -1,28 +0,0 @@ -include Stdlib.Seq -open Eio - -type 'a reader = 'a t -type 'a writer = { mutable resolver : 'a node Promise.u } - -let write writer item = - let promise, resolver = Promise.create () in - let next = Cons (item, fun () -> Promise.await promise) in - Promise.resolve writer.resolver next; - writer.resolver <- resolver - -let close_writer writer = Promise.resolve writer.resolver Nil -let read reader = reader () - -let rec exhaust_reader reader = - match reader () with Nil -> () | Cons (_, reader) -> exhaust_reader reader - -let read_and_exhaust reader = - match reader () with - | Nil -> None - | Cons (item, reader) -> - exhaust_reader reader; - Some item - -let create_reader_writer () = - let promise, resolver = Promise.create () in - ((fun () -> Promise.await promise), { resolver }) diff --git a/lib/grpc-eio/seq.mli b/lib/grpc-eio/seq.mli deleted file mode 100644 index 9f9ee78..0000000 --- a/lib/grpc-eio/seq.mli +++ /dev/null @@ -1,12 +0,0 @@ -include module type of Stdlib.Seq - -type 'a reader = 'a t -type 'a writer - -val create_reader_writer : unit -> 'a reader * 'a writer -val read : 'a reader -> 'a Stdlib.Seq.node -val read_and_exhaust : 'a reader -> 'a option -val exhaust_reader : 'a reader -> unit -val write : 'a writer -> 'a -> unit -val close_writer : 'a writer -> unit -(* val map_writer : ('a -> 'b) -> 'a writer -> 'b writer *) diff --git a/lib/grpc-eio/server.ml b/lib/grpc-eio/server.ml deleted file mode 100644 index ffd850c..0000000 --- a/lib/grpc-eio/server.ml +++ /dev/null @@ -1,129 +0,0 @@ -module ServiceMap = Map.Make (String) - -type service = H2.Reqd.t -> unit -type t = service ServiceMap.t - -let v () = ServiceMap.empty -let add_service ~name ~service t = ServiceMap.add name service t - -let handle_request t reqd = - let request = H2.Reqd.request reqd in - let respond_with code = - H2.Reqd.respond_with_string reqd (H2.Response.create code) "" - in - let route () = - let parts = String.split_on_char '/' request.target in - if List.length parts > 1 then - (* allow for arbitrary prefixes *) - let service_name = List.nth parts (List.length parts - 2) in - let service = ServiceMap.find_opt service_name t in - match service with - | Some service -> service reqd - | None -> respond_with `Not_found - else respond_with `Not_found - in - match request.meth with - | `POST -> ( - match H2.Headers.get request.headers "content-type" with - | Some s -> - if - Stringext.chop_prefix s ~prefix:"application/grpc" |> Option.is_some - then - match H2.Headers.get request.headers "grpc-encoding" with - | None | Some "identity" -> ( - match H2.Headers.get request.headers "grpc-accept-encoding" with - | None -> route () - | Some encodings -> - let encodings = String.split_on_char ',' encodings in - if List.mem "identity" encodings then route () - else respond_with `Not_acceptable) - | Some _ -> - (* TODO: not sure if there is a specific way to handle this in grpc *) - respond_with `Bad_request - else respond_with `Unsupported_media_type - | None -> respond_with `Unsupported_media_type) - | _ -> respond_with `Not_found - -module Rpc = struct - type unary = string -> Grpc.Status.t * string option - type client_streaming = string Seq.t -> Grpc.Status.t * string option - type server_streaming = string -> (string -> unit) -> Grpc.Status.t - - type bidirectional_streaming = - string Seq.t -> (string -> unit) -> Grpc.Status.t - - type t = - | Unary of unary - | Client_streaming of client_streaming - | Server_streaming of server_streaming - | Bidirectional_streaming of bidirectional_streaming - - let bidirectional_streaming ~f reqd = - let body = H2.Reqd.request_body reqd in - let request_reader, request_writer = Seq.create_reader_writer () in - let response_reader, response_writer = Seq.create_reader_writer () in - Connection.grpc_recv_streaming body request_writer; - let status_promise, status_notify = Eio.Promise.create () in - Eio.Fiber.both - (fun () -> - let respond = Seq.write response_writer in - let status = f request_reader respond in - Seq.close_writer response_writer; - Eio.Promise.resolve status_notify status) - (fun () -> - try Connection.grpc_send_streaming reqd response_reader status_promise - with exn -> - (* https://github.com/anmonteiro/ocaml-h2/issues/175 *) - Eio.traceln "%s" (Printexc.to_string exn)) - - let client_streaming ~f reqd = - bidirectional_streaming reqd ~f:(fun requests respond -> - let status, response = f requests in - (match response with None -> () | Some response -> respond response); - status) - - let server_streaming ~f reqd = - bidirectional_streaming reqd ~f:(fun requests respond -> - match Seq.read_and_exhaust requests with - | None -> Grpc.Status.(v OK) - | Some request -> f request respond) - - let unary ~f reqd = - bidirectional_streaming reqd ~f:(fun requests respond -> - match Seq.read_and_exhaust requests with - | None -> Grpc.Status.(v OK) - | Some request -> - let status, response = f request in - (match response with - | None -> () - | Some response -> respond response); - status) -end - -module Service = struct - module RpcMap = Map.Make (String) - - type t = Rpc.t RpcMap.t - - let v () = RpcMap.empty - let add_rpc ~name ~rpc t = RpcMap.add name rpc t - - let handle_request (t : t) reqd = - let request = H2.Reqd.request reqd in - let respond_with code = - H2.Reqd.respond_with_string reqd (H2.Response.create code) "" - in - let parts = String.split_on_char '/' request.target in - if List.length parts > 1 then - let rpc_name = List.nth parts (List.length parts - 1) in - let rpc = RpcMap.find_opt rpc_name t in - match rpc with - | Some rpc -> ( - match rpc with - | Unary f -> Rpc.unary ~f reqd - | Client_streaming f -> Rpc.client_streaming ~f reqd - | Server_streaming f -> Rpc.server_streaming ~f reqd - | Bidirectional_streaming f -> Rpc.bidirectional_streaming ~f reqd) - | None -> respond_with `Not_found - else respond_with `Not_found -end diff --git a/lib/grpc-eio/server.mli b/lib/grpc-eio/server.mli deleted file mode 100644 index 40961f5..0000000 --- a/lib/grpc-eio/server.mli +++ /dev/null @@ -1,50 +0,0 @@ -include Grpc.Server.S - -module Rpc : sig - type unary = string -> Grpc.Status.t * string option - (** [unary] is the type for a unary grpc rpc, one request, one response. *) - - type client_streaming = string Seq.t -> Grpc.Status.t * string option - (** [client_streaming] is the type for an rpc where the client streams the requests and the server responds once. *) - - type server_streaming = string -> (string -> unit) -> Grpc.Status.t - (** [server_streaming] is the type for an rpc where the client sends one request and the server sends multiple responses. *) - - type bidirectional_streaming = - string Seq.t -> (string -> unit) -> Grpc.Status.t - (** [bidirectional_streaming] is the type for an rpc where both the client and server can send multiple messages. *) - - type t = - | Unary of unary - | Client_streaming of client_streaming - | Server_streaming of server_streaming - | Bidirectional_streaming of bidirectional_streaming - - (** [t] represents the types of rpcs available in gRPC. *) - - val unary : f:unary -> H2.Reqd.t -> unit - (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and handles sending the response. *) - - val client_streaming : f:client_streaming -> H2.Reqd.t -> unit - (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from and handles sending the response. *) - - val server_streaming : f:server_streaming -> H2.Reqd.t -> unit - (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] and handles sending the responses pushed out. *) - - val bidirectional_streaming : f:bidirectional_streaming -> H2.Reqd.t -> unit - (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests from and andles sending the responses pushed out. *) -end - -module Service : sig - type t - (** [t] represents a gRPC service with potentially multiple rpcs and the information needed to route to them. *) - - val v : unit -> t - (** [v ()] creates a new service *) - - val add_rpc : name:string -> rpc:Rpc.t -> t -> t - (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to it with [name]. *) - - val handle_request : t -> H2.Reqd.t -> unit - (** [handle_request t reqd] handles routing [reqd] to the correct rpc if available in [t]. *) -end diff --git a/lib/grpc-server/dune b/lib/grpc-server/dune new file mode 100644 index 0000000..ab421c0 --- /dev/null +++ b/lib/grpc-server/dune @@ -0,0 +1,4 @@ +(library + (name grpc_server) + (public_name grpc-server) + (libraries grpc)) diff --git a/lib/grpc-server/grpc_server.ml b/lib/grpc-server/grpc_server.ml new file mode 100644 index 0000000..4eaf558 --- /dev/null +++ b/lib/grpc-server/grpc_server.ml @@ -0,0 +1,80 @@ +module StringMap = Map.Make (String) + +type error = + [ `Not_found of [ `Service_not_found | `Invalid_url | `Bad_method ] + | `Unsupported_media_type + | `Bad_request + | `Grpc of Grpc.Status.t ] + +let error_to_code_and_headers error = + match error with + | `Not_found _ -> (404, []) + | `Unsupported_media_type -> (415, []) + | `Bad_request -> (400, []) + | `Grpc status -> Grpc.Status.to_net_resp status + +let rec service_name_and_method = function + | [] -> None + | [ _ ] -> None + | [ service_name; method_name ] -> Some (service_name, method_name) + | _ :: tl -> service_name_and_method tl + +type parsed_request = { service : string; meth : string } + +let parse_request ~is_post_request ~get_header ~path : + (parsed_request, error) result = + let route () = + let parts = String.split_on_char '/' path in + match service_name_and_method parts with + | Some (service, meth) -> Ok { service; meth } + | None -> Error (`Not_found `Invalid_url) + in + match is_post_request with + | true -> ( + match get_header "content-type" with + | Some s -> + if + Stringext.chop_prefix s ~prefix:"application/grpc" |> Option.is_some + then + match get_header "grpc-encoding" with + | None | Some "identity" -> ( + match get_header "grpc-accept-encoding" with + | None -> route () + | Some encodings -> + let encodings = String.split_on_char ',' encodings in + if List.mem "identity" encodings then route () + else Error (`Grpc (Grpc.Status.make Unimplemented))) + | Some _ -> + (* TODO: not sure if there is a specific way to handle this in grpc *) + Error `Bad_request + else Error `Unsupported_media_type + | None -> Error `Unsupported_media_type) + | _ -> Error (`Not_found `Bad_method) + +type headers = { content_type : string; extra : (string * string) list } +type format = [ `Json | `Proto | `Other of string ] + +let headers ?(extra = []) (format : format) = + { + content_type = + (match format with + | `Json -> "application/grpc+json" + | `Proto -> "application/grpc+proto" + | `Other s -> Printf.sprintf "application/grpc+%s" s); + extra; + } + +let headers_grpc_proto = headers `Proto + +type trailers = { + grpc_status : int; + grpc_message : string option; + extra : (string * string) list; +} + +let make_trailers ?(extra = []) status = + { + grpc_status = Grpc.Status.int_of_code (Grpc.Status.code status); + grpc_message = Grpc.Status.error_message status; + extra; + } diff --git a/lib/grpc-server/grpc_server.mli b/lib/grpc-server/grpc_server.mli new file mode 100644 index 0000000..0356db1 --- /dev/null +++ b/lib/grpc-server/grpc_server.mli @@ -0,0 +1,32 @@ +type error = + [ `Not_found of [ `Service_not_found | `Invalid_url | `Bad_method ] + | `Unsupported_media_type + | `Bad_request + | `Grpc of Grpc.Status.t ] + +val error_to_code_and_headers : error -> int * (string * string) list +(** [error_to_code_and_headers e] returns the HTTP status code and headers + corresponding to [e]. *) + +type parsed_request = { service : string; meth : string } + +val parse_request : + is_post_request:bool -> + get_header:(string -> string option) -> + path:string -> + (parsed_request, error) result +(** [handle_request t handler] handles a request using [handler] and the + services registered in [t]. *) + +type headers = { content_type : string; extra : (string * string) list } + +val headers : ?extra:(string * string) list -> Grpc.Message.format -> headers +val headers_grpc_proto : headers + +type trailers = { + grpc_status : int; + grpc_message : string option; + extra : (string * string) list; +} + +val make_trailers : ?extra:(string * string) list -> Grpc.Status.t -> trailers diff --git a/lib/grpc/buffer.ml b/lib/grpc/buffer.ml index a28c01f..f3edbdb 100644 --- a/lib/grpc/buffer.ml +++ b/lib/grpc/buffer.ml @@ -21,12 +21,7 @@ let copy_from_bigstringaf ~src_off ~src ~dst ~length = ~len:length; dst.length <- dst.length + length -let sub ~start ~length t = - let contents = Bytes.sub t.contents start length in - { contents; length } - -let to_bytes t = Bytes.sub t.contents 0 t.length -let to_string t = to_bytes t |> Bytes.to_string +let sub_string ~start ~length t = Bytes.sub_string t.contents start length let shift_left ~by t = Bytes.blit t.contents by t.contents 0 (t.length - by); @@ -38,3 +33,5 @@ let get_u32_be ~pos t = let high = Bytes.get_uint16_be t.contents pos in let low = Bytes.get_uint16_be t.contents (pos + 2) in (high lsl 16) lor low + +let internal_buffer t = t.contents diff --git a/lib/grpc/buffer.mli b/lib/grpc/buffer.mli index 61437a7..4f395fa 100644 --- a/lib/grpc/buffer.mli +++ b/lib/grpc/buffer.mli @@ -12,26 +12,25 @@ val length : t -> int val capacity : t -> int (** [capacity t] returns the total capacity of the buffer. *) -val to_bytes : t -> bytes -(** [to_bytes t] converts the valid data in the buffer into bytes. *) - -val to_string : t -> string -(** [to_string t] converts the valid data in the buffer into a string. *) - val copy_from_bigstringaf : src_off:int -> src:Bigstringaf.t -> dst:t -> length:int -> unit (** [copy_from_bigstringaf ~src_off ~src ~dst ~length] copies data from [src] into [dst] starting from [src_off] and ending at [src_off + length] to the end of the buffer. *) -val sub : start:int -> length:int -> t -> t -(** [sub ~start ~length t] creates a new buffer from the current, containing the data in the range \[start, start+length). *) +val sub_string : start:int -> length:int -> t -> string +(** [sub_string ~start ~length t] returns a string containing the data in the + range \[start, start+length). *) val get_u8 : pos:int -> t -> int (** [get_u8 ~pos t] returns the unsigned 8 bit integer at [pos] in [t]. *) val get_u32_be : pos:int -> t -> int -(** [get_u32_be ~pos t] returns the unsigned 32 bit big endian integer at [pos] in [t]. *) +(** [get_u32_be ~pos t] returns the unsigned 32 bit big endian integer at [pos] + in [t]. *) val shift_left : by:int -> t -> unit (** [shift_left ~by t] shifts [t] left by [by] positions, discarding the data. *) + +val internal_buffer : t -> Bytes.t +(** [internal_buffer t] returns the internal buffer. *) diff --git a/lib/grpc/dune b/lib/grpc/dune index 1170fe5..32defe6 100644 --- a/lib/grpc/dune +++ b/lib/grpc/dune @@ -3,4 +3,4 @@ (public_name grpc) (preprocess (pps ppx_deriving.show)) - (libraries h2 bigstringaf uri)) + (libraries bigstringaf uri)) diff --git a/lib/grpc/grpc.ml b/lib/grpc/grpc.ml index 00ca697..3ae2efa 100644 --- a/lib/grpc/grpc.ml +++ b/lib/grpc/grpc.ml @@ -1,4 +1,3 @@ -module Server = Server module Status = Status module Message = Message module Buffer = Buffer diff --git a/lib/grpc/message.ml b/lib/grpc/message.ml index 8ad3ce2..3c55f0c 100644 --- a/lib/grpc/message.ml +++ b/lib/grpc/message.ml @@ -1,20 +1,29 @@ +[@@@landmark "auto"] + +let fill_header ~length buffer = + (* write compressed flag (uint8) *) + Bytes.set buffer 0 '\x00'; + (* write msg length (uint32 be) *) + Bytes.set_uint16_be buffer 1 (length lsr 16); + Bytes.set_uint16_be buffer 3 (length land 0xFFFF) + let make content = let content_len = String.length content in let payload = Bytes.create @@ (content_len + 1 + 4) in - (* write compressed flag (uint8) *) - Bytes.set payload 0 '\x00'; - (* write msg length (uint32 be) *) - let length = String.length content in - Bytes.set_uint16_be payload 1 (length lsr 16); - Bytes.set_uint16_be payload 3 (length land 0xFFFF); + fill_header ~length:content_len payload; (* write msg *) Bytes.blit_string content 0 payload 5 content_len; Bytes.to_string payload -(** [extract_message buf] extracts the grpc message starting in [buf] - in the buffer if there is one *) -let extract_message buf = - if Buffer.length buf >= 5 then ( +let get_u32_be ~pos t = + let high = Bytes.get_uint16_be t pos in + let low = Bytes.get_uint16_be t (pos + 2) in + (high lsl 16) lor low + +(** [extract_message buf] extracts the grpc message starting in [buf] in the + buffer if there is one *) +let extract_message_pos ~start buf = + if Bytes.length buf >= 5 + start then ( let compressed = (* A Compressed-Flag value of 1 indicates that the binary octet sequence of Message is compressed using the mechanism declared by @@ -24,24 +33,22 @@ let extract_message buf = new context for each message in the stream. If the Message-Encoding header is omitted then the Compressed-Flag must be 0. *) (* encoded as 1 byte unsigned integer *) - Buffer.get_u8 buf ~pos:0 == 1 + Bytes.get_uint8 buf start == 1 and length = (* encoded as 4 byte unsigned integer (big endian) *) - Buffer.get_u32_be buf ~pos:1 + get_u32_be buf ~pos:(start + 1) in if compressed then failwith "Compressed flag set but not supported"; - if Buffer.length buf - 5 >= length then - Some (Buffer.sub buf ~start:5 ~length |> Buffer.to_string) - else None) + if Bytes.length buf - 5 >= length then Some (start + 5, length) else None) else None -(** [get_message_and_shift buf] tries to extract the first grpc message - from [buf] and if successful shifts these bytes out of the buffer *) +(** [get_message_and_shift buf] tries to extract the first grpc message from + [buf] and if successful shifts these bytes out of the buffer *) let get_message_and_shift buf = - let message = extract_message buf in - match message with + match extract_message_pos ~start:0 (Buffer.internal_buffer buf) with | None -> None - | Some message -> + | Some (start, length) -> + let message = Buffer.sub_string ~start ~length buf in let mlen = String.length message in Buffer.shift_left buf ~by:(5 + mlen); Some message @@ -57,3 +64,10 @@ let extract_all f buf = loop () in loop () + +type format = [ `Json | `Proto | `Other of string ] + +let format_to_content_type = function + | `Json -> "application/grpc+json" + | `Proto -> "application/grpc+proto" + | `Other s -> Printf.sprintf "application/grpc+%s" s diff --git a/lib/grpc/message.mli b/lib/grpc/message.mli index 05e5608..3463bff 100644 --- a/lib/grpc/message.mli +++ b/lib/grpc/message.mli @@ -1,8 +1,20 @@ +val fill_header : length:int -> Bytes.t -> unit +(** [fill_header ~length b] fills the header of a gRPC message in [b]. *) + val make : string -> string (** [make s] encodes a string as a gRPC message. *) val extract : Buffer.t -> string option (** [extract b] attempts to extract a gRPC message from [b]. *) +val extract_message_pos : start:int -> Bytes.t -> (int * int) option +(** [extract b] attempts to extract a gRPC message from [b] and exposes its + internal buffer. *) + val extract_all : (string -> unit) -> Buffer.t -> unit (** [extract_all f b] extracts and calls [f] on all gRPC messages from [b]. *) + +type format = [ `Json | `Proto | `Other of string ] + +val format_to_content_type : format -> string +(** [format_to_content_type f] returns the content type for [f]. *) diff --git a/lib/grpc/server.ml b/lib/grpc/server.ml deleted file mode 100644 index aaea758..0000000 --- a/lib/grpc/server.ml +++ /dev/null @@ -1,14 +0,0 @@ -(** The type of a Server *) -module type S = sig - type t - (** [t] represents a server and its associated services and routing information. *) - - val v : unit -> t - (** [v ()] creates a new server. *) - - val add_service : name:string -> service:(H2.Reqd.t -> unit) -> t -> t - (** [add_service ~name ~service t] adds [service] to [t] and ensures that it is routable via [name]. *) - - val handle_request : t -> H2.Reqd.t -> unit - (** [handle_request t reqd] routes [reqd] to the appropriate service in [t] if available. *) -end diff --git a/lib/grpc/status.ml b/lib/grpc/status.ml index 91bbd39..53b4dcb 100644 --- a/lib/grpc/status.ml +++ b/lib/grpc/status.ml @@ -59,13 +59,15 @@ let code_of_int = function type t = { code : code; message : string option } [@@deriving show] -let v ?message code = { code; message } +let make ?error_message code = { code; message = error_message } let code t = t.code -let message t = Option.map (fun message -> Uri.pct_encode message) t.message -let extract_status headers = +let error_message t = + Option.map (fun message -> Uri.pct_encode message) t.message + +let extract_status ~get_header = let code, message = - match H2.Headers.get headers "grpc-status" with + match get_header "grpc-status" with | None -> (Unknown, Some "Expected gprc-status header, got nothing") | Some s -> ( match int_of_string_opt s with @@ -81,6 +83,36 @@ let extract_status headers = Printf.sprintf "Expected valid gprc-status code, got %i" i in (Unknown, Some msg) - | Some c -> (c, H2.Headers.get headers "grpc-message"))) + | Some c -> (c, get_header "grpc-message"))) + in + make ?error_message:message code + +let status_to_headers status = + let message = error_message status in + ("grpc-status", string_of_int (int_of_code (code status))) + :: (match message with Some s -> [ ("grpc-message", s) ] | None -> []) + +let to_net_resp status = + (* https://cloud.google.com/apis/design/errors#error_model *) + let headers = status_to_headers status in + let status_code = + match code status with + | OK -> 200 + | Cancelled -> 499 + | Unknown -> 500 + | Invalid_argument -> 400 + | Deadline_exceeded -> 504 + | Not_found -> 404 + | Already_exists -> 409 + | Permission_denied -> 403 + | Resource_exhausted -> 429 + | Failed_precondition -> 400 + | Aborted -> 409 + | Out_of_range -> 400 + | Unimplemented -> 501 + | Internal -> 500 + | Unavailable -> 503 + | Data_loss -> 500 + | Unauthenticated -> 401 in - v ?message code + (status_code, headers) diff --git a/lib/grpc/status.mli b/lib/grpc/status.mli index 5494327..a0e2f13 100644 --- a/lib/grpc/status.mli +++ b/lib/grpc/status.mli @@ -29,15 +29,20 @@ val code_of_int : int -> code option type t [@@deriving show] (** [t] represents a full gRPC status, this includes code and optional message. *) -val v : ?message:string -> code -> t -(** [v ~message code] creates a new status with the given [code] and [message]. *) +val make : ?error_message:string -> code -> t +(** [v ~message code] creates a new status with the given [code] and [message]. + It is an error to construct an OK status with non-empty error_message *) val code : t -> code (** [code t] returns the code associated with [t]. *) -val message : t -> string option +val error_message : t -> string option (** [message t] returns the message associated with [t], if there is one. *) -val extract_status : H2.Headers.t -> t -(** [extract_status headers] returns the status embedded in the headers, or a default - when the status is invalid or missing. *) +val extract_status : get_header:(string -> string option) -> t +(** [extract_status ~get_header] returns the status embedded in the headers, or + a default when the status is invalid or missing. *) + +val to_net_resp : t -> int * (string * string) list +(** [to_net_resp t] returns the status code and headers to send over the + network. *) diff --git a/lib/grpc-lwt/client.ml b/lib/lwt/client.ml similarity index 100% rename from lib/grpc-lwt/client.ml rename to lib/lwt/client.ml diff --git a/lib/grpc-lwt/client.mli b/lib/lwt/client.mli similarity index 79% rename from lib/grpc-lwt/client.mli rename to lib/lwt/client.mli index f0f8270..7389b10 100644 --- a/lib/grpc-lwt/client.mli +++ b/lib/lwt/client.mli @@ -9,9 +9,9 @@ module Rpc : sig val client_streaming : f:((string option -> unit) -> string option Lwt.t -> 'a Lwt.t) -> 'a handler - (** [client_streaming ~f write read] sets up the sending and receiving - logic using [write] and [read], then calls [f] with a push function for - requests and promise for the response. *) + (** [client_streaming ~f write read] sets up the sending and receiving logic + using [write] and [read], then calls [f] with a push function for requests + and promise for the response. *) val server_streaming : f:(string Lwt_stream.t -> 'a Lwt.t) -> string -> 'a handler @@ -20,9 +20,9 @@ module Rpc : sig stream of responses. *) val unary : f:(string option Lwt.t -> 'a Lwt.t) -> string -> 'a handler - (** [unary ~f enc write read] sets up the sending and receiving - logic using [write] and [read], then sends [enc] and calls [f] with a - promise for the response. *) + (** [unary ~f enc write read] sets up the sending and receiving logic using + [write] and [read], then sends [enc] and calls [f] with a promise for the + response. *) end type response_handler = H2.Client_connection.response_handler @@ -44,6 +44,6 @@ val call : ?headers:H2.Headers.t -> unit -> ('a * Grpc.Status.t, H2.Status.t) result Lwt.t -(** [call ~service ~rpc ~handler ~do_request ()] calls the rpc endpoint given - by [service] and [rpc] using the [do_request] function. The [handler] is - called when this request is set up to send and receive data. *) +(** [call ~service ~rpc ~handler ~do_request ()] calls the rpc endpoint given by + [service] and [rpc] using the [do_request] function. The [handler] is called + when this request is set up to send and receive data. *) diff --git a/lib/grpc-lwt/connection.ml b/lib/lwt/connection.ml similarity index 100% rename from lib/grpc-lwt/connection.ml rename to lib/lwt/connection.ml diff --git a/lib/grpc-lwt/dune b/lib/lwt/dune similarity index 100% rename from lib/grpc-lwt/dune rename to lib/lwt/dune diff --git a/lib/grpc-lwt/grpc_lwt.ml b/lib/lwt/grpc_lwt.ml similarity index 100% rename from lib/grpc-lwt/grpc_lwt.ml rename to lib/lwt/grpc_lwt.ml diff --git a/lib/grpc-lwt/server.ml b/lib/lwt/server.ml similarity index 100% rename from lib/grpc-lwt/server.ml rename to lib/lwt/server.ml diff --git a/lib/grpc-lwt/server.mli b/lib/lwt/server.mli similarity index 75% rename from lib/grpc-lwt/server.mli rename to lib/lwt/server.mli index a319fa0..d363d8b 100644 --- a/lib/grpc-lwt/server.mli +++ b/lib/lwt/server.mli @@ -6,14 +6,17 @@ module Rpc : sig type client_streaming = string Lwt_stream.t -> (Grpc.Status.t * string option) Lwt.t - (** [client_streaming] is the type for an rpc where the client streams the requests and the server responds once. *) + (** [client_streaming] is the type for an rpc where the client streams the + requests and the server responds once. *) type server_streaming = string -> (string -> unit) -> Grpc.Status.t Lwt.t - (** [server_streaming] is the type for an rpc where the client sends one request and the server sends multiple responses. *) + (** [server_streaming] is the type for an rpc where the client sends one + request and the server sends multiple responses. *) type bidirectional_streaming = string Lwt_stream.t -> (string -> unit) -> Grpc.Status.t Lwt.t - (** [bidirectional_streaming] is the type for an rpc where both the client and server can send multiple messages. *) + (** [bidirectional_streaming] is the type for an rpc where both the client and + server can send multiple messages. *) type t = | Unary of unary @@ -24,29 +27,36 @@ module Rpc : sig (** [t] represents the types of rpcs available in gRPC. *) val unary : f:unary -> H2.Reqd.t -> unit Lwt.t - (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and handles sending the response. *) + (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and + handles sending the response. *) val client_streaming : f:client_streaming -> H2.Reqd.t -> unit Lwt.t - (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from and handles sending the response. *) + (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from + and handles sending the response. *) val server_streaming : f:server_streaming -> H2.Reqd.t -> unit Lwt.t - (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] and handles sending the responses pushed out. *) + (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] + and handles sending the responses pushed out. *) val bidirectional_streaming : f:bidirectional_streaming -> H2.Reqd.t -> unit Lwt.t - (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests from and andles sending the responses pushed out. *) + (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests + from and andles sending the responses pushed out. *) end module Service : sig type t - (** [t] represents a gRPC service with potentially multiple rpcs and the information needed to route to them. *) + (** [t] represents a gRPC service with potentially multiple rpcs and the + information needed to route to them. *) val v : unit -> t (** [v ()] creates a new service *) val add_rpc : name:string -> rpc:Rpc.t -> t -> t - (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to it with [name]. *) + (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to + it with [name]. *) val handle_request : t -> H2.Reqd.t -> unit - (** [handle_request t reqd] handles routing [reqd] to the correct rpc if available in [t]. *) + (** [handle_request t reqd] handles routing [reqd] to the correct rpc if + available in [t]. *) end diff --git a/overlay.nix b/overlay.nix new file mode 100644 index 0000000..d673a6c --- /dev/null +++ b/overlay.nix @@ -0,0 +1,68 @@ +self: super: +let inherit (super) fetchFromGitHub; +in { + ocaml-ng = super.ocaml-ng // { + ocamlPackages_5_1 = super.ocaml-ng.ocamlPackages_5_1.overrideScope' + (oself: super: + let + ocamlProtocSrc = fetchFromGitHub { + owner = "mransan"; + repo = "ocaml-protoc"; + rev = "292165b0f23f75973ac533ee48bae325544a42a9"; + sha256 = "sha256-P5Y+Sk9EfIgK1wSoMDImCoYEF/npdWVMTP5/3msDDhM="; + fetchSubmodules = true; + }; + in { + h2 = super.h2.overrideAttrs (_: { + src = fetchFromGitHub { + owner = "dialohq"; + repo = "ocaml-h2"; + rev = "5fc0a4976ed25248872bac487ba344ebcaa76de0"; + sha256 = "sha256-SZKv6Cv45hRrM1e/P7bmmWT96IERmF41wUvyaQGHj3g="; + fetchSubmodules = true; + }; + }); + h2-eio = super.h2-eio.overrideAttrs (_: { + src = fetchFromGitHub { + owner = "dialohq"; + repo = "ocaml-h2"; + rev = "5fc0a4976ed25248872bac487ba344ebcaa76de0"; + sha256 = "sha256-SZKv6Cv45hRrM1e/P7bmmWT96IERmF41wUvyaQGHj3g="; + + fetchSubmodules = true; + }; + }); + pbrt = super.pbrt.overrideAttrs (_: { src = ocamlProtocSrc; }); + pbrt_services = super.buildDunePackage ({ + pname = "pbrt_services"; + version = "3.0.1"; + duneVersion = "3"; + propagatedBuildInputs = [ oself.pbrt oself.pbrt_yojson ]; + src = ocamlProtocSrc; + }); + pbrt_yojson = super.buildDunePackage ({ + pname = "pbrt_yojson"; + version = "3.0.1"; + duneVersion = "3"; + propagatedBuildInputs = [ super.yojson super.base64 ]; + src = ocamlProtocSrc; + }); + ocaml-protoc = super.ocaml-protoc.overrideAttrs (_: { + propagatedBuildInputs = super.ocaml-protoc.propagatedBuildInputs + ++ [ oself.pbrt_yojson oself.pbrt_services ]; + src = ocamlProtocSrc; + }); + gluten-eio = super.gluten-eio.overrideAttrs (_: { + src = fetchFromGitHub { + owner = "dialohq"; + repo = "gluten"; + rev = "e9ae4690ebd65b143e69955b1dc26ac77c25fa91"; + sha256 = "sha256-hT62/TWFD11Irn+fy43nNGB8PKF1UAux0i9+9U3a/Ho="; + + fetchSubmodules = true; + }; + }); + }); + }; +} +