From acc7bcc1d9f49402bfa2c489e79ccd7072e4d448 Mon Sep 17 00:00:00 2001 From: Leandro Ostera Date: Sun, 24 Dec 2023 03:11:56 +0100 Subject: [PATCH] feat: impl writer for socket and test it --- riot/lib/io.ml | 17 ++--- riot/lib/net.ml | 9 +++ riot/riot.mli | 14 +++-- test/dune | 5 ++ test/io_writer_test.ml | 4 +- test/net_reader_writer_test.ml | 111 +++++++++++++++++++++++++++++++++ 6 files changed, 145 insertions(+), 15 deletions(-) create mode 100644 test/net_reader_writer_test.ml diff --git a/riot/lib/io.ml b/riot/lib/io.ml index 8da19bb9..e236ae7e 100644 --- a/riot/lib/io.ml +++ b/riot/lib/io.ml @@ -92,14 +92,14 @@ module Writer = struct let flush = B.flush end - type 't write = (module Write with type t = 't) - type 't writer = Writer : ('t write * 't) -> 't writer + type 'src write = (module Write with type t = 'src) + type 'src t = Writer : ('src write * 'src) -> 'src t - let of_write_src : type src. src write -> src -> src writer = + let of_write_src : type src. src write -> src -> src t = fun write src -> Writer (write, src) let write : - type src. src writer -> data:Buffer.t -> (int, [> `Closed ]) result = + type src. src t -> data:Buffer.t -> (int, [> `Closed ]) result = fun (Writer ((module W), src)) ~data -> W.write src ~data end @@ -116,13 +116,14 @@ module Reader = struct let read = B.read end - type 't read = (module Read with type t = 't) - type 't reader = Reader : ('t read * 't) -> 't reader + type 'src read = (module Read with type t = 'src) + type 'src t = Reader : ('src read * 'src) -> 'src t + type 'src reader = 'src t - let of_read_src : type src. src read -> src -> src reader = + let of_read_src : type src. src read -> src -> src t = fun read src -> Reader (read, src) - let read : type src. src reader -> buf:Buffer.t -> (int, [> `Closed ]) result + let read : type src. src t -> buf:Buffer.t -> (int, [> `Closed ]) result = fun (Reader ((module R), src)) ~buf -> R.read src ~buf diff --git a/riot/lib/net.ml b/riot/lib/net.ml index c2cc2265..19d1f26d 100644 --- a/riot/lib/net.ml +++ b/riot/lib/net.ml @@ -143,4 +143,13 @@ module Socket = struct end) let to_reader t = Io.Reader.of_read_src (module Read) t + + module Write = Io.Writer.Make (struct + type t = stream_socket + + let write t ~data = send ~data t + let flush _t = Ok () + end) + + let to_writer t = Io.Writer.of_write_src (module Write) t end diff --git a/riot/riot.mli b/riot/riot.mli index c698a2c0..360bc881 100644 --- a/riot/riot.mli +++ b/riot/riot.mli @@ -443,9 +443,9 @@ module IO : sig end module Writer : sig - type 'src writer + type 'src t - val write : 'src writer -> data:Buffer.t -> (int, [> `Closed ]) result + val write : 'src t -> data:Buffer.t -> (int, [> `Closed ]) result module Make (B : Write) : sig type t = B.t @@ -462,7 +462,8 @@ module IO : sig end module Reader : sig - type 'src reader + type 'src t + type 'src reader = 'src t val read : 'src reader -> buf:Buffer.t -> (int, [> `Closed ]) result @@ -488,8 +489,8 @@ module File : sig val open_write : string -> [ `w ] file val close : _ file -> unit val remove : string -> unit - val to_reader : [ `r ] file -> [ `r ] file IO.Reader.reader - val to_writer : [ `w ] file -> [ `w ] file IO.Writer.writer + val to_reader : [ `r ] file -> [ `r ] file IO.Reader.t + val to_writer : [ `w ] file -> [ `w ] file IO.Writer.t end module Net : sig @@ -556,6 +557,9 @@ module Net : sig Format.formatter -> [ IO.unix_error | `Closed | `Timeout | `System_limit ] -> unit + + val to_reader : stream_socket -> stream_socket IO.Reader.t + val to_writer : stream_socket -> stream_socket IO.Writer.t end end diff --git a/test/dune b/test/dune index 921648a4..bee0c4d8 100644 --- a/test/dune +++ b/test/dune @@ -41,6 +41,11 @@ (modules net_addr_uri_test) (libraries riot)) +(test + (name net_reader_writer_test) + (modules net_reader_writer_test) + (libraries riot)) + (test (name net_test) (modules net_test) diff --git a/test/io_writer_test.ml b/test/io_writer_test.ml index 8d7c25e0..b19ea89c 100644 --- a/test/io_writer_test.ml +++ b/test/io_writer_test.ml @@ -8,8 +8,8 @@ let () = Logger.set_log_level (Some Info); let now = Ptime_clock.now () in let path = - Format.asprintf "./test/generated/%a.io_writer_test.txt" - (Ptime.pp_rfc3339 ()) now + Format.asprintf "./generated/%a.io_writer_test.txt" (Ptime.pp_rfc3339 ()) + now in let file = File.open_write path in let writer = File.to_writer file in diff --git a/test/net_reader_writer_test.ml b/test/net_reader_writer_test.ml new file mode 100644 index 00000000..54e27c06 --- /dev/null +++ b/test/net_reader_writer_test.ml @@ -0,0 +1,111 @@ +open Riot + +type Message.t += Received of string + +(* rudimentary tcp echo server *) +let server port = + let socket = Net.Socket.listen ~port () |> Result.get_ok in + Logger.debug (fun f -> f "Started server on %d" port); + process_flag (Trap_exit true); + let conn, addr = Net.Socket.accept socket |> Result.get_ok in + Logger.debug (fun f -> + f "Accepted client %a (%a)" Net.Addr.pp addr Net.Socket.pp conn); + let close () = + Net.Socket.close conn; + Logger.debug (fun f -> + f "Closed client %a (%a)" Net.Addr.pp addr Net.Socket.pp conn) + in + + let reader = Net.Socket.to_reader conn in + let writer = Net.Socket.to_writer conn in + + let buf = IO.Buffer.with_capacity 1024 in + let rec echo () = + Logger.debug (fun f -> + f "Reading from client client %a (%a)" Net.Addr.pp addr Net.Socket.pp + conn); + match IO.Reader.read reader ~buf with + | Ok len -> ( + Logger.debug (fun f -> f "Server received %d bytes" len); + let data = IO.Buffer.sub ~off:0 ~len buf in + match IO.Writer.write ~data writer with + | Ok bytes -> + Logger.debug (fun f -> f "Server sent %d bytes" bytes); + echo () + | Error `Closed -> close () + | Error (`Unix_error unix_err) -> + Logger.error (fun f -> + f "send unix error %s" (Unix.error_message unix_err)); + close ()) + | Error (`Closed | `Timeout) -> close () + | Error (`Unix_error unix_err) -> + Logger.error (fun f -> + f "recv unix error %s" (Unix.error_message unix_err)); + close () + in + echo () + +let client port main = + let addr = Net.Addr.(tcp loopback port) in + let conn = Net.Socket.connect addr |> Result.get_ok in + Logger.debug (fun f -> f "Connected to server on %d" port); + let data = IO.Buffer.of_string "hello world" in + + let reader = Net.Socket.to_reader conn in + let writer = Net.Socket.to_writer conn in + + let rec send_loop n = + sleep 0.001; + if n = 0 then Logger.error (fun f -> f "client retried too many times") + else + match IO.Writer.write ~data writer with + | Ok bytes -> Logger.debug (fun f -> f "Client sent %d bytes" bytes) + | Error `Closed -> Logger.debug (fun f -> f "connection closed") + | Error (`Unix_error (ENOTCONN | EPIPE)) -> send_loop n + | Error (`Unix_error unix_err) -> + Logger.error (fun f -> + f "client unix error %s" (Unix.error_message unix_err)); + send_loop (n - 1) + in + send_loop 10_000; + + let buf = IO.Buffer.with_capacity 128 in + let recv_loop () = + match IO.Reader.read ~buf reader with + | Ok bytes -> + Logger.debug (fun f -> f "Client received %d bytes" bytes); + bytes + | Error (`Closed | `Timeout) -> + Logger.error (fun f -> f "Server closed the connection"); + 0 + | Error (`Unix_error unix_err) -> + Logger.error (fun f -> + f "client unix error %s" (Unix.error_message unix_err)); + 0 + in + let len = recv_loop () in + + if len = 0 then send main (Received "empty paylaod") + else send main (Received (IO.Buffer.to_string buf)) + +let () = + Riot.run @@ fun () -> + let _ = Logger.start () |> Result.get_ok in + Logger.set_log_level (Some Info); + let port = 2112 in + let main = self () in + let _server = spawn (fun () -> server port) in + let _client = spawn (fun () -> client port main) in + match receive () with + | Received "hello world" -> + Logger.info (fun f -> f "net_reader_writer_test: OK"); + sleep 0.001; + shutdown () + | Received other -> + Logger.error (fun f -> f "net_reader_writer_test: bad payload: %S" other); + sleep 0.001; + Stdlib.exit 1 + | _ -> + Logger.error (fun f -> f "net_reader_writer_test: unexpected message"); + sleep 0.001; + Stdlib.exit 1