diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..7655de0 --- /dev/null +++ b/dune-project @@ -0,0 +1 @@ +(lang dune 1.1) diff --git a/examples/lwt/dune b/examples/lwt/dune new file mode 100644 index 0000000..ddb3e52 --- /dev/null +++ b/examples/lwt/dune @@ -0,0 +1,3 @@ +(executables + (names wscat echo_server) + (libraries websocketaf websocketaf-lwt lwt lwt.unix)) diff --git a/examples/lwt/echo_server.ml b/examples/lwt/echo_server.ml new file mode 100644 index 0000000..7356301 --- /dev/null +++ b/examples/lwt/echo_server.ml @@ -0,0 +1,71 @@ +let connection_handler : Unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t = + let module Body = Httpaf.Body in + let module Headers = Httpaf.Headers in + let module Reqd = Httpaf.Reqd in + let module Response = Httpaf.Response in + let module Status = Httpaf.Status in + + let websocket_handler _ wsd = + let frame ~opcode ~is_fin:_ bs ~off ~len = + match opcode with + | `Continuation + | `Text + | `Binary -> + Websocketaf.Wsd.schedule wsd bs ~kind:`Text ~off ~len + | `Connection_close -> + Websocketaf.Wsd.close wsd + | `Ping -> + Websocketaf.Wsd.send_ping wsd + | `Pong + | `Other _ -> + () + in + let eof () = () + in + { Websocketaf.Server_connection.frame + ; eof + } + in + + let error_handler _client_address ?request:_ error start_response = + let response_body = start_response Headers.empty in + + begin match error with + | `Exn exn -> + Body.write_string response_body (Printexc.to_string exn); + Body.write_string response_body "\n"; + + | #Status.standard as error -> + Body.write_string response_body (Status.default_reason_phrase error) + end; + in + + Websocketaf_lwt.Server.create_connection_handler + ?config:None + ~websocket_handler + ~error_handler + + + +let () = + let open Lwt.Infix in + + let port = ref 8080 in + Arg.parse + ["-p", Arg.Set_int port, " Listening port number (8080 by default)"] + ignore + "Echoes websocket messages. Runs forever."; + + let listen_address = Unix.(ADDR_INET (inet_addr_loopback, !port)) in + + Lwt.async begin fun () -> + Lwt_io.establish_server_with_client_socket + listen_address connection_handler + >>= fun _server -> + Printf.printf "Listening on port %i and echoing websocket messages.\n" !port; + flush stdout; + Lwt.return_unit + end; + + let forever, _ = Lwt.wait () in + Lwt_main.run forever diff --git a/examples/lwt/wscat.ml b/examples/lwt/wscat.ml new file mode 100644 index 0000000..b32c69b --- /dev/null +++ b/examples/lwt/wscat.ml @@ -0,0 +1,66 @@ +open Lwt.Infix + +let websocket_handler wsd = + let rec input_loop wsd () = + Lwt_io.(read_line stdin) >>= fun line -> + let payload = Bytes.of_string line in + Websocketaf.Wsd.send_bytes wsd ~kind:`Text payload ~off:0 ~len:(Bytes.length payload); + input_loop wsd () + in + Lwt.async (input_loop wsd); + let frame ~opcode:_ ~is_fin:_ bs ~off ~len = + let payload = Bytes.create len in + Lwt_bytes.blit_to_bytes + bs off + payload 0 + len; + Printf.printf "%s\n" (Bytes.unsafe_to_string payload); + flush stdout + in + let eof () = + Printf.printf "[EOF]\n" + in + { Websocketaf.Client_connection.frame + ; eof + } + +let error_handler = function + | `Handshake_failure (rsp, _body) -> + Format.printf "Handshake failure: %a" Httpaf.Response.pp_hum rsp + | _ -> assert false + +let () = + let host = ref None in + let port = ref 80 in + + Arg.parse + ["-p", Set_int port, " Port number (80 by default)"] + (fun host_argument -> host := Some host_argument) + "wscat.exe [-p N] HOST"; + + let host = + match !host with + | None -> failwith "No hostname provided" + | Some host -> host + in + + Lwt_main.run begin + Lwt_unix.getaddrinfo host (string_of_int !port) [Unix.(AI_FAMILY PF_INET)] + >>= fun addresses -> + + let socket = Lwt_unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in + Lwt_unix.connect socket (List.hd addresses).Unix.ai_addr + >>= fun () -> + + let nonce = "0123456789ABCDEF" in + let resource = "/" in + let port = !port in + Websocketaf_lwt.Client.connect + socket + ~nonce + ~host + ~port + ~resource + ~error_handler + ~websocket_handler + end diff --git a/lib/bigstring.ml b/lib/bigstring.ml deleted file mode 100644 index f733224..0000000 --- a/lib/bigstring.ml +++ /dev/null @@ -1,111 +0,0 @@ -type bigstring = - (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t - -type t = bigstring - -let create size = Bigarray.(Array1.create char c_layout size) -let empty = create 0 - -module BA1 = Bigarray.Array1 - -let length t = BA1.dim t -external unsafe_get : t -> int -> char = "%caml_ba_unsafe_ref_1" -external unsafe_set : t -> int -> char -> unit = "%caml_ba_unsafe_set_1" - -external blit : t -> int -> t -> int -> int -> unit = - "angstrom_bigstring_blit_to_bigstring" [@@noalloc] - -external blit_to_bytes : t -> int -> Bytes.t -> int -> int -> unit = - "angstrom_bigstring_blit_to_bytes" [@@noalloc] - -external blit_from_bytes : Bytes.t -> int -> t -> int -> int -> unit = - "angstrom_bigstring_blit_from_bytes" [@@noalloc] - -let blit_from_string src src_off dst dst_off len = - blit_from_bytes (Bytes.unsafe_of_string src) src_off dst dst_off len - -let sub t ~off ~len = - BA1.sub t off len - -let copy src ~off ~len = - let dst = create len in - BA1.blit (BA1.sub src off len) dst; - dst - -let substring t ~off ~len = - let b = Bytes.create len in - blit_to_bytes t off b 0 len; - Bytes.unsafe_to_string b - -let of_string ~off ~len s = - let b = create len in - blit_from_string s off b 0 len; - b - -external caml_bigstring_set_16 : bigstring -> off:int -> int -> unit = "%caml_bigstring_set16u" -external caml_bigstring_set_32 : bigstring -> off:int -> int32 -> unit = "%caml_bigstring_set32u" -external caml_bigstring_set_64 : bigstring -> off:int -> int64 -> unit = "%caml_bigstring_set64u" - -external caml_bigstring_get_16 : bigstring -> off:int -> int = "%caml_bigstring_get16u" -external caml_bigstring_get_32 : bigstring -> off:int -> int32 = "%caml_bigstring_get32u" -external caml_bigstring_get_64 : bigstring -> off:int -> int64 = "%caml_bigstring_get64u" - -module Swap = struct - external bswap16 : int -> int = "%bswap16" - external bswap_int32 : int32 -> int32 = "%bswap_int32" - external bswap_int64 : int64 -> int64 = "%bswap_int64" - - let caml_bigstring_set_16 bs ~off i = - caml_bigstring_set_16 bs ~off (bswap16 i) - - let caml_bigstring_set_32 bs ~off i = - caml_bigstring_set_32 bs ~off (bswap_int32 i) - - let caml_bigstring_set_64 bs ~off i = - caml_bigstring_set_64 bs ~off (bswap_int64 i) - - let caml_bigstring_get_16 bs ~off = - bswap16 (caml_bigstring_get_16 bs ~off) - - let caml_bigstring_get_32 bs ~off = - bswap_int32 (caml_bigstring_get_32 bs ~off) - - let caml_bigstring_get_64 bs ~off = - bswap_int64 (caml_bigstring_get_64 bs ~off) -end - -let unsafe_set_16_le, unsafe_set_16_be = - if Sys.big_endian - then Swap.caml_bigstring_set_16, caml_bigstring_set_16 - else caml_bigstring_set_16 , Swap.caml_bigstring_set_16 - -let unsafe_set_32_le, unsafe_set_32_be = - if Sys.big_endian - then Swap.caml_bigstring_set_32, caml_bigstring_set_32 - else caml_bigstring_set_32 , Swap.caml_bigstring_set_32 - -let unsafe_set_64_le, unsafe_set_64_be = - if Sys.big_endian - then Swap.caml_bigstring_set_64, caml_bigstring_set_64 - else caml_bigstring_set_64 , Swap.caml_bigstring_set_64 - -let unsafe_get_u16_le, unsafe_get_u16_be = - if Sys.big_endian - then Swap.caml_bigstring_get_16, caml_bigstring_get_16 - else caml_bigstring_get_16 , Swap.caml_bigstring_get_16 - -let unsafe_get_16_le x ~off = - ((unsafe_get_u16_le x ~off) lsl (Sys.int_size - 16)) asr (Sys.int_size - 16) - -let unsafe_get_16_be x ~off = - ((unsafe_get_u16_be x~off ) lsl (Sys.int_size - 16)) asr (Sys.int_size - 16) - -let unsafe_get_32_le, unsafe_get_32_be = - if Sys.big_endian - then Swap.caml_bigstring_get_32, caml_bigstring_get_32 - else caml_bigstring_get_32 , Swap.caml_bigstring_get_32 - -let unsafe_get_64_le, unsafe_get_64_be = - if Sys.big_endian - then Swap.caml_bigstring_get_64, caml_bigstring_get_64 - else caml_bigstring_get_64 , Swap.caml_bigstring_get_64 diff --git a/lib/client_connection.ml b/lib/client_connection.ml index a101ad4..1742510 100644 --- a/lib/client_connection.ml +++ b/lib/client_connection.ml @@ -1,3 +1,5 @@ +module IOVec = Httpaf.IOVec + type state = | Uninitialized | Handshake of Client_handshake.t @@ -5,11 +7,19 @@ type state = type t = state ref +type error = + [ Httpaf.Client_connection.error + | `Handshake_failure of Httpaf.Response.t * [`read] Httpaf.Body.t ] + +type input_handlers = Client_websocket.input_handlers = + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } + let passes_scrutiny ~accept headers = let upgrade = Httpaf.Headers.get headers "upgrade" in let connection = Httpaf.Headers.get headers "connection" in let sec_websocket_accept = Httpaf.Headers.get headers "sec-websocket-accept" in - sec_websocket_accept = accept + sec_websocket_accept = Some accept && (match upgrade with | None -> false | Some upgrade -> String.lowercase_ascii upgrade = "websocket") @@ -18,26 +28,36 @@ let passes_scrutiny ~accept headers = | Some connection -> String.lowercase_ascii connection = "upgrade") ;; -let create - ~nonce - ~host - ~port +let handshake_exn t = + match !t with + | Handshake handshake -> handshake + | Uninitialized + | Websocket _ -> assert false + +let create + ~nonce + ~host + ~port ~resource ~sha1 ~error_handler ~websocket_handler = let t = ref Uninitialized in + let nonce = B64.encode nonce in let response_handler response response_body = let accept = sha1 (nonce ^ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") in - match response.code with + match response.Httpaf.Response.status with | `Switching_protocols when passes_scrutiny ~accept response.headers -> - Body.close response_body response_body; - t := Websocket (Client_websocket.create ~websocket_handler ~eof_handler) + Httpaf.Body.close_reader response_body; + let handshake = handshake_exn t in + t := Websocket (Client_websocket.create ~websocket_handler); + Client_handshake.close handshake | _ -> error_handler (`Handshake_failure(response, response_body)) in - let handshake = + let handshake = + let error_handler = (error_handler :> Httpaf.Client_connection.error_handler) in Client_handshake.create ~nonce ~host @@ -64,18 +84,11 @@ let read t bs ~off ~len = | Websocket websocket -> Client_websocket.read websocket bs ~off ~len ;; -let yield_reader t f = +let read_eof t bs ~off ~len = match !t with | Uninitialized -> assert false - | Handshake handshake -> Client_handshake.yield_reader handshake f - | Websocket websocket -> Client_websocket.yield_reader websocket f -;; - -let shutdown_reader t = - match !t with - | Uninitialized -> assert false - | Handshake handshake -> Client_handshake.shutdown_reader handshake - | Websocket websocket -> Client_websocket.shutdown_reader websocket + | Handshake handshake -> Client_handshake.read handshake bs ~off ~len + | Websocket websocket -> Client_websocket.read_eof websocket bs ~off ~len ;; let next_write_operation t = @@ -95,13 +108,13 @@ let report_write_result t result = let yield_writer t f = match !t with | Uninitialized -> assert false - | Handshake handshake -> Client_handshake.yield_writer handshake f - | Websocket websocket -> Client_websocket.yield_writer websocket f + | Handshake handshake -> Client_handshake.yield_writer handshake f + | Websocket websocket -> Client_websocket.yield_writer websocket f ;; let close t = match !t with | Uninitialized -> assert false - | Handshake handshake -> Client_handshake.close handshake f - | Websocket websocket -> Client_websocket.close websocket f + | Handshake handshake -> Client_handshake.close handshake + | Websocket websocket -> Client_websocket.close websocket ;; diff --git a/lib/client_connection.mli b/lib/client_connection.mli index 625026b..e32d475 100644 --- a/lib/client_connection.mli +++ b/lib/client_connection.mli @@ -1,32 +1,32 @@ module IOVec = Httpaf.IOVec -type t +type t type error = [ Httpaf.Client_connection.error - | `Handshake_failure of Httpaf.Request.t * [`read] Httpaf.Body.t ] + | `Handshake_failure of Httpaf.Response.t * [`read] Httpaf.Body.t ] -type input_handlers = Client_websocket.input_handlers = - { frame : Websocket.Opcode.t -> is_fin:bool -> Bigstring.t -> off:int -> len:int -> unit - ; eof : unit -> unit } +type input_handlers = Client_websocket.input_handlers = + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } -val create +val create : nonce : string -> host : string -> port : int -> resource : string -> sha1 : (string -> string) - -> error_handler : (Httpaf.Response.t -> [`read] Httpaf.Body.t -> unit) - -> websocket_handler : (Wsd.t -> input_handlers) + -> error_handler : (error -> unit) + -> websocket_handler : (Wsd.t -> input_handlers) -> t -val next_read_operation : t -> [ `Read | `Yield | `Close ] -val next_write_operation : t -> [ `Write of Bigstring.t IOVec.t list | `Yield | `Close of int ] +val next_read_operation : t -> [ `Read | `Close ] +val next_write_operation : t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] -val read : t -> Bigstring.t -> off:int -> len:int -> int +val read : t -> Bigstringaf.t -> off:int -> len:int -> int +val read_eof : t -> Bigstringaf.t -> off:int -> len:int -> int val report_write_result : t -> [`Ok of int | `Closed ] -> unit -val yield_reader : t -> (unit -> unit) -> unit val yield_writer : t -> (unit -> unit) -> unit val close : t -> unit diff --git a/lib/client_handshake.ml b/lib/client_handshake.ml index f4e02f7..6b25b5e 100644 --- a/lib/client_handshake.ml +++ b/lib/client_handshake.ml @@ -1,29 +1,50 @@ module IOVec = Httpaf.IOVec -include Httpaf.Client_connection -let create - ~nonce - ~host - ~port +type t = + { connection : Httpaf.Client_connection.t + ; body : [`write] Httpaf.Body.t } + +let create + ~nonce + ~host + ~port ~resource ~error_handler - ~response_handler + ~response_handler = - let nonce = B64.encode nonce in let headers = [ "upgrade" , "websocket" ; "connection" , "upgrade" ; "host" , String.concat ":" [ host; string_of_int port ] - ; "sec-websocket-version", "13" + ; "sec-websocket-version", "13" ; "sec-websocket-key" , nonce ] |> Httpaf.Headers.of_list in - let request_body, t = + let body, connection = Httpaf.Client_connection.request (Httpaf.Request.create ~headers `GET resource) ~error_handler ~response_handler in - Httpaf.Body.close request_body; - t + { connection + ; body + } ;; + +let next_read_operation t = + Httpaf.Client_connection.next_read_operation t.connection + +let next_write_operation t = + Httpaf.Client_connection.next_write_operation t.connection + +let read t = + Httpaf.Client_connection.read t.connection + +let report_write_result t = + Httpaf.Client_connection.report_write_result t.connection + +let yield_writer t = + Httpaf.Client_connection.yield_writer t.connection + +let close t = + Httpaf.Body.close_writer t.body diff --git a/lib/client_handshake.mli b/lib/client_handshake.mli index e011380..78e4a97 100644 --- a/lib/client_handshake.mli +++ b/lib/client_handshake.mli @@ -1,8 +1,8 @@ module IOVec = Httpaf.IOVec -type t +type t -val create +val create : nonce : string -> host : string -> port : int @@ -12,10 +12,11 @@ val create -> t val next_read_operation : t -> [ `Read | `Close ] -val next_write_operation : t -> [ `Write of Bigstring.t IOVec.t list | `Yield | `Close of int ] +val next_write_operation : t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] -val read : t -> Bigstring.t -> off:int -> len:int -> int -val shutdown_reader : t -> unit +val read : t -> Bigstringaf.t -> off:int -> len:int -> int val report_write_result : t -> [`Ok of int | `Closed ] -> unit val yield_writer : t -> (unit -> unit) -> unit + +val close : t -> unit diff --git a/lib/client_websocket.ml b/lib/client_websocket.ml index 2f1016c..ab57350 100644 --- a/lib/client_websocket.ml +++ b/lib/client_websocket.ml @@ -1,14 +1,46 @@ +module IOVec = Httpaf.IOVec + type t = - { reader : Reader.t + { reader : [`Parse of string list * string] Reader.t ; wsd : Wsd.t } type input_handlers = - { frame : Websocket.Opcode.t -> is_fin:bool -> Bigstring.t -> off:int -> len:int -> unit - ; eof : unit -> unit } + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } + +let random_int32 () = + Random.int32 Int32.max_int let create ~websocket_handler = - let wsd = Wsd.create () in - let { frame; eof } = websocket_handler wsd in + let mode = `Client random_int32 in + let wsd = Wsd.create mode in + let { frame; _ } = websocket_handler wsd in { reader = Reader.create frame - ; wds + ; wsd } + +let next_read_operation t = + Reader.next t.reader + +let next_write_operation t = + Wsd.next t.wsd + +let read t bs ~off ~len = + Reader.read_with_more t.reader bs ~off ~len Incomplete + +let read_eof t bs ~off ~len = + Reader.read_with_more t.reader bs ~off ~len Complete + +let report_write_result t result = + Wsd.report_result t.wsd result + +let yield_writer t k = + if Wsd.is_closed t.wsd + then begin + Wsd.close t.wsd; + k () + end else + Wsd.when_ready_to_write t.wsd k + +let close { wsd; _ } = + Wsd.close wsd diff --git a/lib/client_websocket.mli b/lib/client_websocket.mli index 3d3a279..cf76a79 100644 --- a/lib/client_websocket.mli +++ b/lib/client_websocket.mli @@ -3,20 +3,20 @@ module IOVec = Httpaf.IOVec type t type input_handlers = - { frame : Websocket.Opcode.t -> is_fin:bool -> Bigstring.t -> off:int -> len:int -> unit - ; eof : unit -> unit } + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } -val create +val create : websocket_handler : (Wsd.t -> input_handlers) -> t -val next_read_operation : t -> [ `Read | `Yield | `Close ] -val next_write_operation : t -> [ `Write of Bigstring.t IOVec.t list | `Yield | `Close of int ] +val next_read_operation : t -> [ `Read | `Close ] +val next_write_operation : t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] -val read : t -> Bigstring.t -> off:int -> len:int -> int +val read : t -> Bigstringaf.t -> off:int -> len:int -> int +val read_eof : t -> Bigstringaf.t -> off:int -> len:int -> int val report_write_result : t -> [`Ok of int | `Closed ] -> unit -val yield_reader : t -> (unit -> unit) -> unit val yield_writer : t -> (unit -> unit) -> unit val close : t -> unit diff --git a/lib/dune b/lib/dune new file mode 100644 index 0000000..43e0352 --- /dev/null +++ b/lib/dune @@ -0,0 +1,7 @@ +(library + (name websocketaf) + (public_name websocketaf) + (libraries + base64 angstrom faraday httpaf result bigstringaf) + (flags (:standard -safe-string))) + diff --git a/lib/jbuild b/lib/jbuild deleted file mode 100644 index f830add..0000000 --- a/lib/jbuild +++ /dev/null @@ -1,9 +0,0 @@ -(jbuild_version 1) - -(library - ((name websocketaf) - (public_name websocketaf) - (libraries - (base64 angstrom faraday httpaf result)) - (flags (:standard -safe-string)))) - diff --git a/lib/reader.ml b/lib/reader.ml index b7634c8..3d4f0e2 100644 --- a/lib/reader.ml +++ b/lib/reader.ml @@ -3,32 +3,35 @@ module AU = Angstrom.Unbuffered type 'error parse_state = | Done | Fail of 'error - | Partial of (Bigstring.t -> off:int -> len:int -> AU.more -> (unit, 'error) result AU.state) + | Partial of (Bigstringaf.t -> off:int -> len:int -> AU.more -> unit AU.state) type 'error t = { parser : unit Angstrom.t ; mutable parse_state : 'error parse_state ; mutable closed : bool } -let create frame_handler = - let open Angstrom in - Websocket.parser - >>| fun bs -> - let is_fin = Frame.is_fin bs in - let opcode = Frame.opcode bs in - Frame.unmask bs; - frame_handler ~is_fin ~opcode bs ~off ~len:(Bigstring.length bs) +let create frame_handler = + let parser = + let open Angstrom in + Websocket.Frame.parse + >>| fun frame -> + let is_fin = Websocket.Frame.is_fin frame in + let opcode = Websocket.Frame.opcode frame in + Websocket.Frame.unmask frame; + Websocket.Frame.with_payload frame ~f:(frame_handler ~opcode ~is_fin) + in + { parser + ; parse_state = Done + ; closed = false + } ;; let transition t state = match state with - | AU.Done(consumed, Ok ()) + | AU.Done(consumed, ()) | AU.Fail(0 as consumed, _, _) -> t.parse_state <- Done; consumed - | AU.Done(consumed, Error error) -> - t.parse_state <- Fail error; - consumed | AU.Fail(consumed, marks, msg) -> t.parse_state <- Fail (`Parse(marks, msg)); consumed @@ -37,10 +40,37 @@ let transition t state = committed and start t state = match state with - | AU.Done _ -> failwith "httpaf.Parse.unable to start parser" + | AU.Done _ -> failwith "websocketaf.Reader.unable to start parser" | AU.Fail(0, marks, msg) -> t.parse_state <- Fail (`Parse(marks, msg)) | AU.Partial { committed = 0; continue } -> t.parse_state <- Partial continue | _ -> assert false ;; + +let next t = + match t.parse_state with + | Done -> + if t.closed + then `Close + else `Read + | Fail _ -> `Close + | Partial _ -> `Read +;; + +let rec read_with_more t bs ~off ~len more = + let consumed = + match t.parse_state with + | Fail _ -> 0 + | Done -> + start t (AU.parse t.parser); + read_with_more t bs ~off ~len more; + | Partial continue -> + transition t (continue bs more ~off ~len) + in + begin match more with + | Complete -> t.closed <- true; + | Incomplete -> () + end; + consumed +;; diff --git a/lib/server_connection.ml b/lib/server_connection.ml new file mode 100644 index 0000000..72c93cd --- /dev/null +++ b/lib/server_connection.ml @@ -0,0 +1,101 @@ +module IOVec = Httpaf.IOVec + +type 'handle state = + | Uninitialized + | Handshake of 'handle Server_handshake.t + | Websocket of Server_websocket.t + +type 'handle t = 'handle state ref + +type input_handlers = Server_websocket.input_handlers = + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } + +let passes_scrutiny _headers = + true (* XXX(andreas): missing! *) + +let create ~sha1 ~websocket_handler = + let t = ref Uninitialized in + let request_handler reqd = + let request = Httpaf.Reqd.request reqd in + if passes_scrutiny request.headers then begin + let key = Httpaf.Headers.get_exn request.headers "sec-websocket-key" in + let accept = sha1 (key ^ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") in + let headers = Httpaf.Headers.of_list [ + "upgrade", "websocket"; + "connection", "upgrade"; + "sec-websocket-accept", accept + ] in + let response = Httpaf.(Response.create ~headers `Switching_protocols) in + (* XXX(andreas): this is a hacky workaround for a missing flush hook *) + let body = Httpaf.Reqd.respond_with_streaming reqd response in + Httpaf.Body.write_string body " "; + Httpaf.Body.flush body (fun () -> + t := Websocket (Server_websocket.create ~websocket_handler); + Httpaf.Body.close_writer body + ) + end + in + let handshake = + Server_handshake.create + ~request_handler + in + t := Handshake handshake; + t +;; + +let next_read_operation t = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.next_read_operation handshake + | Websocket websocket -> (Server_websocket.next_read_operation websocket :> [ `Read | `Yield | `Close ]) +;; + +let read t bs ~off ~len = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.read handshake bs ~off ~len + | Websocket websocket -> Server_websocket.read websocket bs ~off ~len +;; + +let read_eof t bs ~off ~len = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.read_eof handshake bs ~off ~len + | Websocket websocket -> Server_websocket.read_eof websocket bs ~off ~len +;; + +let yield_reader t f = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.yield_reader handshake f + | Websocket _ -> assert false +;; + +let next_write_operation t = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.next_write_operation handshake + | Websocket websocket -> Server_websocket.next_write_operation websocket +;; + +let report_write_result t result = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.report_write_result handshake result + | Websocket websocket -> Server_websocket.report_write_result websocket result +;; + +let yield_writer t f = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.yield_writer handshake f + | Websocket websocket -> Server_websocket.yield_writer websocket f +;; + +let close t = + match !t with + | Uninitialized -> assert false + | Handshake handshake -> Server_handshake.close handshake + | Websocket websocket -> Server_websocket.close websocket +;; diff --git a/lib/server_connection.mli b/lib/server_connection.mli new file mode 100644 index 0000000..3384298 --- /dev/null +++ b/lib/server_connection.mli @@ -0,0 +1,24 @@ +module IOVec = Httpaf.IOVec + +type 'handle t + +type input_handlers = + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } + +val create + : sha1 : (string -> string) + -> websocket_handler : (Wsd.t -> input_handlers) + -> 'handle t + +val next_read_operation : _ t -> [ `Read | `Yield | `Close ] +val next_write_operation : _ t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] + +val read : _ t -> Bigstringaf.t -> off:int -> len:int -> int +val read_eof : _ t -> Bigstringaf.t -> off:int -> len:int -> int +val report_write_result : _ t -> [`Ok of int | `Closed ] -> unit + +val yield_reader : _ t -> (unit -> unit) -> unit +val yield_writer : _ t -> (unit -> unit) -> unit + +val close : _ t -> unit diff --git a/lib/server_handshake.ml b/lib/server_handshake.ml new file mode 100644 index 0000000..4fced0d --- /dev/null +++ b/lib/server_handshake.ml @@ -0,0 +1,39 @@ +module IOVec = Httpaf.IOVec + +type 'handle t = + { connection : 'handle Httpaf.Server_connection.t + } + +let create + ~request_handler + = + let connection = + Httpaf.Server_connection.create + request_handler + in + { connection } +;; + +let next_read_operation t = + Httpaf.Server_connection.next_read_operation t.connection + +let next_write_operation t = + Httpaf.Server_connection.next_write_operation t.connection + +let read t = + Httpaf.Server_connection.read t.connection + +let read_eof t = + Httpaf.Server_connection.read_eof t.connection + +let report_write_result t = + Httpaf.Server_connection.report_write_result t.connection + +let yield_reader t = + Httpaf.Server_connection.yield_reader t.connection + +let yield_writer t = + Httpaf.Server_connection.yield_writer t.connection + +let close t = + Httpaf.Server_connection.shutdown t.connection diff --git a/lib/server_handshake.mli b/lib/server_handshake.mli new file mode 100644 index 0000000..b945606 --- /dev/null +++ b/lib/server_handshake.mli @@ -0,0 +1,19 @@ +module IOVec = Httpaf.IOVec + +type 'handle t + +val create + : request_handler : 'handle Httpaf.Server_connection.request_handler + -> 'handle t + +val next_read_operation : _ t -> [ `Read | `Close | `Yield ] +val next_write_operation : _ t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] + +val read : _ t -> Bigstringaf.t -> off:int -> len:int -> int +val read_eof : _ t -> Bigstringaf.t -> off:int -> len:int -> int +val report_write_result : _ t -> [`Ok of int | `Closed ] -> unit + +val yield_reader : _ t -> (unit -> unit) -> unit +val yield_writer : _ t -> (unit -> unit) -> unit + +val close : _ t -> unit diff --git a/lib/server_websocket.ml b/lib/server_websocket.ml new file mode 100644 index 0000000..317edc4 --- /dev/null +++ b/lib/server_websocket.ml @@ -0,0 +1,43 @@ +module IOVec = Httpaf.IOVec + +type t = + { reader : [`Parse of string list * string] Reader.t + ; wsd : Wsd.t } + +type input_handlers = + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } + +let create ~websocket_handler = + let mode = `Server in + let wsd = Wsd.create mode in + let { frame; _ } = websocket_handler wsd in + { reader = Reader.create frame + ; wsd + } + +let next_read_operation t = + Reader.next t.reader + +let next_write_operation t = + Wsd.next t.wsd + +let read t bs ~off ~len = + Reader.read_with_more t.reader bs ~off ~len Incomplete + +let read_eof t bs ~off ~len = + Reader.read_with_more t.reader bs ~off ~len Complete + +let report_write_result t result = + Wsd.report_result t.wsd result + +let yield_writer t k = + if Wsd.is_closed t.wsd + then begin + Wsd.close t.wsd; + k () + end else + Wsd.when_ready_to_write t.wsd k + +let close { wsd; _ } = + Wsd.close wsd diff --git a/lib/server_websocket.mli b/lib/server_websocket.mli new file mode 100644 index 0000000..cf76a79 --- /dev/null +++ b/lib/server_websocket.mli @@ -0,0 +1,22 @@ +module IOVec = Httpaf.IOVec + +type t + +type input_handlers = + { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> Bigstringaf.t -> off:int -> len:int -> unit + ; eof : unit -> unit } + +val create + : websocket_handler : (Wsd.t -> input_handlers) + -> t + +val next_read_operation : t -> [ `Read | `Close ] +val next_write_operation : t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] + +val read : t -> Bigstringaf.t -> off:int -> len:int -> int +val read_eof : t -> Bigstringaf.t -> off:int -> len:int -> int +val report_write_result : t -> [`Ok of int | `Closed ] -> unit + +val yield_writer : t -> (unit -> unit) -> unit + +val close : t -> unit diff --git a/lib/websocket.ml b/lib/websocket.ml index dec8cdc..5eac7b9 100644 --- a/lib/websocket.ml +++ b/lib/websocket.ml @@ -1,10 +1,10 @@ module Opcode = struct - type standard_non_control = + type standard_non_control = [ `Continuation | `Text | `Binary ] - type standard_control = + type standard_control = [ `Connection_close | `Ping | `Pong ] @@ -48,7 +48,7 @@ module Opcode = struct Array.unsafe_get code_table code let of_code code = - if code > 0xf + if code > 0xf then None else Some (Array.unsafe_get code_table code) @@ -60,6 +60,9 @@ module Opcode = struct let to_int = code let of_int = of_code let of_int_exn = of_code_exn + + let pp_hum fmt t = + Format.fprintf fmt "%d" (to_int t) end module Close_code = struct @@ -128,7 +131,7 @@ module Close_code = struct then failwith "Close_code.of_code_exn: value can't fit in two bytes"; if code < 1000 then failwith "Close_code.of_code_exn: value in invalid range 0-999"; - if code < 1016 + if code < 1016 then unsafe_of_code (code land 0b1111) else `Other code ;; @@ -139,66 +142,55 @@ module Close_code = struct end module Frame = struct - type t = Bigstring.t + type t = Bigstringaf.t let is_fin t = - let bits = Bigstring.unsafe_get t 0 |> Char.code in + let bits = Bigstringaf.unsafe_get t 0 |> Char.code in bits land (1 lsl 8) = 1 lsl 8 ;; let rsv t = - let bits = Bigstring.unsafe_get t 0 |> Char.code in + let bits = Bigstringaf.unsafe_get t 0 |> Char.code in (bits lsr 4) land 0b0111 ;; let opcode t = - let bits = Bigstring.unsafe_get t 0 |> Char.code in - bits land 4 |> Opcode.unsafe_of_code + let bits = Bigstringaf.unsafe_get t 0 |> Char.code in + bits land 0b1111 |> Opcode.unsafe_of_code ;; let payload_length_of_offset t off = - let bits = Bigstring.unsafe_get t (off + 1) |> Char.code in + let bits = Bigstringaf.unsafe_get t (off + 1) |> Char.code in let length = bits land 0b01111111 in - if length = 126 then Bigstring.unsafe_get_u16_be t ~off:(off + 2) else + if length = 126 then Bigstringaf.unsafe_get_int16_be t (off + 2) else (* This is technically unsafe, but if somebody's asking us to read 2^63 * bytes, then we're already screwd. *) - if length = 127 then Bigstring.unsafe_get_64_be t ~off:(off + 2) |> Int64.to_int else + if length = 127 then Bigstringaf.unsafe_get_int64_be t (off + 2) |> Int64.to_int else length ;; - let payload_length t = + let payload_length t = payload_length_of_offset t 0 ;; let has_mask t = - let bits = Bigstring.unsafe_get t 1 |> Char.code in - bits land (1 lsl 8) = 1 lsl 8 - ;; - - let mask t = - if not (has_mask t) - then None - else - Some ( - let bits = Bigstring.unsafe_get t 1 |> Char.code in - if bits = 254 then Bigstring.unsafe_get_32_be t ~off:4 else - if bits = 255 then Bigstring.unsafe_get_32_be t ~off:10 else - Bigstring.unsafe_get_32_be t ~off:2) + let bits = Bigstringaf.unsafe_get t 1 |> Char.code in + bits land (1 lsl 7) = 1 lsl 7 ;; let mask_exn t = - let bits = Bigstring.unsafe_get t 1 |> Char.code in - if bits = 254 then Bigstring.unsafe_get_32_be t ~off:4 else - if bits = 255 then Bigstring.unsafe_get_32_be t ~off:10 else - if bits >= 127 then Bigstring.unsafe_get_32_be t ~off:2 else + let bits = Bigstringaf.unsafe_get t 1 |> Char.code in + if bits = 254 then Bigstringaf.unsafe_get_int32_be t 4 else + if bits = 255 then Bigstringaf.unsafe_get_int32_be t 10 else + if bits >= 127 then Bigstringaf.unsafe_get_int32_be t 2 else failwith "Frame.mask_exn: no mask present" ;; let payload_offset_of_bits bits = let initial_offset = 2 in - let mask_offset = (bits land (1 lsl 8)) lsr (7 - 2) in - let length_offset = - let length = bits land 0b0111111 in + let mask_offset = (bits land (1 lsl 7)) lsr (7 - 2) in + let length_offset = + let length = bits land 0b01111111 in if length < 126 then 0 else 2 lsl ((length land 0b1) lsl 2) @@ -207,7 +199,7 @@ module Frame = struct ;; let payload_offset t = - let bits = Bigstring.unsafe_get t 1 |> Char.code in + let bits = Bigstringaf.unsafe_get t 1 |> Char.code in payload_offset_of_bits bits ;; @@ -218,21 +210,21 @@ module Frame = struct ;; let copy_payload t = - with_payload t ~f:Bigstring.copy + with_payload t ~f:Bigstringaf.copy ;; - let copy_payload_bytes t = - with_payload t ~f:(fun bs ~off ~len -> + let copy_payload_bytes t = + with_payload t ~f:(fun bs ~off:src_off ~len -> let bytes = Bytes.create len in - Bigstring.blit_to_bytes bs off bytes 0 len; + Bigstringaf.blit_to_bytes bs ~src_off bytes ~dst_off:0 ~len; bytes) ;; let length_of_offset t off = - let bits = Bigstring.unsafe_get t (off + 1) |> Char.code in + let bits = Bigstringaf.unsafe_get t (off + 1) |> Char.code in let payload_offset = payload_offset_of_bits bits in let payload_length = payload_length_of_offset t off in - 2 + payload_offset + payload_length + payload_offset + payload_length ;; let length t = @@ -240,19 +232,19 @@ module Frame = struct ;; let apply_mask mask bs ~off ~len = - for i = off to len - 1 do + for i = off to off + len - 1 do let j = (i - off) mod 4 in - let c = Bigstring.unsafe_get bs i |> Char.code in - let c = c lxor (Int32.(logand (shift_left mask (4 - j)) 0xffl) |> Int32.to_int) in - Bigstring.unsafe_set bs i (Char.unsafe_chr c) + let c = Bigstringaf.unsafe_get bs i |> Char.code in + let c = c lxor Int32.(logand (shift_right mask (8 * (3 - j))) 0xffl |> to_int) in + Bigstringaf.unsafe_set bs i (Char.unsafe_chr c) done ;; let apply_mask_bytes mask bs ~off ~len = - for i = off to len - 1 do + for i = off to off + len - 1 do let j = (i - off) mod 4 in let c = Bytes.unsafe_get bs i |> Char.code in - let c = c lxor (Int32.(logand (shift_left mask (4 - j)) 0xffl) |> Int32.to_int) in + let c = c lxor Int32.(logand (shift_right mask (8 * (3 - j))) 0xffl |> to_int) in Bytes.unsafe_set bs i (Char.unsafe_chr c) done ;; @@ -271,20 +263,20 @@ module Frame = struct let parse = let open Angstrom in - Unsafe.peek 2 (fun bs ~off ~len -> length_of_offset bs off) - >>= fun len -> Unsafe.take len Bigstring.sub + Unsafe.peek 2 (fun bs ~off ~len:_ -> length_of_offset bs off) + >>= fun len -> Unsafe.take len Bigstringaf.sub ;; let serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length = let opcode = Opcode.to_int opcode in - let is_fin = if is_fin then 1 lsl 8 else 0 in + let is_fin = if is_fin then 1 lsl 7 else 0 in let is_mask = match mask with | None -> 0 - | Some _ -> 1 lsl 8 + | Some _ -> 1 lsl 7 in - Faraday.write_uint8 faraday (is_fin lsl opcode); - if payload_length <= 125 then + Faraday.write_uint8 faraday (is_fin lor opcode); + if payload_length <= 125 then Faraday.write_uint8 faraday (is_mask lor payload_length) else if payload_length <= 0xffff then begin Faraday.write_uint8 faraday (is_mask lor 126); @@ -299,12 +291,12 @@ module Frame = struct end ;; - let serialize_control faraday ~opcode = - serialize_headers faraday ~is_fin:true ~opcode ~payload_length:0 + let serialize_control ?mask faraday ~opcode = + let opcode = (opcode :> Opcode.t) in + serialize_headers faraday ?mask ~is_fin:true ~opcode ~payload_length:0 let schedule_serialize ?mask faraday ~is_fin ~opcode ~payload ~off ~len = - let payload_length = Bigstring.length payload in - serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length; + serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length:len; begin match mask with | None -> () | Some mask -> apply_mask mask payload ~off ~len @@ -313,8 +305,16 @@ module Frame = struct ;; let serialize_bytes ?mask faraday ~is_fin ~opcode ~payload ~off ~len = - let payload_length = Bytes.length payload in - serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length; + serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length:len; + begin match mask with + | None -> () + | Some mask -> apply_mask_bytes mask payload ~off ~len + end; + Faraday.write_bytes faraday payload ~off ~len; + ;; + + let schedule_serialize_bytes ?mask faraday ~is_fin ~opcode ~payload ~off ~len = + serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length:len; begin match mask with | None -> () | Some mask -> apply_mask_bytes mask payload ~off ~len diff --git a/lib/websocket.mli b/lib/websocket.mli index 51786d0..ecff99f 100644 --- a/lib/websocket.mli +++ b/lib/websocket.mli @@ -26,6 +26,8 @@ module Opcode : sig val of_int : int -> t option val of_int_exn : int -> t + + val pp_hum : Format.formatter -> t -> unit end module Close_code : sig @@ -66,31 +68,48 @@ module Frame : sig val opcode : t -> Opcode.t val has_mask : t -> bool - val mask : t -> int32 option + val mask : t -> unit + val unmask : t -> unit + val mask_exn : t -> int32 + val length : t -> int + val payload_length : t -> int - val with_payload : t -> f:(Bigstring.t -> off:int -> len:int -> 'a) -> 'a + val with_payload : t -> f:(Bigstringaf.t -> off:int -> len:int -> 'a) -> 'a - val copy_payload : t -> Bigstring.t + val copy_payload : t -> Bigstringaf.t val copy_payload_bytes : t -> Bytes.t val parse : t Angstrom.t - val serialize_control : Faraday.t -> opcode:Opcode.standard_control -> unit + val serialize_control + : ?mask:int32 + -> Faraday.t + -> opcode:Opcode.standard_control + -> unit val schedule_serialize : ?mask:int32 -> Faraday.t -> is_fin:bool -> opcode:Opcode.t - -> payload:Bigstring.t + -> payload:Bigstringaf.t -> off:int -> len:int - -> Faraday.t -> unit val schedule_serialize_bytes + : ?mask:int32 + -> Faraday.t + -> is_fin:bool + -> opcode:Opcode.t + -> payload:Bytes.t + -> off:int + -> len:int + -> unit + + val serialize_bytes : ?mask:int32 -> Faraday.t -> is_fin:bool diff --git a/lib/websocketaf.ml b/lib/websocketaf.ml index 4401979..1c4b8f5 100644 --- a/lib/websocketaf.ml +++ b/lib/websocketaf.ml @@ -1,2 +1,6 @@ +module Bigstring = Bigstring module Client_handshake = Client_handshake -module Client_connetion = Client_connection +module Client_connection = Client_connection +module Server_connection = Server_connection +module Wsd = Wsd +module Websocket = Websocket diff --git a/lib/websocketaf.mli b/lib/websocketaf.mli deleted file mode 100644 index e69de29..0000000 diff --git a/lib/wsd.ml b/lib/wsd.ml index 99736c4..c9b500d 100644 --- a/lib/wsd.ml +++ b/lib/wsd.ml @@ -1,20 +1,76 @@ -type t = Faraday.t +module IOVec = Httpaf.IOVec -let schedule ?mask t ~kind payload ~off ~len = - Websocket.Frame.schedule_serialize ?mask t ~is_fin:true ~opcode:kind ~payload ~off ~len +type mode = + [ `Client of unit -> int32 + | `Server + ] -let send_bytes ?mask t ~kind payload ~off ~len = - Websocket.Frame.schedule_serialize_bytes ?mask t ~is_fin:true ~opcode:kind ~payload ~off ~len +type t = + { faraday : Faraday.t + ; mode : mode + ; mutable when_ready_to_write : unit -> unit + } + +let default_ready_to_write = Sys.opaque_identity (fun () -> ()) + +let create mode = + { faraday = Faraday.create 0x1000 + ; mode + ; when_ready_to_write = default_ready_to_write; + } + +let mask t = + match t.mode with + | `Client m -> Some (m ()) + | `Server -> None + +let ready_to_write t = + let callback = t.when_ready_to_write in + t.when_ready_to_write <- default_ready_to_write; + callback () + +let schedule t ~kind payload ~off ~len = + let mask = mask t in + Websocket.Frame.schedule_serialize t.faraday ?mask ~is_fin:true ~opcode:(kind :> Websocket.Opcode.t) ~payload ~off ~len; + ready_to_write t + +let send_bytes t ~kind payload ~off ~len = + let mask = mask t in + Websocket.Frame.schedule_serialize_bytes t.faraday ?mask ~is_fin:true ~opcode:(kind :> Websocket.Opcode.t) ~payload ~off ~len; + ready_to_write t let send_ping t = - Websocket.Frame.serialize_control t ~opcode:`Ping + Websocket.Frame.serialize_control t.faraday ~opcode:`Ping; + ready_to_write t let send_pong t = - Websocket.Frame.serialize_control t ~opcode:`Pong + Websocket.Frame.serialize_control t.faraday ~opcode:`Pong; + ready_to_write t -let flushed t f = Faraday.flush t f +let flushed t f = Faraday.flush t.faraday f let close t = - Websocket.Frame.serialize_control t ~opcode:`Connection_close; - Faraday.close t -;; + Websocket.Frame.serialize_control t.faraday ~opcode:`Connection_close; + Faraday.close t.faraday; + ready_to_write t + +let next t = + match Faraday.operation t.faraday with + | `Close -> `Close 0 (* XXX(andreas): should track unwritten bytes *) + | `Yield -> `Yield + | `Writev iovecs -> `Write iovecs + +let report_result t result = + match result with + | `Closed -> close t + | `Ok len -> Faraday.shift t.faraday len + +let is_closed t = + Faraday.is_closed t.faraday + +let when_ready_to_write t callback = + if not (t.when_ready_to_write == default_ready_to_write) + then failwith "Wsd.when_ready_to_write: only one callback can be registered at a time" + else if is_closed t + then callback () + else t.when_ready_to_write <- callback diff --git a/lib/wsd.mli b/lib/wsd.mli index d36497b..3f9d768 100644 --- a/lib/wsd.mli +++ b/lib/wsd.mli @@ -1,17 +1,26 @@ +module IOVec = Httpaf.IOVec + +type mode = + [ `Client of unit -> int32 + | `Server + ] + type t -val schedule - : ?mask:int32 +val create + : mode -> t + +val schedule + : t -> kind:[ `Text | `Binary ] - -> Bigstring.t + -> Bigstringaf.t -> off:int -> len:int -> unit val send_bytes - : ?mask:int32 - -> t + : t -> kind:[ `Text | `Binary ] -> Bytes.t -> off:int @@ -23,3 +32,10 @@ val send_pong : t -> unit val flushed : t -> (unit -> unit) -> unit val close : t -> unit + +val next : t -> [ `Write of Bigstringaf.t IOVec.t list | `Yield | `Close of int ] +val report_result : t -> [`Ok of int | `Closed ] -> unit + +val is_closed : t -> bool + +val when_ready_to_write : t -> (unit -> unit) -> unit diff --git a/lib_test/dune b/lib_test/dune new file mode 100644 index 0000000..29eb823 --- /dev/null +++ b/lib_test/dune @@ -0,0 +1,9 @@ +(executable + (libraries websocketaf alcotest) + (name test_websocketaf)) + +(alias + (name runtest) + (package websocketaf) + (deps (:test test_websocketaf.exe)) + (action (run %{test} -v))) diff --git a/lib_test/test_websocketaf.ml b/lib_test/test_websocketaf.ml new file mode 100644 index 0000000..5f8b73c --- /dev/null +++ b/lib_test/test_websocketaf.ml @@ -0,0 +1,48 @@ +module Websocket = struct + open Websocketaf.Websocket + + module Testable = struct + let opcode = Alcotest.testable Opcode.pp_hum (=) + end + + let parse_frame serialized_frame = + match Angstrom.parse_string Frame.parse serialized_frame with + | Ok frame -> frame + | Error err -> Alcotest.fail err + + let test_parsing_ping_frame () = + let frame = parse_frame "\137\128\000\000\046\216" in + Alcotest.check Testable.opcode "opcode" `Ping (Frame.opcode frame); + Alcotest.(check bool) "has mask" true (Frame.has_mask frame); + Alcotest.(check int32) "mask" 11992l (Frame.mask_exn frame); + Alcotest.(check int) "payload_length" (Frame.payload_length frame) 0; + Alcotest.(check int) "length" (Frame.length frame) 6 + + let test_parsing_close_frame () = + let frame = parse_frame "\136\000" in + Alcotest.check Testable.opcode "opcode" `Connection_close (Frame.opcode frame); + Alcotest.(check int) "payload_length" (Frame.payload_length frame) 0; + Alcotest.(check int) "length" (Frame.length frame) 2 + + let test_parsing_text_frame () = + let frame = parse_frame "\129\139\086\057\046\216\103\011\029\236\099\015\025\224\111\009\036" in + Alcotest.check Testable.opcode "opcode" `Text (Frame.opcode frame); + Alcotest.(check bool) "has mask" true (Frame.has_mask frame); + Alcotest.(check int32) "mask" 1446588120l (Frame.mask_exn frame); + Alcotest.(check int) "payload_length" (Frame.payload_length frame) 11; + Alcotest.(check int) "length" (Frame.length frame) 17; + Frame.unmask frame; + let payload = Bytes.to_string (Frame.copy_payload_bytes frame) in + Alcotest.(check string) "payload" "1234567890\n" payload + + let tests = + [ "parsing ping frame", `Quick, test_parsing_ping_frame + ; "parsing close frame", `Quick, test_parsing_close_frame + ; "parsing text frame", `Quick, test_parsing_text_frame + ] +end + +let () = + Alcotest.run "websocketaf unit tests" + [ "websocket", Websocket.tests + ] diff --git a/lwt/dune b/lwt/dune new file mode 100644 index 0000000..da965ba --- /dev/null +++ b/lwt/dune @@ -0,0 +1,5 @@ +(library + (name websocketaf_lwt) + (public_name websocketaf-lwt) + (libraries faraday-lwt-unix websocketaf lwt.unix digestif.ocaml base64) + (flags (:standard -safe-string))) diff --git a/lwt/websocketaf_lwt.ml b/lwt/websocketaf_lwt.ml new file mode 100644 index 0000000..d071779 --- /dev/null +++ b/lwt/websocketaf_lwt.ml @@ -0,0 +1,274 @@ +open Lwt.Infix + +let sha1 s = + s + |> Digestif.SHA1.digest_string + |> Digestif.SHA1.to_raw_string + |> B64.encode ~pad:true + +module Buffer : sig + type t + + val create : int -> t + + val get : t -> f:(Lwt_bytes.t -> off:int -> len:int -> int) -> int + val put : t -> f:(Lwt_bytes.t -> off:int -> len:int -> int Lwt.t) -> int Lwt.t +end = struct + type t = + { buffer : Lwt_bytes.t + ; mutable off : int + ; mutable len : int } + + let create size = + let buffer = Lwt_bytes.create size in + { buffer; off = 0; len = 0 } + + let compress t = + if t.len = 0 + then begin + t.off <- 0; + t.len <- 0; + end else if t.off > 0 + then begin + Lwt_bytes.blit t.buffer t.off t.buffer 0 t.len; + t.off <- 0; + end + + let get t ~f = + let n = f t.buffer ~off:t.off ~len:t.len in + t.off <- t.off + n; + t.len <- t.len - n; + if t.len = 0 + then t.off <- 0; + n + + let put t ~f = + compress t; + f t.buffer ~off:(t.off + t.len) ~len:(Lwt_bytes.length t.buffer - t.len) + >>= fun n -> + t.len <- t.len + n; + Lwt.return n +end + + +let read fd buffer = + Lwt.catch + (fun () -> + Buffer.put buffer ~f:(fun bigstring ~off ~len -> + Lwt_bytes.read fd bigstring off len)) + (function + | Unix.Unix_error (Unix.EBADF, _, _) as exn -> + Lwt.fail exn + | exn -> + Lwt.async (fun () -> + Lwt_unix.close fd); + Lwt.fail exn) + + >>= fun bytes_read -> + if bytes_read = 0 then + Lwt.return `Eof + else + Lwt.return (`Ok bytes_read) + +let shutdown socket command = + try Lwt_unix.shutdown socket command + with Unix.Unix_error (Unix.ENOTCONN, _, _) -> () + + + +module Server = struct + let create_connection_handler ?config:_ ~websocket_handler ~error_handler:_ = + fun client_addr socket -> + let module Server_connection = Websocketaf.Server_connection in + let connection = + Server_connection.create + ~sha1 + ~websocket_handler:(websocket_handler client_addr) + in + + + let read_buffer = Buffer.create 0x1000 in + let read_loop_exited, notify_read_loop_exited = Lwt.wait () in + + let rec read_loop () = + let rec read_loop_step () = + match Server_connection.next_read_operation connection with + | `Read -> + read socket read_buffer >>= begin function + | `Eof -> + Buffer.get read_buffer ~f:(fun bigstring ~off ~len -> + Server_connection.read_eof connection bigstring ~off ~len) + |> ignore; + read_loop_step () + | `Ok _ -> + Buffer.get read_buffer ~f:(fun bigstring ~off ~len -> + Server_connection.read connection bigstring ~off ~len) + |> ignore; + read_loop_step () + end + + | `Yield -> + Server_connection.yield_reader connection read_loop; + Lwt.return_unit + + | `Close -> + Lwt.wakeup_later notify_read_loop_exited (); + if not (Lwt_unix.state socket = Lwt_unix.Closed) then begin + shutdown socket Unix.SHUTDOWN_RECEIVE + end; + Lwt.return_unit + in + + Lwt.async (fun () -> + Lwt.catch + read_loop_step + (fun exn -> + (* XXX(andreas): missing error reporting *) + (* Server_connection.report_exn connection exn;*) + Printexc.print_backtrace stdout; + ignore(raise exn); + Lwt.return_unit)) + in + + + let writev = Faraday_lwt_unix.writev_of_fd socket in + let write_loop_exited, notify_write_loop_exited = Lwt.wait () in + + let rec write_loop () = + let rec write_loop_step () = + match Server_connection.next_write_operation connection with + | `Write io_vectors -> + writev io_vectors >>= fun result -> + Server_connection.report_write_result connection result; + write_loop_step () + + | `Yield -> + Server_connection.yield_writer connection write_loop; + Lwt.return_unit + + | `Close _ -> + Lwt.wakeup_later notify_write_loop_exited (); + if not (Lwt_unix.state socket = Lwt_unix.Closed) then begin + shutdown socket Unix.SHUTDOWN_SEND + end; + Lwt.return_unit + in + + Lwt.async (fun () -> + Lwt.catch + write_loop_step + (fun exn -> + (* XXX(andreas): missing error reporting *) + (*Server_connection.report_exn connection exn;*) + Printexc.print_backtrace stdout; + ignore(raise exn); + Lwt.return_unit)) + in + + + read_loop (); + write_loop (); + Lwt.join [read_loop_exited; write_loop_exited] >>= fun () -> + + if Lwt_unix.state socket <> Lwt_unix.Closed then + Lwt.catch + (fun () -> Lwt_unix.close socket) + (fun _exn -> Lwt.return_unit) + else + Lwt.return_unit +end + + + +module Client = struct + let connect socket ~nonce ~host ~port ~resource ~error_handler ~websocket_handler = + let module Client_connection = Websocketaf.Client_connection in + let connection = + Client_connection.create ~nonce ~host ~port ~resource ~sha1 ~error_handler ~websocket_handler in + + let read_buffer = Buffer.create 0x1000 in + let read_loop_exited, notify_read_loop_exited = Lwt.wait () in + + let read_loop () = + let rec read_loop_step () = + match Client_connection.next_read_operation connection with + | `Read -> + read socket read_buffer >>= begin function + | `Ok _ -> + Buffer.get read_buffer ~f:(fun bigstring ~off ~len -> + Client_connection.read connection bigstring ~off ~len + ) + |> ignore; + read_loop_step () + | `Eof -> + Buffer.get read_buffer ~f:(fun bigstring ~off ~len -> + Client_connection.read_eof connection bigstring ~off ~len) + |> ignore; + read_loop_step () + end + + | `Close -> + Lwt.wakeup_later notify_read_loop_exited (); + if not (Lwt_unix.state socket = Lwt_unix.Closed) then begin + shutdown socket Unix.SHUTDOWN_RECEIVE + end; + Lwt.return_unit + in + + Lwt.async (fun () -> + Lwt.catch + read_loop_step + (fun exn -> + (*Client_connection.report_exn connection exn;*) + Printexc.print_backtrace stdout; + ignore(raise exn); + Lwt.return_unit)) + in + + + let writev = Faraday_lwt_unix.writev_of_fd socket in + let write_loop_exited, notify_write_loop_exited = Lwt.wait () in + + let rec write_loop () = + let rec write_loop_step () = + flush stdout; + match Client_connection.next_write_operation connection with + | `Write io_vectors -> + writev io_vectors >>= fun result -> + Client_connection.report_write_result connection result; + write_loop_step () + + | `Yield -> + Client_connection.yield_writer connection write_loop; + Lwt.return_unit + + | `Close _ -> + Lwt.wakeup_later notify_write_loop_exited (); + if not (Lwt_unix.state socket = Lwt_unix.Closed) then begin + shutdown socket Unix.SHUTDOWN_SEND + end; + Lwt.return_unit + in + + Lwt.async (fun () -> + Lwt.catch + write_loop_step + (fun exn -> + (*Client_connection.report_exn connection exn;*) + ignore(raise exn); + Lwt.return_unit)) + in + + + read_loop (); + write_loop (); + + Lwt.join [read_loop_exited; write_loop_exited] >>= fun () -> + + if Lwt_unix.state socket <> Lwt_unix.Closed then + Lwt.catch + (fun () -> Lwt_unix.close socket) + (fun _exn -> Lwt.return_unit) + else + Lwt.return_unit; +end diff --git a/lwt/websocketaf_lwt.mli b/lwt/websocketaf_lwt.mli new file mode 100644 index 0000000..417882c --- /dev/null +++ b/lwt/websocketaf_lwt.mli @@ -0,0 +1,19 @@ +module Client : sig + val connect + : Lwt_unix.file_descr + -> nonce : string + -> host : string + -> port : int + -> resource : string + -> error_handler : (Websocketaf.Client_connection.error -> unit) + -> websocket_handler : (Websocketaf.Wsd.t -> Websocketaf.Client_connection.input_handlers) + -> unit Lwt.t +end + +module Server : sig + val create_connection_handler + : ?config : Httpaf.Server_connection.Config.t + -> websocket_handler : (Unix.sockaddr -> Websocketaf.Wsd.t -> Websocketaf.Server_connection.input_handlers) + -> error_handler : (Unix.sockaddr -> Httpaf.Server_connection.error_handler) + -> (Unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t) +end diff --git a/websocketaf-lwt.opam b/websocketaf-lwt.opam new file mode 100644 index 0000000..d304478 --- /dev/null +++ b/websocketaf-lwt.opam @@ -0,0 +1,22 @@ +opam-version: "2.0" +name: "websocketaf-lwt" +maintainer: "Spiros Eliopoulos " +authors: [ "Andreas Garnæs " ] +license: "BSD-3-clause" +homepage: "https://github.com/inhabitedtype/websocketaf" +bug-reports: "https://github.com/inhabitedtype/websocketaf/issues" +dev-repo: "git+https://github.com/inhabitedtype/websocketaf.git" +build: [ + ["dune" "subst" "-p" name] {pinned} + ["dune" "build" "-p" name "-j" jobs] +] +depends: [ + "ocaml" {>= "4.03.0"} + "faraday-lwt-unix" + "websocketaf" + "dune" {build} + "lwt" + "digestif" + "base64" +] +synopsis: "Lwt support for websocket/af" diff --git a/websocketaf.opam b/websocketaf.opam index 1f115db..f7305b5 100644 --- a/websocketaf.opam +++ b/websocketaf.opam @@ -6,16 +6,17 @@ homepage: "https://github.com/inhabitedtype/websocketaf" bug-reports: "https://github.com/inhabitedtype/websocketaf/issues" dev-repo: "https://github.com/inhabitedtype/websocketaf.git" build: [ - ["jbuilder" "subst"] {pinned} - ["jbuilder" "build" "-p" name "-j" jobs] + ["dune" "subst"] {pinned} + ["dune" "build" "-p" name "-j" jobs] ] build-test: [ - ["jbuilder" "runtest" "-p" name] + ["dune" "runtest" "-p" name] ] depends: [ - "jbuilder" {build & >= "1.0+beta10"} + "dune" {build} "base64" "alcotest" {test} + "bigstringaf" "angstrom" {>= "0.7.0"} "faraday" {>= "0.5.0"} "httpaf"