From 7ede6c1d989ca3135e6305f96e3abe4e83a59a60 Mon Sep 17 00:00:00 2001 From: Matt Bray Date: Mon, 22 Nov 2021 22:47:43 -0600 Subject: [PATCH] feat: obj decoders that consume fields from a string map Inspired by https://github.com/sneeuwballen/benchpress/blob/cf60d86a4a9d06b16f869328e19ec2341fef96bb/src/core/Stanza.ml#L183-L229 --- src/decode.ml | 94 +++++++++++++++++++++++++++++++++++++++++++ src/decoders_util.ml | 12 ++++++ src/decoders_util.mli | 12 ++++++ src/sig.ml | 30 ++++++++++++++ test-yojson/main.ml | 53 ++++++++++++++++++++++++ 5 files changed, 201 insertions(+) diff --git a/src/decode.ml b/src/decode.ml index 0fc89ed..2ef01be 100644 --- a/src/decode.ml +++ b/src/decode.ml @@ -289,6 +289,100 @@ module Make (Decodeable : Decodeable) : key_value_pairs_seq' string value_decoder + module Obj = struct + type t = + { context : value + ; map : value U.String_map.t + } + + type 'a obj = (t, 'a * t, value Error.t) Decoder.t + + let succeed x t = Ok (x, t) + + let bind : ('a -> 'b obj) -> 'a obj -> 'b obj = + fun f dec t -> match dec t with Ok (x, t) -> f x t | Error e -> Error e + + + let map f dec t = + match dec t with Ok (x, t) -> Ok (f x, t) | Error e -> Error e + + + let apply f dec t = + match f t with + | Ok (f, t) -> + (match dec t with Ok (x, t) -> Ok (f x, t) | Error e -> Error e) + | Error e -> + Error e + + + module Infix = struct + let ( >>= ) x f = bind f x + + let ( >|= ) x f = map f x + + let ( <*> ) x f = apply f x + + (* let monoid_product a b = map (fun x y -> (x, y)) a <*> b *) + + let ( let+ ) = ( >|= ) + + (* let ( and+ ) = monoid_product *) + + let ( let* ) = ( >>= ) + + (* let ( and* ) = monoid_product *) + end + + let field_opt key v_dec : 'a option obj = + fun t -> + match U.String_map.get key t.map with + | None -> + Ok (None, t) + | Some value -> + let m = U.String_map.remove key t.map in + let t = { t with map = m } in + (match v_dec value with Ok x -> Ok (Some x, t) | Error e -> Error e) + + + let field key v_dec : 'a obj = + fun t -> + match field_opt key v_dec t with + | Ok (Some x, t) -> + Ok (x, t) + | Ok (None, _t) -> + Error + (Error.make + (Printf.sprintf "Expected an object with an attribute %S" key) + ~context:t.context ) + | Error e -> + Error e + + + let empty : unit obj = + fun t -> + match U.String_map.choose_opt t.map with + | None -> + Ok ((), t) + | Some (k, _) -> + Error + (Error.make + (Printf.sprintf + "Expected an empty object, but have unconsumed field %S" + k ) + ~context:t.context ) + + + let run : 'a obj -> 'a decoder = + fun dec context -> + match key_value_pairs value context with + | Ok l -> + let map = U.String_map.of_list l in + let t = { context; map } in + dec t |> U.My_result.map (fun (x, _) -> x) + | Error e -> + Error e + end + let field : string -> 'a decoder -> 'a decoder = fun key value_decoder t -> let value = diff --git a/src/decoders_util.ml b/src/decoders_util.ml index 20b3fa6..95023df 100644 --- a/src/decoders_util.ml +++ b/src/decoders_util.ml @@ -163,6 +163,18 @@ module My_list = struct aux f l (fun l -> l) end +module String_map = struct + include Map.Make (String) + + let add_list m l = List.fold_left (fun m (k, v) -> add k v m) m l + + let of_list l = add_list empty l + + let get = find_opt + + let choose_opt m = try Some (choose m) with Not_found -> None +end + let with_file_in file f = let ic = open_in file in try diff --git a/src/decoders_util.mli b/src/decoders_util.mli index 913a1db..7a59ed8 100644 --- a/src/decoders_util.mli +++ b/src/decoders_util.mli @@ -47,6 +47,18 @@ module My_list : sig val flat_map : ('a -> 'b list) -> 'a list -> 'b list end +module String_map : sig + type 'a t + + val of_list : (string * 'a) list -> 'a t + + val get : string -> 'a t -> 'a option + + val remove : string -> 'a t -> 'a t + + val choose_opt : 'a t -> (string * 'a) option +end + val with_file_in : string -> (in_channel -> 'a) -> 'a val read_all : in_channel -> string diff --git a/src/sig.ml b/src/sig.ml index 00f40c0..b82017d 100644 --- a/src/sig.ml +++ b/src/sig.ml @@ -119,6 +119,36 @@ module type S = sig (** {1 Object primitives} *) + module Obj : sig + type 'a obj + + val run : 'a obj -> 'a decoder + + val succeed : 'a -> 'a obj + + val bind : ('a -> 'b obj) -> 'a obj -> 'b obj + + val map : ('a -> 'b) -> 'a obj -> 'b obj + + val field : string -> 'a decoder -> 'a obj + + val field_opt : string -> 'a decoder -> 'a option obj + + val empty : unit obj + + module Infix : sig + val ( >>= ) : 'a obj -> ('a -> 'b obj) -> 'b obj + + val ( >|= ) : 'a obj -> ('a -> 'b) -> 'b obj + + val ( <*> ) : 'a obj -> ('a -> 'b) obj -> 'b obj + + val ( let* ) : 'a obj -> ('a -> 'b obj) -> 'b obj + + val ( let+ ) : 'a obj -> ('a -> 'b) -> 'b obj + end + end + val field : string -> 'a decoder -> 'a decoder (** Decode an object, requiring a particular field. *) diff --git a/test-yojson/main.ml b/test-yojson/main.ml index 3bced63..42f375c 100644 --- a/test-yojson/main.ml +++ b/test-yojson/main.ml @@ -172,6 +172,57 @@ let yojson_basic_suite = Format.asprintf "@,@[%a@]" pp_error e) in + let obj_test = + "objects" + >:: fun _test_ctxt -> + let obj = + Obj.( + let open Infix in + let* name = field "name" string in + let* age = field "age" int in + let* () = empty in + succeed (name, age)) + in + let decoder = Obj.run obj in + let input = {| {"name": "Jim", "age": 42} |} in + match decode_string decoder input with + | Ok value -> + assert_equal value ("Jim", 42) + | Error error -> + assert_string (Format.asprintf "%a" pp_error error) + in + + let obj_test_2 = + "objects with remaining fields" + >:: fun _test_ctxt -> + let obj = + Obj.( + let open Infix in + let* name = field "name" string in + let* age = field "age" int in + let* () = empty in + succeed (name, age)) + in + let decoder = Obj.run obj in + let input = {| {"name": "Jim", "age": 42, "another": "thing"} |} in + match decode_string decoder input with + | Ok _ -> + assert_string "Expected an error" + | Error error -> + let open Decoders in + assert_equal + error + (Error.make + {|Expected an empty object, but have unconsumed field "another"|} + ~context: + (`Assoc + [ ("name", `String "Jim") + ; ("age", `Int 42) + ; ("another", `String "thing") + ] ) ) + ~printer:(fun e -> Format.asprintf "@,@[%a@]" pp_error e) + in + "Yojson.Basic" >::: [ list_string_test ; array_string_test @@ -179,6 +230,8 @@ let yojson_basic_suite = ; mut_rec_test ; string_or_floatlit_test ; grouping_errors_test + ; obj_test + ; obj_test_2 ]