diff --git a/lib/websocket.ml b/lib/websocket.ml index c7a4bb5..394891c 100644 --- a/lib/websocket.ml +++ b/lib/websocket.ml @@ -1,10 +1,10 @@ module Opcode = struct - type standard_non_control = + type standard_non_control = [ `Continuation | `Text | `Binary ] - type standard_control = + type standard_control = [ `Connection_close | `Ping | `Pong ] @@ -48,7 +48,7 @@ module Opcode = struct Array.unsafe_get code_table code let of_code code = - if code > 0xf + if code > 0xf then None else Some (Array.unsafe_get code_table code) @@ -60,6 +60,9 @@ module Opcode = struct let to_int = code let of_int = of_code let of_int_exn = of_code_exn + + let pp_hum fmt t = + Format.fprintf fmt "%d" (to_int t) end module Close_code = struct @@ -128,7 +131,7 @@ module Close_code = struct then failwith "Close_code.of_code_exn: value can't fit in two bytes"; if code < 1000 then failwith "Close_code.of_code_exn: value in invalid range 0-999"; - if code < 1016 + if code < 1016 then unsafe_of_code (code land 0b1111) else `Other code ;; @@ -153,7 +156,7 @@ module Frame = struct let opcode t = let bits = Bigstringaf.unsafe_get t 0 |> Char.code in - bits land 4 |> Opcode.unsafe_of_code + bits land 0b1111 |> Opcode.unsafe_of_code ;; let payload_length_of_offset t off = @@ -166,14 +169,13 @@ module Frame = struct length ;; - let payload_length t = + let payload_length t = payload_length_of_offset t 0 ;; let has_mask t = let bits = Bigstringaf.unsafe_get t 1 |> Char.code in - bits land (1 lsl 8) = 1 lsl 8 - ;; + bits land (1 lsl 7) = 1 lsl 7 let mask t = if not (has_mask t) @@ -196,9 +198,9 @@ module Frame = struct let payload_offset_of_bits bits = let initial_offset = 2 in - let mask_offset = (bits land (1 lsl 8)) lsr (7 - 2) in - let length_offset = - let length = bits land 0b0111111 in + let mask_offset = (bits land (1 lsl 7)) lsr (7 - 2) in + let length_offset = + let length = bits land 0b01111111 in if length < 126 then 0 else 2 lsl ((length land 0b1) lsl 2) @@ -232,23 +234,27 @@ module Frame = struct let bits = Bigstringaf.unsafe_get t (off + 1) |> Char.code in let payload_offset = payload_offset_of_bits bits in let payload_length = payload_length_of_offset t off in - 2 + payload_offset + payload_length + payload_offset + payload_length ;; + let length t = + length_of_offset t 0 + ;; + let apply_mask mask bs ~off ~len = - for i = off to len - 1 do + for i = off to off + len - 1 do let j = (i - off) mod 4 in let c = Bigstringaf.unsafe_get bs i |> Char.code in - let c = c lxor (Int32.(logand (shift_left mask (4 - j)) 0xffl) |> Int32.to_int) in + let c = c lxor Int32.(logand (shift_right mask (8 * (3 - j))) 0xffl |> to_int) in Bigstringaf.unsafe_set bs i (Char.unsafe_chr c) done ;; let apply_mask_bytes mask bs ~off ~len = - for i = off to len - 1 do + for i = off to off + len - 1 do let j = (i - off) mod 4 in let c = Bytes.unsafe_get bs i |> Char.code in - let c = c lxor (Int32.(logand (shift_left mask (4 - j)) 0xffl) |> Int32.to_int) in + let c = c lxor Int32.(logand (shift_right mask (8 * (3 - j))) 0xffl |> to_int) in Bytes.unsafe_set bs i (Char.unsafe_chr c) done ;; @@ -273,14 +279,14 @@ module Frame = struct let serialize_headers faraday ?mask ~is_fin ~opcode ~payload_length = let opcode = Opcode.to_int opcode in - let is_fin = if is_fin then 1 lsl 8 else 0 in + let is_fin = if is_fin then 1 lsl 7 else 0 in let is_mask = match mask with | None -> 0 - | Some _ -> 1 lsl 8 + | Some _ -> 1 lsl 7 in - Faraday.write_uint8 faraday (is_fin lsl opcode); - if payload_length <= 125 then + Faraday.write_uint8 faraday (is_fin lor opcode); + if payload_length <= 125 then Faraday.write_uint8 faraday (is_mask lor payload_length) else if payload_length <= 0xffff then begin Faraday.write_uint8 faraday (is_mask lor 126); diff --git a/lib/websocket.mli b/lib/websocket.mli index 9c5dc84..e49abbc 100644 --- a/lib/websocket.mli +++ b/lib/websocket.mli @@ -26,6 +26,8 @@ module Opcode : sig val of_int : int -> t option val of_int_exn : int -> t + + val pp_hum : Format.formatter -> t -> unit end module Close_code : sig @@ -66,12 +68,15 @@ module Frame : sig val opcode : t -> Opcode.t val has_mask : t -> bool + val mask : t -> int32 option val mask_exn : t -> int32 val mask_inplace : t -> unit val unmask_inplace : t -> unit + val length : t -> int + val payload_length : t -> int val with_payload : t -> f:(Bigstringaf.t -> off:int -> len:int -> 'a) -> 'a diff --git a/lib/websocketaf.ml b/lib/websocketaf.ml index 4401979..5b596a2 100644 --- a/lib/websocketaf.ml +++ b/lib/websocketaf.ml @@ -1,2 +1,4 @@ module Client_handshake = Client_handshake module Client_connetion = Client_connection +module Wsd = Wsd +module Websocket = Websocket diff --git a/lib/websocketaf.mli b/lib/websocketaf.mli deleted file mode 100644 index e69de29..0000000 diff --git a/lib_test/dune b/lib_test/dune new file mode 100644 index 0000000..085efdf --- /dev/null +++ b/lib_test/dune @@ -0,0 +1,9 @@ +(executable + (libraries websocketaf alcotest) + (name test_websocketaf)) + +(alias + (name runtest) + (package websocketaf) + (deps (:test test_websocketaf.exe)) + (action (run %{test}))) diff --git a/lib_test/test_websocketaf.ml b/lib_test/test_websocketaf.ml new file mode 100644 index 0000000..3250b7b --- /dev/null +++ b/lib_test/test_websocketaf.ml @@ -0,0 +1,48 @@ +module Websocket = struct + open Websocketaf.Websocket + + module Testable = struct + let opcode = Alcotest.testable Opcode.pp_hum (=) + end + + let parse_frame serialized_frame = + match Angstrom.parse_string Frame.parse serialized_frame with + | Ok frame -> frame + | Error err -> Alcotest.fail err + + let test_parsing_ping_frame () = + let frame = parse_frame "\137\128\000\000\046\216" in + Alcotest.check Testable.opcode "opcode" `Ping (Frame.opcode frame); + Alcotest.(check bool) "has mask" true (Frame.has_mask frame); + Alcotest.(check int32) "mask" 11992l (Frame.mask_exn frame); + Alcotest.(check int) "payload_length" (Frame.payload_length frame) 0; + Alcotest.(check int) "length" (Frame.length frame) 6 + + let test_parsing_close_frame () = + let frame = parse_frame "\136\000" in + Alcotest.check Testable.opcode "opcode" `Connection_close (Frame.opcode frame); + Alcotest.(check int) "payload_length" (Frame.payload_length frame) 0; + Alcotest.(check int) "length" (Frame.length frame) 2 + + let test_parsing_text_frame () = + let frame = parse_frame "\129\139\086\057\046\216\103\011\029\236\099\015\025\224\111\009\036" in + Alcotest.check Testable.opcode "opcode" `Text (Frame.opcode frame); + Alcotest.(check bool) "has mask" true (Frame.has_mask frame); + Alcotest.(check int32) "mask" 1446588120l (Frame.mask_exn frame); + Alcotest.(check int) "payload_length" (Frame.payload_length frame) 11; + Alcotest.(check int) "length" (Frame.length frame) 17; + Frame.unmask_inplace frame; + let payload = Bytes.to_string (Frame.copy_payload_bytes frame) in + Alcotest.(check string) "payload" "1234567890\n" payload + + let tests = + [ "parsing ping frame", `Quick, test_parsing_ping_frame + ; "parsing close frame", `Quick, test_parsing_close_frame + ; "parsing text frame", `Quick, test_parsing_text_frame + ] +end + +let () = + Alcotest.run "websocketaf unit tests" + [ "websocket", Websocket.tests + ]