diff --git a/README.md b/README.md index fee9b27..c94cf8b 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,20 @@ Type inference and elaborator are implemented but the environment construction doesn't check for overlapping instances yet. +- `bidi_local` attempts to implement ["Local Type Inference"][bidi_local] by + Piece and Turner. + + This is a bidirectional type checking mechanism where we alternate between + checking mode (where we have type annotations available) and synthesis mode. + + Thus this requires type annotations for let-functions declarations. + + The type system supports subtyping (with `null`). The way it works is it + collects lower/upper bound constraints for type variables and then synthesises + the most general type for type variables at polymorphic abstraction + elimination. + + # Development ``` @@ -48,3 +62,4 @@ make test [inferno]: https://gitlab.inria.fr/fpottier/inferno/ [THIH]: https://web.cecs.pdx.edu/~mpj/thih/thih.pdf [tomprimozic/type-systems]: https://github.com/tomprimozic/type-systems +[bidi_local]: https://www.cis.upenn.edu/~bcpierce/papers/lti-toplas.pdf diff --git a/bidi_local/bidi_local.ml b/bidi_local/bidi_local.ml new file mode 100644 index 0000000..ac0bfdd --- /dev/null +++ b/bidi_local/bidi_local.ml @@ -0,0 +1,69 @@ +module Expr = struct + include Syntax.Expr + + include Nice_parser.Make (struct + type token = Parser.token + + type result = Syntax.expr + + let parse = Parser.expr_eof + + let next_token = Lexer.token + + exception ParseError = Parser.Error + + exception LexError = Lexer.Error + end) +end + +module Ty = struct + include Syntax.Ty + + include Nice_parser.Make (struct + type token = Parser.token + + type result = Syntax.ty + + let parse = Parser.ty_eof + + let next_token = Lexer.token + + exception ParseError = Parser.Error + + exception LexError = Lexer.Error + end) +end + +module Ty_sch = struct + include Syntax.Ty_sch + + include Nice_parser.Make (struct + type token = Parser.token + + type result = Syntax.ty_sch + + let parse = Parser.ty_sch_eof + + let next_token = Lexer.token + + exception ParseError = Parser.Error + + exception LexError = Lexer.Error + end) +end + +module Type_error = Type_error +module Var = Var + +module Env = struct + include Infer.Env + + let assume_val name ty env = + add_val ~kind:Val_top env name (Ty_sch.parse_string ty) +end + +let infer = Infer.infer + +let enable_colors = Layout.enable_colors + +let () = Expr.pp_exceptions () diff --git a/bidi_local/bin/dune b/bidi_local/bin/dune new file mode 100644 index 0000000..6b2ec25 --- /dev/null +++ b/bidi_local/bin/dune @@ -0,0 +1,4 @@ +(executable + (public_name bidi_local) + (name main) + (libraries bidi_local base stdio unix)) diff --git a/bidi_local/bin/main.ml b/bidi_local/bin/main.ml new file mode 100644 index 0000000..ce1faba --- /dev/null +++ b/bidi_local/bin/main.ml @@ -0,0 +1,31 @@ +open Base +open Bidi_local + +let () = + let () = if Unix.isatty Unix.stdout then enable_colors true in + let env = + Env.empty + |> Env.assume_val "id" "a . a -> a" + |> Env.assume_val "null" "a . a?" + |> Env.assume_val "one" "int" + |> Env.assume_val "succ" "int -> int" + |> Env.assume_val "nil" "a . list[a]" + |> Env.assume_val "cons" "a . (a, list[a]) -> list[a]" + |> Env.assume_val "map" "a, b . (a -> b, list[a]) -> list[b]" + |> Env.assume_val "choose" "a . (a, a) -> a" + |> Env.assume_val "choose3" "a . (a, a, a) -> a" + |> Env.assume_val "choose4" "a . (a, a, a, a) -> a" + |> Env.assume_val "hello" "string" + |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" + |> Env.assume_val "plus" "(int, int) -> int" + |> Env.assume_val "true" "bool" + |> Env.assume_val "ifnull" "a . (a?, a) -> a" + |> Env.assume_val "eq" "a . (a, a) -> bool" + in + let s = Stdio.In_channel.input_all Stdio.stdin in + let e = Expr.parse_string (String.strip s) in + match infer ~env e with + | Ok e -> Caml.Format.printf "%s@." (Expr.show e) + | Error err -> + let report = Type_error.to_report err in + Caml.Format.printf "%a@." Location.print_report report diff --git a/bidi_local/debug.ml b/bidi_local/debug.ml new file mode 100644 index 0000000..b6f6ad6 --- /dev/null +++ b/bidi_local/debug.ml @@ -0,0 +1,26 @@ +(** + + Debug flags which influence the amount of logging information and pretty + printing. + + Flags are controlled through DEBUG environment variables, use it like: + + $ DEBUG=sgi dune exec -- COMMAND + + The above invocation will toggle glags [s], [g] and [i] to be "on". + + *) + +open! Base + +let flags = Caml.Sys.getenv_opt "DEBUG" |> Option.value ~default:"" + +let log_solve = String.mem flags 's' + +let log_generalize = String.mem flags 'g' + +let log_instantiate = String.mem flags 'i' + +let log_check = String.mem flags 'c' + +let log_levels = String.mem flags 'l' diff --git a/bidi_local/dune b/bidi_local/dune new file mode 100644 index 0000000..5969ccb --- /dev/null +++ b/bidi_local/dune @@ -0,0 +1,11 @@ +(library + (name bidi_local) + (preprocess (pps ppx_sexp_conv)) + (libraries base pprint nice_parser) + ) + +(ocamllex lexer) + +(menhir + (modules parser) + (flags --explain --dump)) diff --git a/bidi_local/import.ml b/bidi_local/import.ml new file mode 100644 index 0000000..57aedda --- /dev/null +++ b/bidi_local/import.ml @@ -0,0 +1,108 @@ +include Base + +module Monad = struct + include Base.Monad + + (** A signature for modules implementing monadic let-syntax. *) + module type MONAD_SYNTAX = sig + type 'a t + + val return : 'a -> 'a t + + val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t + + val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t + end + + (** Make a monadic syntax out of the monad. *) + module Make_monad_syntax (P : Base.Monad.S) : + MONAD_SYNTAX with type 'a t := 'a P.t = struct + let return = P.return + + let ( let* ) v f = P.bind v ~f + + let ( let+ ) v f = P.map v ~f + end + + module type S = sig + include Base.Monad.S + + module Monad_syntax : MONAD_SYNTAX with type 'a t := 'a t + end + + module Make (P : Basic) : S with type 'a t := 'a P.t = struct + module Self = Base.Monad.Make (P) + include Self + + module Monad_syntax = Make_monad_syntax (struct + type 'a t = 'a P.t + + include Self + end) + end +end + +module MakeId () = struct + let c = ref 0 + + let fresh () = + Int.incr c; + !c + + let reset () = c := 0 +end + +module type SHOWABLE = sig + type t + + val show : t -> string + + val print : ?label:string -> t -> unit +end + +module Showable (S : sig + type t + + val layout : t -> PPrint.document +end) : SHOWABLE with type t = S.t = struct + type t = S.t + + let show v = + let width = 60 in + let buf = Buffer.create 100 in + PPrint.ToBuffer.pretty 1. width buf (S.layout v); + Buffer.contents buf + + let print ?label v = + match label with + | Some label -> Caml.print_endline (label ^ ": " ^ show v) + | None -> Caml.print_endline (show v) +end + +module type DUMPABLE = sig + type t + + val dump : ?label:string -> t -> unit + + val sdump : ?label:string -> t -> string +end + +module Dumpable (S : sig + type t + + val sexp_of_t : t -> Sexp.t +end) : DUMPABLE with type t = S.t = struct + type t = S.t + + let dump ?label v = + let s = S.sexp_of_t v in + match label with + | None -> Caml.Format.printf "%a@." Sexp.pp_hum s + | Some label -> Caml.Format.printf "%s %a@." label Sexp.pp_hum s + + let sdump ?label v = + let s = S.sexp_of_t v in + match label with + | None -> Caml.Format.asprintf "%a@." Sexp.pp_hum s + | Some label -> Caml.Format.asprintf "%s %a@." label Sexp.pp_hum s +end diff --git a/bidi_local/infer.ml b/bidi_local/infer.ml new file mode 100644 index 0000000..1ffc5af --- /dev/null +++ b/bidi_local/infer.ml @@ -0,0 +1,465 @@ +open! Import +open! Syntax + +(** Typing environment. *) +module Env : sig + type t + + type val_kind = Val_top | Val_local + + val empty : t + + val add_val : ?kind:val_kind -> t -> name -> ty_sch -> t + + val find_val : t -> name -> (val_kind * ty_sch) option + + val add_var : t -> name -> var -> t + + val build_ty_subst : ?init:Ty_subst.t -> t -> Ty_subst.t +end = struct + type t = { + vals : (name, val_kind * ty_sch, String.comparator_witness) Map.t; + vars : (name, var, String.comparator_witness) Map.t; + } + + and val_kind = Val_top | Val_local + + let add_val ?(kind = Val_local) env name ty_sch = + { env with vals = Map.set env.vals ~key:name ~data:(kind, ty_sch) } + + let add_var env name var = + { env with vars = Map.set env.vars ~key:name ~data:var } + + let find_val env name = Map.find env.vals name + + let empty = + { vals = Map.empty (module String); vars = Map.empty (module String) } + + let build_ty_subst ?(init = Ty_subst.empty) env = + Map.fold env.vars ~init ~f:(fun ~key:name ~data:v subst -> + Ty_subst.add_name subst name v) +end + +type ctx = { env : Env.t; lvl : lvl; variance : Variance.t } +(** Information we need to pass between different typechecking routines. *) + +(** [instantiate ~lvl ty_sch] instantiates type scheme [ty_sch] into [ty] type. + + It does so by substituting all generic type variables with fresh type + variables at specific level [lvl]. + *) +let instantiate ?(val_kind = Env.Val_local) ?env ~lvl ~variance ty_sch = + let vs, ty = ty_sch in + let subst = + List.fold vs ~init:Ty_subst.empty ~f:(fun subst v -> + let v' = + match val_kind with + | Env.Val_local -> Var.refresh ~lvl v + | Env.Val_top -> Var.fresh ~lvl () + in + Ty_subst.add_var subst v v') + in + let subst = + match env with + | None -> subst + | Some env -> Env.build_ty_subst ~init:subst env + in + let ty = Ty_subst.apply_ty ~variance subst ty in + if Debug.log_instantiate then + Caml.Format.printf "INST lvl:%i %s@." lvl (Ty.show ty); + ty + +(** [generalize ~lvl ty] generalizes type [ty] to a type scheme. + + It finds all unresolved type variables marked with level > [lvl] and makes + them generalized type variables. The level check acts as a scope check so we + don't generalize over variables from the outer scope. + *) +let generalize ~lvl ty = + let vs = ref [] in + let rec aux = function + | Ty_top + | Ty_bot + | Ty_const _ -> + () + | Ty_arr (args, ty) -> + List.iter args ~f:aux; + aux ty + | Ty_app (ty, args) -> + List.iter args ~f:aux; + aux ty + | Ty_nullable ty -> aux ty + | Ty_var v -> ( + match Var.ty v with + | Some ty -> aux ty + | None -> if Var.lvl v > lvl then vs := v :: !vs else ()) + | Ty_record row -> aux row + | Ty_row_empty -> () + | Ty_row_extend ((_, ty), row) -> + aux ty; + aux row + in + aux ty; + let ty_sch = (List.dedup_and_sort ~compare:Var.compare !vs, ty) in + if Debug.log_instantiate then + Caml.Format.printf "GENR lvl:%i %s@." lvl (Ty_sch.show ty_sch); + ty_sch + +let rec promote ~lvl (ty : ty) = + match ty with + | Ty_const _ -> ty + | Ty_var v -> ( + match Var.ty v with + | None -> if Var.lvl v > lvl then Ty_top else ty + | Some ty -> promote ~lvl ty) + | Ty_app (ty, args) -> + Ty_app (promote ~lvl ty, List.map args ~f:(promote ~lvl)) + | Ty_nullable ty -> Ty_nullable (promote ~lvl ty) + | Ty_arr (args, ty) -> Ty_arr (List.map args ~f:(demote ~lvl), promote ~lvl ty) + | Ty_record row -> Ty_record (promote ~lvl row) + | Ty_row_empty -> ty + | Ty_row_extend ((name, ty), row) -> + Ty_row_extend ((name, promote ~lvl ty), promote ~lvl row) + | Ty_bot -> ty + | Ty_top -> ty + +and demote ~lvl (ty : ty) = + match ty with + | Ty_const _ -> ty + | Ty_var v -> ( + match Var.ty v with + | None -> if Var.lvl v > lvl then Ty_bot else ty + | Some ty -> demote ~lvl ty) + | Ty_app (ty, args) -> Ty_app (promote ~lvl ty, List.map args ~f:(demote ~lvl)) + | Ty_nullable ty -> Ty_nullable (demote ~lvl ty) + | Ty_arr (args, ty) -> Ty_arr (List.map args ~f:(promote ~lvl), demote ~lvl ty) + | Ty_record row -> Ty_record (demote ~lvl row) + | Ty_row_empty -> ty + | Ty_row_extend ((name, ty), row) -> + Ty_row_extend ((name, demote ~lvl ty), demote ~lvl row) + | Ty_bot -> ty + | Ty_top -> ty + +let subsumes ~lvl constraints ~sub:(sub_loc, sub_ty) ~super:(super_loc, super_ty) + = + let exception Subsumption_failed of { sub_ty : ty; super_ty : ty } in + let rec aux ~sub_ty ~super_ty = + if Debug.log_solve then + Caml.Format.printf "??? %s <: %s@." (Ty.show sub_ty) (Ty.show super_ty); + if phys_equal super_ty sub_ty || Ty.equal super_ty sub_ty then () + else + match (sub_ty, super_ty) with + | Ty_const name, Ty_const name' -> + if not (String.equal name name') then + raise (Subsumption_failed { sub_ty; super_ty }) + | Ty_nullable sub_ty', Ty_nullable super_ty' -> + aux ~super_ty:super_ty' ~sub_ty:sub_ty' + | sub_ty', Ty_nullable super_ty' -> + aux ~super_ty:super_ty' ~sub_ty:sub_ty' + | Ty_app (sub_ty', sub_args), Ty_app (super_ty', super_args) -> ( + aux ~super_ty:super_ty' ~sub_ty:sub_ty'; + match + List.iter2 super_args sub_args ~f:(fun super_ty' sub_ty' -> + aux ~super_ty:super_ty' ~sub_ty:sub_ty') + with + | Unequal_lengths -> raise (Subsumption_failed { sub_ty; super_ty }) + | Ok () -> ()) + | Ty_arr (sub_args, sub_ty'), Ty_arr (super_args, super_ty') -> + (match + List.iter2 super_args sub_args ~f:(fun super_ty' sub_ty' -> + aux ~sub_ty:super_ty' ~super_ty:sub_ty') + with + | Unequal_lengths -> raise (Subsumption_failed { sub_ty; super_ty }) + | Ok () -> ()); + aux ~sub_ty:sub_ty' ~super_ty:super_ty' + | Ty_record sub_row, Ty_record super_row -> ( + try + Subtyping.unify_rows + (fun sub_ty super_ty -> aux ~sub_ty ~super_ty) + sub_row super_row + with + | Subtyping.Unification_error -> + raise (Subsumption_failed { sub_ty; super_ty })) + | Ty_var sub_v, Ty_var super_v -> ( + match (Var.ty sub_v, Var.ty super_v) with + | None, None -> Subtyping.union_vars sub_v super_v + | Some sub_ty, None -> + Subtyping.Constraint_set.add constraints super_v + (promote ~lvl sub_ty, Ty_top) + | None, Some super_ty -> + Subtyping.Constraint_set.add constraints sub_v + (Ty_bot, demote ~lvl super_ty) + | Some sub_ty, Some super_ty -> aux ~sub_ty ~super_ty) + | Ty_var sub_v, super_ty -> ( + match Var.ty sub_v with + | Some sub_ty -> aux ~super_ty ~sub_ty + | None -> + Subtyping.Constraint_set.add constraints sub_v + (Ty_bot, demote ~lvl super_ty)) + | sub_ty, Ty_var super_v -> ( + match Var.ty super_v with + | Some super_ty -> aux ~super_ty ~sub_ty + | None -> + Subtyping.Constraint_set.add constraints super_v + (promote ~lvl sub_ty, Ty_top)) + | _, Ty_top -> () + | Ty_bot, _ -> () + | Ty_row_empty, _ + | _, Ty_row_empty -> + assert false + | Ty_row_extend _, _ + | _, Ty_row_extend _ -> + assert false + | _ -> raise (Subsumption_failed { sub_ty; super_ty }) + in + try aux ~sub_ty ~super_ty with + | Subsumption_failed _ -> + Type_error.raise_not_a_subtype ~sub_loc ~sub_ty ~super_loc ~super_ty () + +let fresh_fun_ty ~arity v = + let lvl = Var.lvl v in + let args_ty = List.init arity ~f:(fun _ -> Ty.var (Var.fresh ~lvl ())) in + let body_ty = Ty.var (Var.fresh ~lvl ()) in + Var.set_ty v (Ty_arr (args_ty, body_ty)); + (args_ty, body_ty) + +let fresh_record_ty v = + let lvl = Var.lvl v in + let row = Ty_var (Var.fresh ~lvl ()) in + let ty = Ty_record row in + Var.set_ty v ty; + ty + +let resolve_ty ty = + match ty with + | Ty_var v as ty -> Option.value (Var.ty v) ~default:ty + | ty -> ty + +let rec synth ~ctx expr = + let ty, expr = synth' ~ctx expr in + (resolve_ty ty, expr) + +and synth' ~ctx expr = + if Debug.log_check then + Caml.Format.printf "SYNTH%s %s@." + (Variance.show ctx.variance) + (Expr.show expr); + match expr with + | loc, E_var name -> + let ty = + match Env.find_val ctx.env name with + | None -> Type_error.raise (Error_unknown_name { name = (loc, name) }) + | Some (val_kind, ty_sch) -> + instantiate ~val_kind ~lvl:ctx.lvl ~variance:ctx.variance ty_sch + in + (ty, expr) + | _, E_ann (expr, ty_sch) -> + let ty = + instantiate ~env:ctx.env ~lvl:ctx.lvl ~variance:ctx.variance ty_sch + in + (* TODO: here we drop E_ann, is this ok? *) + (ty, check ~ctx expr ty) + | loc, E_abs (vs, args, body) -> + let env, vs = + List.fold vs ~init:(ctx.env, []) ~f:(fun (env, vs) v -> + let v = Var.refresh ~lvl:ctx.lvl v in + (Env.add_var env (Var.name v) v, v :: vs)) + in + let env, args, args_ty = + List.fold args ~init:(env, [], []) + ~f:(fun (env, args, args_ty) ((loc, name), ty) -> + match ty with + | None -> + Type_error.raise + (Error_missing_type_annotation { expr = (loc, E_var name) }) + | Some ty -> + let ty = + instantiate ~env ~lvl:ctx.lvl + ~variance:(Variance.inv ctx.variance) + ([], ty) + in + ( Env.add_val env name ([], ty), + ((loc, name), Some ty) :: args, + ty :: args_ty )) + in + let body_ty, body = synth ~ctx:{ ctx with env } body in + let vs = + (* Only keep variables which were not solved during checking args and + body. *) + List.filter vs ~f:Var.is_empty + in + ( Ty_arr (List.rev args_ty, body_ty), + (loc, E_abs (List.rev vs, List.rev args, body)) ) + | loc, E_app (f, args) -> + (* S-App-InfAlg *) + let (args_tys, body_ty), f = + match synth ~ctx f with + | Ty_arr (args_tys, body_ty), f -> ((args_tys, body_ty), f) + | Ty_var v, f -> + assert (Var.is_empty v); + (fresh_fun_ty v ~arity:(List.length args), f) + | ty, _ -> + Type_error.raise (Error_expected_a_function { loc = Expr.loc f; ty }) + in + let constraints = Subtyping.Constraint_set.empty () in + let args = + match + List.map2 args_tys args ~f:(fun ty arg -> + check' + ~ctx:{ ctx with variance = Variance.inv ctx.variance } + ~constraints arg ty) + with + | Unequal_lengths -> Type_error.raise (Error_arity_mismatch { loc }) + | Ok args -> args + in + Subtyping.Constraint_set.solve constraints; + let expr = E_app (f, args) in + if Debug.log_solve then + Caml.Format.printf "== SOLVED %s@." + (Expr.show (loc, E_ann ((loc, expr), ([], body_ty)))); + (body_ty, (loc, E_app (f, args))) + | loc, E_record fields -> + let row, fields = + (* List.fold fields ~init:(Ty_row_empty, []) *) + List.fold_right fields ~init:(Ty_row_empty, []) + ~f:(fun (name, expr) (row, fields) -> + let ty, expr = synth ~ctx expr in + (Ty_row_extend ((name, ty), row), (name, expr) :: fields)) + in + (Ty_record row, (loc, E_record fields)) + | loc, E_record_project (expr, name) -> + let row = Ty.var @@ Var.fresh ~lvl:ctx.lvl () in + let ty = Ty.var @@ Var.fresh ~lvl:ctx.lvl () in + let record_ty = Ty_record (Ty_row_extend ((name, ty), row)) in + let constraints = Subtyping.Constraint_set.empty () in + let expr = check' ~ctx ~constraints expr record_ty in + Subtyping.Constraint_set.solve constraints; + (ty, (loc, E_record_project (expr, name))) + | loc, E_record_extend (expr, fields) -> + let constraints = Subtyping.Constraint_set.empty () in + let row = Ty.var @@ Var.fresh ~lvl:ctx.lvl () in + let expr = check' ~ctx ~constraints expr (Ty_record row) in + let row, fields = + List.fold fields ~init:(row, []) ~f:(fun (row, fields) (name, expr) -> + let ty, expr = synth ~ctx expr in + (Ty_row_extend ((name, ty), row), (name, expr) :: fields)) + in + let ty = Ty_record row in + Subtyping.Constraint_set.solve constraints; + (ty, (loc, E_record_extend (expr, List.rev fields))) + | loc, E_record_update (expr, fields) -> + let constraints = Subtyping.Constraint_set.empty () in + let row, fields = + let row = Ty.var @@ Var.fresh ~lvl:ctx.lvl () in + List.fold fields ~init:(row, []) ~f:(fun (row, fields) (name, expr) -> + let ty, expr = synth ~ctx expr in + (Ty_row_extend ((name, ty), row), (name, expr) :: fields)) + in + let ty = Ty_record row in + let ty', expr = synth ~ctx expr in + subsumes ~lvl:ctx.lvl constraints ~sub:(loc, ty) ~super:(Expr.loc expr, ty'); + Subtyping.Constraint_set.solve constraints; + (ty, (loc, E_record_update (expr, List.rev fields))) + | _, E_lit (Lit_string _) -> (Ty_const "string", expr) + | _, E_lit (Lit_int _) -> (Ty_const "int", expr) + | loc, E_let ((name, expr, e_ty_sch), body) -> + let e_ty, expr = + match e_ty_sch with + | None -> synth ~ctx:{ ctx with lvl = ctx.lvl + 1 } expr + | Some e_ty_sch -> + let e_ty = + instantiate ~env:ctx.env ~lvl:(ctx.lvl + 1) ~variance:ctx.variance + e_ty_sch + in + (e_ty, check ~ctx:{ ctx with lvl = ctx.lvl + 1 } expr e_ty) + in + let e_ty_sch = generalize ~lvl:ctx.lvl e_ty in + let env = Env.add_val ctx.env name e_ty_sch in + let body_ty, body = synth ~ctx:{ ctx with env } body in + (body_ty, (loc, E_let ((name, expr, Some e_ty_sch), body))) + +and check' ~ctx ~constraints expr ty = + if Debug.log_check then + Caml.Format.printf "CHECK%s %s : %s@." + (Variance.show ctx.variance) + (Expr.show expr) (Ty.show ty); + match expr with + | loc, E_abs (vs, args, body) -> + let env, vs = + List.fold vs ~init:(ctx.env, []) ~f:(fun (env, vs) v -> + let v = Var.refresh ~lvl:ctx.lvl v in + (Env.add_var env (Var.name v) v, v :: vs)) + in + let args_ty, body_ty = + match resolve_ty ty with + | Ty_arr (args_ty, ret_ty) -> (args_ty, ret_ty) + | Ty_var v -> + assert (Var.is_empty v); + fresh_fun_ty v ~arity:(List.length args) + | ty -> Type_error.raise (Error_expected_a_function { loc; ty }) + in + let env, args = + match + List.fold2 args args_ty ~init:(env, []) + ~f:(fun (env, args) ((loc, name), ty') ty -> + Option.iter ty' ~f:(fun ty' -> + subsumes ~lvl:ctx.lvl constraints ~sub:(Location.none, ty) + ~super:(loc, ty')); + let env = Env.add_val env name ([], ty) in + (env, ((loc, name), Some ty) :: args)) + with + | Unequal_lengths -> Type_error.raise (Error_arity_mismatch { loc }) + | Ok (env, args) -> (env, List.rev args) + in + let body = check' ~ctx:{ ctx with env } ~constraints body body_ty in + (loc, E_abs (List.rev vs, args, body)) + | loc, E_app (f, args) -> + let f_ty, f = synth ~ctx f in + let args = args |> List.map ~f:(synth ~ctx) in + let args_tys', ty' = + match resolve_ty f_ty with + | Ty_arr (args_tys', ty') -> (args_tys', ty') + | Ty_var v -> + assert (Var.is_empty v); + fresh_fun_ty v ~arity:(List.length args) + | ty -> + Type_error.raise (Error_expected_a_function { loc = Expr.loc f; ty }) + in + let () = + match + List.iter2 args_tys' args ~f:(fun ty' (ty, expr) -> + subsumes ~lvl:ctx.lvl constraints + ~sub:(Expr.loc expr, ty) + ~super:(Location.none, ty')) + with + | Unequal_lengths -> Type_error.raise (Error_arity_mismatch { loc }) + | Ok () -> () + in + subsumes ~lvl:ctx.lvl constraints ~sub:(loc, ty') ~super:(Location.none, ty); + (loc, E_app (f, List.map args ~f:snd)) + | expr -> + let ty', expr = synth ~ctx expr in + subsumes ~lvl:ctx.lvl constraints + ~sub:(Expr.loc expr, ty') + ~super:(Location.none, ty); + expr + +and check ~ctx expr ty = + let constraints = Subtyping.Constraint_set.empty () in + let expr = check' ~ctx ~constraints expr ty in + Subtyping.Constraint_set.solve constraints; + expr + +let infer ~env expr : (expr, Type_error.t) Result.t = + let ctx = { lvl = 1; env; variance = Covariant } in + try + Ok + (let ty, expr = synth ~ctx expr in + let ty = generalize ~lvl:0 ty in + match expr with + | _, E_let ((name, _, _), (_, E_var name')) when String.equal name name' + -> + expr + | loc, expr -> (loc, E_ann ((loc, expr), ty))) + with + | Type_error.Type_error err -> Error err diff --git a/bidi_local/layout.ml b/bidi_local/layout.ml new file mode 100644 index 0000000..17c8a5d --- /dev/null +++ b/bidi_local/layout.ml @@ -0,0 +1,190 @@ +open! Import +open! Syntax0 + +module Var_name : sig + type t + + val make : string -> int -> t + + val to_string : t -> string + + (* val of_string : string -> t *) + + val succ : t -> t + + (* val next : t -> t *) + + include Comparator.S with type t := t +end = struct + type t = string * int + + let make s n = (s, n) + + let to_string = function + | s, 0 -> s + | s, n -> Printf.sprintf "%s/%i" s n + + (* let of_string s = *) + (* if String.length s = 0 then raise (Invalid_argument s); *) + (* (s, 0) *) + + let succ (s, n) = + if String.length s > 1 then (s ^ "'", n) + else + match String.get s 0 with + | 'z' -> (s ^ "'", n) + | c -> (String.of_char (Char.of_int_exn (Char.to_int c + 1)), n) + + (* let next (c, n) = (c, n + 1) *) + + include Comparator.Make (struct + type nonrec t = t + + let compare (ac, an) (bc, bn) = + match String.compare ac bc with + | 0 -> Int.compare an bn + | n -> n + + let sexp_of_t a = Sexp.Atom (to_string a) + end) +end + +module Names : sig + type t + + val empty : t + + val alloc_var : var -> t -> t * string + + val lookup_var : var -> t -> string option +end = struct + type t = { + by_id : (int, string, Int.comparator_witness) Map.t; + by_name : (string, int, String.comparator_witness) Map.t; + } + + let empty = + { by_id = Map.empty (module Int); by_name = Map.empty (module String) } + + let lookup_var var names = + let var = Union_find.value var in + Map.find names.by_id var.id + + let alloc_var var names = + let var = Union_find.value var in + match Map.find names.by_id var.id with + | Some name -> (names, name) + | None -> + let names, name = + match var.name with + | Some name -> + let name', n = + match Map.find names.by_name name with + | None -> (name, 1) + | Some n -> (Printf.sprintf "%s/%i" name n, n + 1) + in + let by_name = Map.set names.by_name ~key:name ~data:n in + ({ names with by_name }, name') + | None -> + let name = + Var_name.succ + (match Map.max_elt names.by_name with + | None -> Var_name.make "a" 0 + | Some (s, n) -> Var_name.make s n) + in + (names, Var_name.to_string name) + in + let by_id = Map.set names.by_id ~key:var.id ~data:name in + ({ names with by_id }, name) +end + +include PPrint + +type 'a t = Names.t -> Names.t * 'a + +let render layout = + let _, v = layout Names.empty in + v + +let to_string layout = + let doc = render layout in + let width = 60 in + let buf = Buffer.create 100 in + PPrint.ToBuffer.pretty 1. width buf doc; + Buffer.contents buf + +type layout = document t + +let alloc_var = Names.alloc_var + +let lookup_var var : string option t = + fun names -> (names, Names.lookup_var var names) + +type names = Names.t + +let names names = (names, names) + +let closed ?names v cnames = + let names = + match names with + | None -> cnames + | Some names -> names + in + let _names, v = v names in + (names, v) + +type color = Black | Red | Green | Yellow | Blue | Magenta | Cyan | White + +let int_of_color col = + match col with + | Black -> 0 + | Red -> 1 + | Green -> 2 + | Yellow -> 3 + | Blue -> 4 + | Magenta -> 5 + | Cyan -> 6 + | White -> 7 + +let enable_colors_flag = ref false + +let enable_colors enable = enable_colors_flag := enable + +let bold doc = + if not !enable_colors_flag then doc + else fancystring "\x1B[1m" 0 ^^ doc ^^ fancystring "\x1B[0m" 0 + +let fg color doc = + if not !enable_colors_flag then doc + else + let opening = Printf.sprintf "\x1B[3%dm" (int_of_color color) in + fancystring opening 0 ^^ doc ^^ fancystring "\x1B[0m" 0 + +include Monad.Make (struct + type nonrec 'a t = 'a t + + let return v names = (names, v) + + let bind v ~f names = + let names, v = v names in + let names, v = f v names in + (names, v) + + let map v ~f names = + let names, v = v names in + (names, f v) + + let map = `Custom map +end) + +include Monad_syntax + +let list_map xs ~f = + let rec aux = function + | [] -> return [] + | x :: xs -> + let* x = f x in + let* xs = aux xs in + return (x :: xs) + in + aux xs diff --git a/bidi_local/layout.mli b/bidi_local/layout.mli new file mode 100644 index 0000000..712b8d4 --- /dev/null +++ b/bidi_local/layout.mli @@ -0,0 +1,53 @@ +(** Layout data structures in a human readable way. + + This module is designed to be openned. + *) + +open! Import +open! Syntax0 + +include module type of PPrint +(** We include all pprint API. *) + +type 'a t +(** Computation which can allocate and lookup names. *) + +val alloc_var : var -> string t +(** Allocate name for a variable. *) + +val lookup_var : var -> string option t +(** Lookup name for a variable. *) + +type names + +val names : names t +(** Get names at current position. *) + +val closed : ?names:names -> 'a t -> 'a t +(** Closed term which doesn't leak names outside. *) + +include Monad.S with type 'a t := 'a t +(** ['a t] is a monad. *) + +include Monad.MONAD_SYNTAX with type 'a t := 'a t + +val list_map : 'a list -> f:('a -> 'b t) -> 'b list t +(** Map with a monadic action over a list of items. *) + +type layout = document t +(** A layout is a computation which produces a document. *) + +val render : layout -> document +(** Render layout into a document. *) + +val to_string : layout -> string + +(** Colors *) + +type color = Black | Red | Green | Yellow | Blue | Magenta | Cyan | White + +val bold : document -> document + +val fg : color -> document -> document + +val enable_colors : bool -> unit diff --git a/bidi_local/lexer.mll b/bidi_local/lexer.mll new file mode 100644 index 0000000..fbf1e10 --- /dev/null +++ b/bidi_local/lexer.mll @@ -0,0 +1,62 @@ +{ + +open Parser + +exception Error of string + +} + + +let ident = ['_' 'A'-'Z' 'a'-'z'] ['_' 'A'-'Z' 'a'-'z' '0'-'9']* +let integer = ['0'-'9']+ + +rule token = parse + | [' ' '\t' '\r' '\n'] { token lexbuf } + | "fun" { FUN } + | "let" { LET } + | "in" { IN } + | "with" { WITH } + | ident { IDENT (Lexing.lexeme lexbuf) } + | '(' { LPAREN } + | ')' { RPAREN } + | '[' { LBRACKET } + | ']' { RBRACKET } + | '{' { LBRACE } + | '}' { RBRACE } + | '=' { EQUALS } + | ':' '=' { ASSIGN } + | "->" { ARROW } + | ',' { COMMA } + | '.' { DOT } + | '.' '.' '.' { ELLIPSIS } + | ':' { COLON } + | '?' { QUESTION } + | eof { EOF } + | _ as c { raise (Error ("unexpected token: '" ^ Char.escaped c ^ "'")) } + + +{ + +let string_of_token = function + | FUN -> "fun" + | LET -> "let" + | IN -> "in" + | WITH -> "forall" + | IDENT ident -> ident + | LPAREN -> "(" + | RPAREN -> ")" + | LBRACKET -> "[" + | RBRACKET -> "]" + | LBRACE -> "{" + | RBRACE -> "}" + | EQUALS -> "=" + | ASSIGN -> ":=" + | ARROW -> "->" + | COMMA -> "," + | DOT -> "." + | ELLIPSIS -> "." + | COLON -> ":" + | QUESTION -> "?" + | EOF -> "" + +} diff --git a/bidi_local/parser.mly b/bidi_local/parser.mly new file mode 100644 index 0000000..a604e62 --- /dev/null +++ b/bidi_local/parser.mly @@ -0,0 +1,194 @@ +%{ + +open Syntax + +let makeloc s e = + {Location.loc_start=s; loc_end=e; loc_ghost=false} + +let makeenv vars = + let open Base in + Var.reset (); + let vs, map = List.fold_left + vars + ~init:([], Map.empty (module String)) + ~f:(fun (vs, env) name -> + let v = Var.fresh ~name () in + v::vs, + Map.set env ~key:name ~data:(Ty.var v)) in + List.rev vs, map + +let build_ty_sch (vs, env) ty = + let open Base in + let rec build_ty ty = match ty with + | Ty_const name -> ( + match Map.find env name with + | Some ty -> ty + | None -> ty) + | Ty_top + | Ty_bot + | Ty_var _ -> ty + | Ty_nullable ty -> Ty_nullable (build_ty ty) + | Ty_app (fty, atys) -> Ty_app (build_ty fty, List.map atys ~f:build_ty) + | Ty_arr (atys, rty) -> Ty_arr (List.map atys ~f:build_ty, build_ty rty) + | Ty_record row -> Ty_record (build_ty row) + | Ty_row_empty -> ty + | Ty_row_extend ((name, ty), row) -> + Ty_row_extend ((name, build_ty ty), build_ty row) + in + vs, build_ty ty +%} + +%token IDENT +%token FUN LET IN WITH +%token LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE +%token ARROW EQUALS COMMA DOT COLON ASSIGN QUESTION ELLIPSIS +%token EOF + +%start expr_eof +%type expr_eof +%start ty_sch_eof +%type ty_sch_eof +%start ty_eof +%type ty_eof + +%% + +expr_eof: + e = expr EOF { e } + +ty_sch_eof: + t = ty_sch EOF { t } + +ty_eof: + t = ty EOF { t } + +expr: + e = simple_expr { e } + + | LPAREN e = expr t = expr_annot RPAREN + { makeloc $startpos $endpos, E_ann (e, t) } + + (* let-bindings *) + | LET n = IDENT t = option(expr_annot) EQUALS e = expr IN b = expr + { makeloc $startpos $endpos, E_let ((n, e, t), b) } + + (* functions *) + | FUN ty_args = ty_args arg = ident ARROW body = expr + { makeloc $startpos $endpos, E_abs (ty_args, [arg, None], body) } + | FUN ty_args = ty_args args = args ARROW body = expr + { makeloc $startpos $endpos, E_abs (ty_args, args, body) } + + | LET n = IDENT ty_args = ty_args arg = ident EQUALS e = expr IN b = expr + { + let e = makeloc $startpos $endpos, E_abs (ty_args, [arg, None], e) in + makeloc $startpos $endpos, E_let ((n, e, None), b) + } + | LET n = IDENT ty_args = ty_args args = args EQUALS e = expr IN b = expr + { + let e = makeloc $startpos $endpos, E_abs (ty_args, args, e) in + makeloc $startpos $endpos, E_let ((n, e, None), b) + } + +ident: + n = IDENT { makeloc $startpos $endpos, n } + +%inline expr_annot: + COLON t = ty_sch { t } + +args: + LPAREN args = flex_list(COMMA, arg) RPAREN { args } + +arg: + n = ident t = option(arg_annot) { n, t } + +ty_args: + (* empty *) { [] } + | LBRACKET vs = nonempty_flex_list(COMMA, ty_arg) RBRACKET { vs } + +ty_arg: + n = IDENT { Var.fresh ~name:n () } + +arg_annot: + COLON t = ty { t } + +simple_expr: + n = IDENT + { makeloc $startpos $endpos, E_var n } + | LPAREN e = expr RPAREN { e } + | f = simple_expr LPAREN args = flex_list(COMMA, expr) RPAREN + { makeloc $startpos $endpos, E_app (f, args) } + | LBRACE fs = flex_list(COMMA, record_field) RBRACE + { makeloc $startpos $endpos, E_record fs } + | LBRACE e = expr WITH fs = nonempty_flex_list(COMMA, record_field) RBRACE + { makeloc $startpos $endpos, E_record_extend (e, fs) } + | LBRACE e = expr WITH fs = nonempty_flex_list(COMMA, record_field_update) RBRACE + { makeloc $startpos $endpos, E_record_update (e, fs) } + | e = simple_expr DOT n = IDENT + { makeloc $startpos $endpos, E_record_project (e, n) } + +record_field: + n = IDENT EQUALS e = expr + { n, e } + +record_field_update: + n = IDENT ASSIGN e = expr + { n, e } + +ident_list: + xs = nonempty_flex_list(COMMA, IDENT) { xs } + +ty_sch: + t = ty { [], t } + | vars = ident_list DOT t = ty + { let env = makeenv vars in build_ty_sch env t } + +ty: + t = simple_ty + { t } + | LPAREN RPAREN ARROW ret = ty + { Ty_arr ([], ret) } + | arg = simple_ty ARROW ret = ty + { Ty_arr ([arg], ret) } + | LPAREN arg = ty COMMA args = flex_list(COMMA, ty) RPAREN ARROW ret = ty + { Ty_arr (arg :: args, ret) } + +simple_ty: + n = IDENT { Ty_const n } + | LPAREN t = ty RPAREN { t } + | f = simple_ty LBRACKET args = nonempty_flex_list(COMMA, ty) RBRACKET + { Ty_app (f, args) } + | LBRACE RBRACE + { Ty_record Ty_row_empty } + | LBRACE row = ty_row RBRACE + { Ty_record row } + | t = simple_ty QUESTION + { Ty.nullable t } + +ty_row: + ELLIPSIS n = IDENT + { Ty_const n } + | n = IDENT COLON ty = ty COMMA? + { Ty_row_extend ((n, ty), Ty_row_empty) } + | n = IDENT COLON ty = ty COMMA row = ty_row + { Ty_row_extend ((n, ty), row) } + +(* Utilities for flexible lists (and its non-empty version). + + A flexible list [flex_list(delim, X)] is the delimited with [delim] list of + it [X] items where it is allowed to have a trailing [delim]. + + A non-empty [nonempty_flex_list(delim, X)] version of flexible list is + provided as well. + + From http://gallium.inria.fr/blog/lr-lists/ + + *) + +flex_list(delim, X): + { [] } + | x = X { [x] } + | x = X delim xs = flex_list(delim, X) { x::xs } + +nonempty_flex_list(delim, X): + x = X { [x] } + | x = X delim xs = flex_list(delim, X) { x::xs } diff --git a/bidi_local/subtyping.ml b/bidi_local/subtyping.ml new file mode 100644 index 0000000..3a376f6 --- /dev/null +++ b/bidi_local/subtyping.ml @@ -0,0 +1,378 @@ +open Import +open Syntax + +exception Unification_error + +let rec unify (a : ty) (b : ty) = + if Debug.log_solve then + Caml.Format.printf "UNIFY %s %s@." (Ty.show a) (Ty.show b); + match (a, b) with + | Ty_const a, Ty_const b -> if not String.(a = b) then raise Unification_error + | Ty_app (af, aargs), Ty_app (bf, bargs) -> ( + unify af bf; + match List.iter2 aargs bargs ~f:unify with + | Unequal_lengths -> raise Unification_error + | Ok () -> ()) + | Ty_nullable a, Ty_nullable b -> unify a b + | Ty_arr (aargs, ab), Ty_arr (bargs, bb) -> + (match List.iter2 aargs bargs ~f:unify with + | Unequal_lengths -> raise Unification_error + | Ok () -> ()); + unify ab bb + | Ty_record a, Ty_record b -> unify_rows unify a b + | Ty_row_empty, _ + | _, Ty_row_empty -> + assert false + | Ty_row_extend _, _ + | _, Ty_row_extend _ -> + assert false + | Ty_bot, Ty_bot -> () + | Ty_top, Ty_top -> () + | Ty_var v1, Ty_var v2 -> ( + match (Var.ty v1, Var.ty v2) with + | None, None -> union_vars v1 v2 + | Some a, Some b -> unify a b + | None, Some ty -> Var.set_ty v1 ty + | Some ty, None -> Var.set_ty v2 ty) + | Ty_var v, ty + | ty, Ty_var v -> ( + match Var.ty v with + | None -> Var.set_ty v ty + | Some ty' -> unify ty ty') + | _, _ -> raise Unification_error + +and unify_rows unify_ty a b = + if Debug.log_solve then + Caml.Format.printf "UNIFY %s %s@." (Ty.show a) (Ty.show b); + match (a, b) with + | Ty_row_empty, Ty_row_empty -> () + | Ty_row_empty, Ty_row_extend _ -> raise Unification_error + | Ty_row_extend ((name, ty), a), b -> + let a_unbound = + match a with + | Ty_var v -> if Var.is_empty v then Some v else None + | _ -> None + in + let rec rewrite = function + | Ty_row_empty -> raise Unification_error + | Ty_row_extend ((name', ty'), b) -> + if String.(name = name') then ( + unify_ty ty ty'; + b) + else Ty_row_extend ((name', ty'), rewrite b) + | Ty_var v -> ( + match Var.ty v with + | Some b -> rewrite b + | None -> + let b = Ty.var @@ Var.fresh ~lvl:(Var.lvl v) () in + Var.set_ty v (Ty_row_extend ((name, ty), b)); + b) + | _ -> assert false + in + let b = rewrite b in + (match a_unbound with + | Some v -> + if not (Var.is_empty v) then Type_error.raise Error_recursive_record_type + | _ -> ()); + unify_rows unify_ty a b + | Ty_var av, Ty_var bv -> ( + match (Var.ty av, Var.ty bv) with + | None, None -> union_vars av bv + | Some a, Some b -> unify_rows unify_ty a b + | None, Some ty -> Var.set_ty av ty + | Some ty, None -> Var.set_ty bv ty) + | Ty_var v, b -> ( + match Var.ty v with + | None -> Var.set_ty v b + | Some ty -> unify_rows unify_ty ty b) + | a, Ty_var v -> ( + match Var.ty v with + | None -> Var.set_ty v a + | Some b -> unify_rows unify_ty a b) + | _, _ -> assert false + +and unifiable a b = + try + unify a b; + true + with + | Unification_error -> false + +and meet_rows meet_ty = + let exception Meet_error in + let rec aux a b = + if Debug.log_solve then + Caml.Format.printf "UNIFY %s %s@." (Ty.show a) (Ty.show b); + match (a, b) with + | Ty_row_empty, Ty_row_empty -> Ty_row_empty + | Ty_row_empty, Ty_row_extend _ -> raise Meet_error + | Ty_row_extend ((name, ty), a), b -> + let a_unbound = + match a with + | Ty_var v -> if Var.is_empty v then Some v else None + | _ -> None + in + let rec rewrite = function + | Ty_row_empty -> raise Meet_error + | Ty_row_extend ((name', ty'), b) -> + if String.(name = name') then + let ty = meet_ty ty ty' in + (ty, b) + else + let ty, b = rewrite b in + (ty, Ty_row_extend ((name', ty'), b)) + | Ty_var v -> ( + match Var.ty v with + | Some b -> rewrite b + | None -> + let b = Ty.var @@ Var.fresh ~lvl:(Var.lvl v) () in + Var.set_ty v (Ty_row_extend ((name, ty), b)); + (Ty_var v, b)) + | _ -> assert false + in + let ty, b = rewrite b in + (match a_unbound with + | Some v -> + if not (Var.is_empty v) then + Type_error.raise Error_recursive_record_type + | _ -> ()); + Ty_row_extend ((name, ty), aux a b) + | Ty_var v1, Ty_var v2 -> ( + match (Var.ty v1, Var.ty v2) with + | None, None -> + union_vars v1 v2; + a + | Some a, Some b -> aux a b + | None, Some ty -> + Var.set_ty v1 ty; + ty + | Some ty, None -> + Var.set_ty v2 ty; + ty) + | Ty_var v, ty + | ty, Ty_var v -> ( + match Var.ty v with + | None -> + Var.set_ty v ty; + ty + | Some ty' -> aux ty ty') + | _, _ -> assert false + in + fun a b -> + try Some (aux a b) with + | Meet_error -> None + +(** [greatest_lower_bound' a b] computes a Greatest-Lower-Bound of [a] and [b]. *) +and greatest_lower_bound = + let rec aux a b = + if Debug.log_solve then + Caml.Format.printf "GLB %s %s@." (Ty.show a) (Ty.show b); + if phys_equal a b then a + else + match (a, b) with + | a, Ty_top + | Ty_top, a -> + a + | _, Ty_bot + | Ty_bot, _ -> + Ty_bot + | Ty_const aname, Ty_const bname -> + if String.equal aname bname then a else Ty_bot + | Ty_arr (aargs, aty), Ty_arr (bargs, bty) -> ( + match List.map2 aargs bargs ~f:least_upper_bound with + | Unequal_lengths -> Ty_bot + | Ok args -> Ty_arr (args, aux aty bty)) + | Ty_app (aty, aargs), Ty_app (bty, bargs) -> ( + match List.map2 aargs bargs ~f:aux with + | Unequal_lengths -> Ty_bot + | Ok args -> Ty_app (aux aty bty, args)) + | Ty_record a, Ty_record b -> ( + match meet_rows aux a b with + | Some row -> Ty_record row + | None -> Ty_bot) + | Ty_row_empty, Ty_row_empty -> assert false + | Ty_row_extend _, Ty_row_extend _ -> assert false + | Ty_nullable a, b -> aux a b + | a, Ty_nullable b -> aux a b + | Ty_var v1, Ty_var v2 -> ( + match (Var.ty v1, Var.ty v2) with + | None, None -> + union_vars v1 v2; + a + | Some a, Some b -> aux a b + | None, Some ty -> + Var.set_ty v1 ty; + ty + | Some ty, None -> + Var.set_ty v2 ty; + ty) + | Ty_var v, ty + | ty, Ty_var v -> ( + match Var.ty v with + | None -> + Var.set_ty v ty; + ty + | Some ty' -> aux ty ty') + | _, _ -> Ty_bot + in + aux + +(** [least_upper_bound a b] computes a Least-Upper-Bound of [a] and [b]. *) +and least_upper_bound = + let rec aux a b = + if Debug.log_solve then + Caml.Format.printf "LUB %s & %s@." (Ty.show a) (Ty.show b); + if phys_equal a b then a + else + match (a, b) with + | _, Ty_top + | Ty_top, _ -> + Ty_top + | a, Ty_bot + | Ty_bot, a -> + a + | Ty_const aname, Ty_const bname -> + if String.equal aname bname then a else Ty_top + | Ty_arr (aargs, aty), Ty_arr (bargs, bty) -> ( + match List.map2 aargs bargs ~f:greatest_lower_bound with + | Unequal_lengths -> Ty_top + | Ok args -> Ty_arr (args, aux aty bty)) + | Ty_app (aty, aargs), Ty_app (bty, bargs) -> ( + match List.map2 aargs bargs ~f:aux with + | Unequal_lengths -> Ty_top + | Ok args -> Ty_app (aux aty bty, args)) + | Ty_record a, Ty_record b -> ( + match meet_rows aux a b with + | Some row -> Ty_record row + | None -> Ty_top) + | Ty_row_empty, Ty_row_empty -> assert false + | Ty_row_extend _, Ty_row_extend _ -> assert false + | Ty_nullable a, b -> Ty.nullable (aux a b) + | a, Ty_nullable b -> Ty.nullable (aux a b) + | Ty_var v1, Ty_var v2 -> ( + match (Var.ty v1, Var.ty v2) with + | None, None -> + union_vars v1 v2; + a + | Some a, Some b -> aux a b + | None, Some ty -> + Var.set_ty v1 ty; + ty + | Some ty, None -> + Var.set_ty v2 ty; + ty) + | Ty_var v, ty + | ty, Ty_var v -> ( + match Var.ty v with + | None -> + Var.set_ty v ty; + ty + | Some ty' -> aux ty ty') + | _, _ -> Ty_top + in + aux + +and union_vars a b = + Var.union ~merge_lower:least_upper_bound ~merge_upper:greatest_lower_bound a b + +let is_subtype ~sub_ty ~super_ty = + let lub = least_upper_bound sub_ty super_ty in + Ty.equal lub super_ty + +module Constraint_set : sig + type t + (* A set of constraints (mutable). *) + + val empty : unit -> t + (** Create new constraint set. *) + + val add : t -> var -> ty * ty -> unit + (** [add cset v (lower, upper)] adds a new constraint for variable [v] with + [lower] and [upper] bounds. *) + + val solve : t -> unit + (** [solve cset] solves all constraints in [cset], raising [Type_error] if + it's unable to solve it. *) +end = struct + module Elem = struct + type t = var + + let layout v = + let open Layout in + let v' = Union_find.value v in + let* lower = layout_ty' v'.lower in + let* v = Var.layout v in + let* upper = layout_ty' v'.upper in + return (lower ^^ string " <: " ^^ v ^^ string " <: " ^^ upper) + + include ( + Showable (struct + type nonrec t = t + + let layout c = Layout.render (layout c) + end) : + SHOWABLE with type t := t) + end + + type t = var list ref + (** Set of constraints on type variables. For each type variable we have a + lower and an upper bound. *) + + let layout set = + let open Layout in + let* items = list_map !set ~f:Elem.layout in + let sep = string ", " in + return (braces (separate sep items)) + + include ( + Showable (struct + type nonrec t = t + + let layout v = Layout.render (layout v) + end) : + SHOWABLE with type t := t) + + let empty () = ref [] + + let add cs v (lower, upper) = + if Debug.log_solve then + Caml.Format.printf "ADD %s <: %s <: %s@." (Ty.show lower) (Var.show v) + (Ty.show upper); + cs := v :: !cs; + let v' = Union_find.value v in + v'.lower <- least_upper_bound v'.lower lower; + v'.upper <- greatest_lower_bound v'.upper upper + + let ensure_is_subtype ~sub_ty ~super_ty = + if not (is_subtype ~sub_ty ~super_ty) then + Type_error.raise_not_a_subtype ~sub_ty ~super_ty () + + let solve set = + set := List.dedup_and_sort ~compare:Var.compare !set; + if Debug.log_solve then Caml.Format.printf "SOLVE %s@." (show set); + let solve v = + if Debug.log_solve then Caml.Format.printf "LOOKING %s@." (Elem.show v); + let v' = Union_find.value v in + match (v'.variance, v'.lower, v'.ty, v'.upper) with + | _, _, Some _, _ -> failwith "constraints against a resolved var" + | (None | Some Covariant), lower, None, upper -> + ensure_is_subtype ~sub_ty:lower ~super_ty:upper; + Var.set_ty v lower + | Some Contravariant, lower, None, Ty_top -> Var.set_ty v lower + | Some Contravariant, lower, None, upper -> + ensure_is_subtype ~sub_ty:lower ~super_ty:upper; + Var.set_ty v upper + | Some Invariant, Ty_bot, None, Ty_top -> assert false + | Some Invariant, Ty_bot, None, upper -> + (* TODO: not sure this case is ok *) + Var.set_ty v upper + | Some Invariant, lower, None, Ty_top -> + (* TODO: not sure this case is ok *) + Var.set_ty v lower + | Some Invariant, lower, None, upper -> + if not (unifiable lower upper) then + Type_error.raise (Error_not_equal (lower, upper)); + Var.set_ty v lower + in + List.iter !set ~f:(fun v -> solve v) +end diff --git a/bidi_local/syntax.ml b/bidi_local/syntax.ml new file mode 100644 index 0000000..142f5ce --- /dev/null +++ b/bidi_local/syntax.ml @@ -0,0 +1,457 @@ +open! Import +include Syntax0 + +let kw = Layout.bold + +let punct = Layout.fg White + +let record_braces doc = Layout.(punct (string "{") ^^ doc ^^ punct (string "}")) + +let ty_brackets doc = Layout.(punct (string "[") ^^ doc ^^ punct (string "]")) + +let f_parens doc = Layout.(punct (string "(") ^^ doc ^^ punct (string ")")) + +let rec layout_expr' (loc, expr) : Layout.layout = + let open Layout in + let is_simple_expr (_, expr) = + match expr with + | E_var _ + | E_app _ + | E_record _ -> + true + | _ -> false + in + match expr with + | E_ann (expr, ty_sch) -> + let* ty_sch = layout_ty_sch' ty_sch in + let* expr = layout_expr' expr in + return + (group + (f_parens (align (expr ^^ break 1 ^^ punct (string ": ") ^^ ty_sch)))) + | E_var name -> return (string name) + | E_abs (vs, args, body) -> + let sep = punct comma ^^ blank 1 in + let* vs = + match vs with + | [] -> return empty + | vs -> + let+ items = list_map vs ~f:layout_poly_var' in + ty_brackets (separate sep items) + in + let newline = + (* Always break on let inside the body. *) + match body with + | _, E_let _ -> hardline + | _ -> break 1 + in + let* args = + match args with + | [ ((_loc, name), None) ] -> return (string name) + | args -> + let layout_arg = function + | (_loc, name), None -> return (string name) + | (_loc, name), Some ty -> + let* ty = layout_ty' ty in + return (string name ^^ punct (string ": ") ^^ ty) + in + let+ items = list_map args ~f:layout_arg in + f_parens (separate sep items) + in + let* body = layout_expr' body in + return + (group + (group + (kw (string "fun") + ^^ vs + ^^ string " " + ^^ args + ^^ punct (string " ->")) + ^^ nest 2 (group (newline ^^ group body)))) + | E_app (f, args) -> + let sep = punct comma ^^ break 1 in + let* f = layout_expr' f in + let* args = list_map args ~f:layout_expr' in + return + (group (f ^^ f_parens (nest 2 (group (break 0 ^^ separate sep args))))) + | E_let _ as e -> + let es = + (* We do not want to print multiple nested let-expression with indents and + therefore we linearize them first and print on the same indent instead. *) + let rec linearize es e = + match e with + | E_let (_, (_, b)) -> linearize (e :: es) b + | e -> e :: es + in + List.rev (linearize [] e) + in + let newline = + (* If there's more than a single let-expression found (checking length > 2 + because es containts the body of the last let-expression too) we split + them with a hardline. *) + if List.length es > 2 then hardline else break 1 + in + let+ items = + list_map es ~f:(function + | E_let ((name, expr, ty), _) -> + let* ascription = + (* We need to layout ty_sch first as it will allocate names for use + down the road. *) + match ty with + | None -> return empty + | Some ty -> + let ty_break, ty_nest = + match ty with + | [], Ty_record _ -> (space, 0) + | _ -> (break 1, 4) + in + let+ ty = layout_ty_sch' ty in + punct (string " :") ^^ nest ty_nest (ty_break ^^ ty) + in + let expr_newline, expr_nest = + (* If there's [let x = let y = ... in ... in ...] then we want to + force break. *) + match expr with + | _, E_let _ -> (hardline, 2) + | _, E_record _ -> (space, 0) + | _ -> (break 1, 2) + in + let* expr = layout_expr' expr in + return + (group + (group + (kw (string "let ") + ^^ string name + ^^ ascription + ^^ punct (string " =")) + ^^ nest expr_nest (expr_newline ^^ expr) + ^^ expr_newline + ^^ kw (string "in")) + ^^ newline) + | e -> layout_expr' (loc, e)) + in + concat items + | E_record fields -> + let layout_field (name, expr) = + let+ expr = layout_expr' expr in + string name ^^ punct (string " = ") ^^ expr + in + let+ fields = Layout.list_map fields ~f:layout_field in + let sep = punct comma ^^ break 1 in + group (record_braces (nest 2 (break 0 ^^ separate sep fields) ^^ break 0)) + | E_record_project (expr, name) -> + if is_simple_expr expr then + let+ expr = layout_expr' expr in + expr ^^ dot ^^ string name + else + let+ expr = layout_expr' expr in + f_parens expr ^^ dot ^^ string name + | E_record_extend (expr, fields) -> + let layout_field (name, expr) = + let+ expr = layout_expr' expr in + string name ^^ punct (string " = ") ^^ expr + in + let* expr = layout_expr' expr in + let* fields = Layout.list_map fields ~f:layout_field in + let sep = punct comma ^^ break 1 in + return (record_braces (expr ^^ kw (string " with ") ^^ separate sep fields)) + | E_record_update (expr, fields) -> + let layout_field (name, expr) = + let+ expr = layout_expr' expr in + string name ^^ punct (string " := ") ^^ expr + in + let* expr = layout_expr' expr in + let* fields = Layout.list_map fields ~f:layout_field in + let sep = punct comma ^^ break 1 in + return (record_braces (expr ^^ kw (string " with ") ^^ separate sep fields)) + | E_lit (Lit_string v) -> return (dquotes (string v)) + | E_lit (Lit_int v) -> return (dquotes (string (Int.to_string v))) + +and layout_ty' ty = + let open Layout in + let rec is_ty_row_empty = function + | Ty_row_empty -> true + | Ty_bot -> true + | Ty_var var -> ( + match (Union_find.value var).ty with + | None -> false + | Some ty -> is_ty_row_empty ty) + | _ -> false + in + let rec is_ty_arr = function + | Ty_var var -> ( + match (Union_find.value var).ty with + | None -> false + | Some ty -> is_ty_arr ty) + | Ty_arr _ -> true + | _ -> false + in + let rec layout_ty = function + | Ty_const name -> return (string name) + | Ty_arr ([ aty ], rty) -> + (* Check if we can layout this as simply as [aty -> try] in case of a + single argument. *) + let is_ty_arr_to_the_left = is_ty_arr aty in + let* aty = layout_ty aty in + let* rty = layout_ty rty in + return + ((if is_ty_arr_to_the_left then + (* If the single arg is the Ty_arr we need to wrap it in parens. *) + f_parens aty + else aty) + ^^ punct (string " -> ") + ^^ rty) + | Ty_arr (atys, rty) -> + let sep = punct comma ^^ blank 1 in + let* atys = list_map atys ~f:layout_ty in + let* rty = layout_ty rty in + return (f_parens (separate sep atys) ^^ punct (string " -> ") ^^ rty) + | Ty_app (fty, atys) -> + let sep = punct comma ^^ blank 1 in + let* fty = layout_ty fty in + let* atys = list_map atys ~f:layout_ty in + return (fty ^^ ty_brackets (separate sep atys)) + | Ty_nullable ty -> + let* ty = layout_ty ty in + return (ty ^^ string "?") + | Ty_var var -> ( + match (Union_find.value var).ty with + | None -> layout_var' var + | Some ty -> layout_ty ty) + | Ty_record ty_row -> + let+ ty_row = layout_ty_row ty_row in + group (record_braces (nest 2 (break 0 ^^ ty_row) ^^ break 0)) + | (Ty_row_empty | Ty_row_extend _) as ty -> + let+ doc = layout_ty_row ty in + group doc + | Ty_bot -> return (string "⊥") + | Ty_top -> return (string "⊤") + and layout_ty_row = function + | Ty_row_extend ((name, ty), next) -> + let* field = + let+ ty = layout_ty ty in + string name ^^ punct (string ": ") ^^ ty + in + if is_ty_row_empty next then return field + else + let* next = layout_ty_row next in + let sep = punct comma ^^ break 1 in + return (field ^^ sep ^^ next) + | Ty_row_empty -> return empty + | Ty_var var -> ( + match (Union_find.value var).ty with + | None -> + let+ var = layout_var' var in + punct (string "...") ^^ var + | Some ty -> layout_ty_row ty) + | Ty_const name -> return (string name) + | ty -> + Caml.Format.printf "%a@." Sexp.pp_hum (sexp_of_ty ty); + assert false + in + layout_ty ty + +and layout_unbound_var v = + let open Layout in + let+ variance = + match v.variance with + | None -> return empty + | Some v -> layout_variance' v + in + variance ^^ string (Printf.sprintf "_%i" v.id) + +and layout_var' v = + let open Layout in + let v' = Union_find.value v in + match v'.ty with + | Some ty -> layout_ty' ty + | None -> + let* name = lookup_var v in + let+ doc = + match name with + | Some name -> return (string name) + | None -> layout_unbound_var v' + in + if Debug.log_levels then doc ^^ layout_v_debug' v else doc + +and layout_ty_sch' (ty_sch : ty_sch) = + let open Layout in + match ty_sch with + | [], ty -> layout_ty' ty + | vs, ty -> + let* vs = layout_var_prenex' vs in + let* ty = layout_ty' ty in + return (group (vs ^^ ty)) + +and layout_var_prenex' vs = + let open Layout in + let sep = punct comma ^^ blank 1 in + let* vs = list_map vs ~f:layout_poly_var' in + return (separate sep vs ^^ punct (string " . ")) + +and layout_poly_var' v : Layout.layout = + let open Layout in + let+ doc = alloc_var v in + if Debug.log_levels then string doc ^^ layout_v_debug' v else string doc + +and layout_v_debug' v = + let open Layout in + let v = Union_find.value v in + let lvl = Option.value v.lvl ~default:(-1) in + string ("{" ^ Int.to_string v.id ^ "}" ^ "@" ^ Int.to_string lvl) + +and layout_variance' v = + let open Layout in + Layout.return + (match v with + | Covariant -> string "+" + | Contravariant -> string "-" + | Invariant -> string "=") + +module Expr = struct + type t = expr + + let loc (loc, _) = loc + + let layout = layout_expr' + + include ( + Showable (struct + type t = expr + + let layout e = Layout.render (layout e) + end) : + SHOWABLE with type t := t) + + include ( + Dumpable (struct + type t = expr + + let sexp_of_t = sexp_of_expr + end) : + DUMPABLE with type t := t) +end + +module Ty = struct + type t = ty + + let arr a b = Ty_arr (a, b) + + let var var = Ty_var var + + let nullable ty = + match ty with + | Ty_nullable _ -> ty + | Ty_bot -> ty + | ty -> Ty_nullable ty + + let rec equal a b = + match (a, b) with + | Ty_const a, Ty_const b -> String.equal a b + | Ty_var a, Ty_var b -> ( + Union_find.equal a b + || + let a = Union_find.value a + and b = Union_find.value b in + match (a.ty, b.ty) with + | None, None -> Int.equal a.id b.id + | Some a, Some b -> equal a b + | _ -> false) + | Ty_var v, b -> ( + match (Union_find.value v).ty with + | None -> false + | Some a -> equal a b) + | a, Ty_var v -> ( + match (Union_find.value v).ty with + | None -> false + | Some b -> equal a b) + | Ty_app (a, args), Ty_app (b, brgs) -> ( + equal a b + && + match List.for_all2 args brgs ~f:equal with + | Unequal_lengths -> false + | Ok v -> v) + | Ty_nullable a, Ty_nullable b -> equal a b + | Ty_arr (args, a), Ty_arr (brgs, b) -> ( + equal a b + && + match List.for_all2 args brgs ~f:equal with + | Unequal_lengths -> false + | Ok v -> v) + | Ty_record row1, Ty_record row2 -> equal row1 row2 + | Ty_row_empty, Ty_row_empty -> true + | Ty_row_extend ((name1, ty1), row1), Ty_row_extend ((name2, ty2), row2) -> + String.equal name1 name2 && equal ty1 ty2 && equal row1 row2 + | Ty_bot, Ty_bot -> true + | Ty_top, Ty_top -> true + | _, _ -> false + + let layout = layout_ty' + + include ( + Showable (struct + type t = ty + + let layout ty = Layout.render (layout ty) + end) : + SHOWABLE with type t := t) + + include ( + Dumpable (struct + type t = ty + + let sexp_of_t = sexp_of_ty + end) : + DUMPABLE with type t := t) +end + +module Ty_sch = struct + type t = ty_sch + + let layout = layout_ty_sch' + + include ( + Showable (struct + type t = ty_sch + + let layout ty_sch = Layout.render (layout ty_sch) + end) : + SHOWABLE with type t := t) + + include ( + Dumpable (struct + type t = ty_sch + + let sexp_of_t = sexp_of_ty_sch + end) : + DUMPABLE with type t := t) +end + +module Variance = struct + type t = variance + + let inv = function + | Covariant -> Contravariant + | Contravariant -> Covariant + | Invariant -> Invariant + + let join a b = + match (a, b) with + | Invariant, _ + | _, Invariant -> + Invariant + | Covariant, Contravariant + | Contravariant, Covariant -> + Invariant + | Covariant, Covariant -> Covariant + | Contravariant, Contravariant -> Contravariant + + let layout = layout_variance' + + include ( + Showable (struct + type nonrec t = t + + let layout v = Layout.render (layout v) + end) : + SHOWABLE with type t := t) +end diff --git a/bidi_local/syntax0.ml b/bidi_local/syntax0.ml new file mode 100644 index 0000000..6283f38 --- /dev/null +++ b/bidi_local/syntax0.ml @@ -0,0 +1,61 @@ +open! Import + +type name = string [@@deriving sexp_of] + +and id = int + +and lvl = int + +type variance = Covariant | Contravariant | Invariant [@@deriving sexp_of] + +type ty = + | Ty_const of name + | Ty_var of var + | Ty_app of ty * ty list + | Ty_nullable of ty + | Ty_arr of ty list * ty + | Ty_record of ty_row + | Ty_row_empty + | Ty_row_extend of (name * ty) * ty_row + | Ty_bot + | Ty_top + +and ty_row = ty + +and var = var' Union_find.t + +and var' = { + id : int; + mutable name : string option; + mutable lvl : lvl option; + mutable ty : ty option; + mutable variance : variance option; + mutable lower : ty; + mutable upper : ty; +} +[@@deriving sexp_of] + +module Location = struct + include Location + + let sexp_of_t _loc = Sexp.Atom "LOC" +end + +type expr = Location.t * exprsyn + +and exprsyn = + | E_var of name + | E_abs of var list * ((Location.t * name) * ty option) list * expr + | E_app of expr * expr list + | E_let of (name * expr * ty_sch option) * expr + | E_lit of lit + | E_ann of expr * ty_sch + | E_record of (name * expr) list + | E_record_project of expr * name + | E_record_extend of expr * (name * expr) list + | E_record_update of expr * (name * expr) list +[@@deriving sexp_of] + +and lit = Lit_string of string | Lit_int of int + +and ty_sch = var list * ty diff --git a/bidi_local/test/dune b/bidi_local/test/dune new file mode 100644 index 0000000..eab027f --- /dev/null +++ b/bidi_local/test/dune @@ -0,0 +1,6 @@ +(library + (name test_bidi_local) + (inline_tests) + (preprocess + (pps ppx_expect)) + (libraries base bidi_local)) diff --git a/bidi_local/test/test_infer.ml b/bidi_local/test/test_infer.ml new file mode 100644 index 0000000..2651447 --- /dev/null +++ b/bidi_local/test/test_infer.ml @@ -0,0 +1,1235 @@ +open Base +open Bidi_local + +let env = + Env.empty + |> Env.assume_val "null" "a . a?" + |> Env.assume_val "one" "int" + |> Env.assume_val "nil" "a . list[a]" + |> Env.assume_val "cons" "a . (a, list[a]) -> list[a]" + |> Env.assume_val "map" "a, b . (a -> b, list[a]) -> list[b]" + |> Env.assume_val "choose" "a . (a, a) -> a" + |> Env.assume_val "choose3" "a . (a, a, a) -> a" + |> Env.assume_val "choose4" "a . (a, a, a, a) -> a" + |> Env.assume_val "hello" "string" + |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" + |> Env.assume_val "plus" "(int, int) -> int" + |> Env.assume_val "true" "bool" + |> Env.assume_val "ifnull" "a . (a?, a) -> a" + |> Env.assume_val "eq" "a . (a, a) -> bool" + +let infer ~env code = + Var.reset (); + let code = String.strip code in + let prog = Expr.parse_string code in + match infer ~env prog with + | Ok e -> Caml.Format.printf "%s@.|" (Expr.show e) + | Error err -> + let report = Type_error.to_report err in + Caml.Format.printf "%a@.|" Location.print_report report + +let%expect_test "" = + infer ~env "choose(one, one)"; + [%expect {| + (choose(one, one) : int) + | |}] + +let%expect_test "" = + infer ~env "choose(null, one)"; + [%expect {| + (choose(null, one) : int?) + | |}] + +let%expect_test "" = + infer ~env "choose(one, null)"; + [%expect {| + (choose(one, null) : int?) + | |}] + +let%expect_test "" = + infer ~env "choose(null, null)"; + [%expect {| + (choose(null, null) : b . b?) + | |}] + +let%expect_test "" = + infer ~env + {| + let f = fun (cb: int? -> int) -> cb(one) in + f(fun z -> ifnull(z, one)) + |}; + [%expect + {| + (let f : (int? -> int) -> int = + fun (cb: int? -> int) -> cb(one) + in + f(fun (z: int?) -> ifnull(z, one)) + : int) + | |}] + +let%expect_test "" = + infer ~env "map(fun x -> plus(x, x), cons(one, nil))"; + [%expect + {| + (map(fun (x: int) -> plus(x, x), cons(one, nil)) + : list[int]) + | |}] + +let%expect_test "" = + infer ~env "(null : int?)"; + [%expect {| + (null : int?) + | |}] + +let env = + env + |> Env.assume_val "fix" "a . (a -> a) -> a" + |> Env.assume_val "head" "a . list[a] -> a" + |> Env.assume_val "tail" "a . list[a] -> list[a]" + |> Env.assume_val "nil" "a . list[a]" + |> Env.assume_val "cons" "a . (a, list[a]) -> list[a]" + |> Env.assume_val "cons_curry" "a . a -> list[a] -> list[a]" + |> Env.assume_val "map" "a, b . (a -> b, list[a]) -> list[b]" + |> Env.assume_val "map_curry" "a, b . (a -> b) -> list[a] -> list[b]" + |> Env.assume_val "one" "int" + |> Env.assume_val "zero" "int" + |> Env.assume_val "succ" "int -> int" + |> Env.assume_val "plus" "(int, int) -> int" + |> Env.assume_val "eq" "a . (a, a) -> bool" + |> Env.assume_val "eq_curry" "a . a -> a -> bool" + |> Env.assume_val "not" "bool -> bool" + |> Env.assume_val "true" "bool" + |> Env.assume_val "false" "bool" + |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" + |> Env.assume_val "pair_curry" "a, b . a -> b -> pair[a, b]" + |> Env.assume_val "first" "a, b . pair[a, b] -> a" + |> Env.assume_val "second" "a, b . pair[a, b] -> b" + |> Env.assume_val "id" "a . a -> a" + |> Env.assume_val "const" "a, b . a -> b -> a" + |> Env.assume_val "apply" "a, b . (a -> b, a) -> b" + |> Env.assume_val "apply_curry" "a, b . (a -> b) -> a -> b" + |> Env.assume_val "choose" "a . (a, a) -> a" + |> Env.assume_val "choose_curry" "a . a -> a -> a" + |> Env.assume_val "age" "int" + |> Env.assume_val "world" "string" + |> Env.assume_val "print" "string -> string" + +let%expect_test "" = + infer ~env "world"; + [%expect {| + (world : string) + | |}] + +let%expect_test "" = + infer ~env "print"; + [%expect {| + (print : string -> string) + | |}] + +let%expect_test "" = + infer ~env "let x = world in x"; + [%expect {| + let x : string = world in + x + | |}] + +let%expect_test "" = + infer ~env "fun () -> world"; + [%expect {| + (fun () -> world : () -> string) + | |}] + +let%expect_test "" = + infer ~env "let x = fun () -> world in world"; + [%expect + {| + (let x : () -> string = fun () -> world in world : string) + | |}] + +let%expect_test "" = + infer ~env "let x = fun () -> world in x"; + [%expect + {| + let x : () -> string = fun () -> world in + x + | |}] + +let%expect_test "" = + infer ~env "print(world)"; + [%expect {| + (print(world) : string) + | |}] + +let%expect_test "" = + infer ~env "let hello = fun (msg: string) -> print(msg) in hello(world)"; + [%expect + {| + (let hello : string -> string = + fun (msg: string) -> print(msg) + in + hello(world) + : string) + | |}] + +let%expect_test "" = + infer ~env + "(fun x -> let y : a . a -> a = fun z -> z in y : a, b . a -> b -> b)"; + [%expect + {| + (fun (x: a) -> + let y : a/1 . a/1 -> a/1 = fun (z: a/1) -> z in y + : a, b . a -> b -> b) + | |}] + +let%expect_test "" = + infer ~env "(fun x -> let y = x in y : a . a -> a)"; + [%expect + {| + (fun (x: a) -> + let y : a = x in y + : a . a -> a) + | |}] + +let%expect_test "" = + infer ~env "fun [a] (x: a) -> let y = fun [b] (z: b) -> x in y"; + [%expect + {| + (fun[a] (x: a) -> + let y : b/1 . b/1 -> a = fun[b/1] (z: b/1) -> x in y + : a, b . a -> b -> a) + | |}] + +let%expect_test "" = + infer ~env "id"; + [%expect {| + (id : b . b -> b) + | |}] + +let%expect_test "" = + infer ~env "one"; + [%expect {| + (one : int) + | |}] + +let%expect_test "" = + infer ~env "x"; + [%expect + {| + Line 1, characters 0-1: + 1 | x + ^ + Error: 'x' is not defined + + | |}] + +let%expect_test "" = + infer ~env "let x = x in x"; + [%expect + {| + Line 1, characters 8-9: + 1 | let x = x in x + ^ + Error: 'x' is not defined + + | |}] + +let%expect_test "" = + infer ~env "let x = id in x"; + [%expect {| + let x : b . b -> b = id in + x + | |}] + +let%expect_test "" = + infer ~env "let x : a . a -> a = fun y -> y in x"; + [%expect + {| + let x : a . a -> a = fun (y: a) -> y in + x + | |}] + +let%expect_test "" = + infer ~env "(fun x -> x : a . a -> a)"; + [%expect {| + (fun (x: a) -> x : a . a -> a) + | |}] + +let%expect_test "" = + infer ~env "pair"; + [%expect {| + (pair : b, b . (b, b) -> pair[b, b]) + | |}] + +let%expect_test "" = + (* TODO: missing b variable here *) + infer ~env "(fun x -> let y = fun[b](z : b) -> z in y : a, b . a -> b -> b)"; + [%expect + {| + (fun (x: a) -> + let y : b/1 . b/1 -> b/1 = fun[b/1] (z: b/1) -> z in y + : a, b . a -> b -> b) + | |}] + +let%expect_test "" = + infer ~env + {| + let f : a . a -> a = fun x -> x in + let id : a . a -> a = fun y -> y in + eq(f, id) + |}; + [%expect + {| + (let f : a . a -> a = fun (x: a) -> x in + let id : a/1 . a/1 -> a/1 = fun (y: a/1) -> y in + eq(f, id) + : bool) + | |}] + +let%expect_test "" = + infer ~env + {| + let f : a . a -> a = fun x -> one in + let id : a . a -> a = fun y -> true in + choose(f, id) + |}; + [%expect + {| + (let f : int -> int = fun (x: int) -> one in + let id : bool -> bool = fun (y: bool) -> true in + choose(f, id) + : ⊥ -> ⊤) + | |}] + +let%expect_test "" = + infer ~env + "let f : a . a -> a = fun x -> x in let id : b . b -> b = fun y -> y in \ + eq_curry(f)(id)"; + [%expect + {| + (let f : a . a -> a = fun (x: a) -> x in + let id : b . b -> b = fun (y: b) -> y in + eq_curry(f)(id) + : bool) + | |}] + +let%expect_test "" = + infer ~env + {| + let f : a . a -> a = fun x -> one in + let id : b . b -> b = fun y -> true in + choose_curry(f)(id) + |}; + [%expect + {| + Line 1, characters 100-102: + -1 | .................................... + 0 | .......................................... + 1 | ....................id. + Error: type bool -> bool is not a subtype of int -> int + + | |}] + +let%expect_test "" = + infer ~env "let f : a . a -> a = fun x -> x in eq(f, succ)"; + [%expect + {| + (let f : a . a -> a = fun (x: a) -> x in eq(f, succ) : bool) + | |}] + +let%expect_test "" = + infer ~env "let f : a . a -> a = fun x -> x in eq_curry(f)(succ)"; + [%expect + {| + (let f : a . a -> a = fun (x: a) -> x in + eq_curry(f)(succ) + : bool) + | |}] + +let%expect_test "" = + infer ~env "let f : a . a -> a = fun x -> x in pair(f(one), f(true))"; + [%expect + {| + (let f : a . a -> a = fun (x: a) -> x in + pair(f(one), f(true)) + : pair[int, bool]) + | |}] + +let%expect_test "" = + infer ~env + {| + let f : a . (a, a) -> bool = + fun (x, y) -> let a = eq(x, y) in eq(x, y) + in f + |}; + [%expect + {| + let f : a . (a, a) -> bool = + fun (x: a, y: a) -> + let a : bool = eq(x, y) in eq(x, y) + in + f + | |}] + +let%expect_test "" = + infer ~env + {| + let f : a . (a, a) -> bool = + fun (x, y) -> let a = eq_curry(x)(y) in eq_curry(x)(y) + in f + |}; + [%expect + {| + let f : a . (a, a) -> bool = + fun (x: a, y: a) -> + let a : bool = eq_curry(x)(y) in eq_curry(x)(y) + in + f + | |}] + +let%expect_test "" = + infer ~env "id(id)"; + [%expect {| + (id(id) : b . b -> b) + | |}] + +let%expect_test "" = + infer ~env "choose(fun (x, y) -> x, fun (x, y) -> y)"; + [%expect + {| + (choose(fun (x: b, y: b) -> x, fun (x: b, y: b) -> y) + : b . (b, b) -> b) + | |}] + +let%expect_test "" = + infer ~env "choose_curry(fun (x, y) -> x)(fun (x, y) -> y)"; + [%expect + {| + (choose_curry(fun (x: b, y: b) -> x)(fun (x: b, y: b) -> y) + : b . (b, b) -> b) + | |}] + +let%expect_test "" = + infer ~env "let x = id in let y = let z = x(id) in z in y"; + [%expect + {| + (let x : b . b -> b = id in + let y : b . b -> b = + let z : b . b -> b = x(id) in + z + in + y + : b . b -> b) + | |}] + +let%expect_test "" = + infer ~env "cons(id, nil)"; + [%expect {| + (cons(id, nil) : b . list[b -> b]) + | |}] + +let%expect_test "" = + infer ~env "cons_curry(id)(nil)"; + [%expect {| + (cons_curry(id)(nil) : b . list[b -> b]) + | |}] + +let%expect_test "" = + infer ~env "let lst1 = cons(id, nil) in let lst2 = cons(succ, lst1) in lst2"; + [%expect + {| + (let lst1 : b . list[b -> b] = cons(id, nil) in + let lst2 : list[int -> int] = cons(succ, lst1) in + lst2 + : list[int -> int]) + | |}] + +let%expect_test "" = + infer ~env "cons_curry(id)(cons_curry(succ)(cons_curry(id)(nil)))"; + [%expect + {| + (cons_curry(id)(cons_curry(succ)(cons_curry(id)(nil))) + : list[int -> int]) + | |}] + +let%expect_test "" = + infer ~env "plus(one, true)"; + [%expect + {| + Line 1, characters 10-14: + 1 | plus(one, true) + ^^^^ + Error: type bool is not a subtype of int + + | |}] + +let%expect_test "" = + infer ~env "plus(one)"; + [%expect + {| + Line 1, characters 0-9: + 1 | plus(one) + ^^^^^^^^^ + Error: arity mismatch + + | |}] + +let%expect_test "" = + infer ~env "(fun x -> let y = x in y : a . a -> a)"; + [%expect + {| + (fun (x: a) -> + let y : a = x in y + : a . a -> a) + | |}] + +let%expect_test "" = + infer ~env + {| + (fun x -> let y = let z = x(fun x -> x) in z in y + : a, b . ((a -> a) -> b) -> b) + |}; + [%expect + {| + (fun (x: (a -> a) -> b) -> + let y : b = + let z : b = x(fun (x: a) -> x) in + z + in + y + : a, b . ((a -> a) -> b) -> b) + | |}] + +let%expect_test "" = + infer ~env + {| + (fun x -> fun y -> let x = x(y) in x(y) + : a, b . (a -> a -> b) -> a -> b) + |}; + [%expect + {| + (fun (x: a -> a -> b) -> + fun (y: a) -> + let x : a -> b = x(y) in x(y) + : a, b . (a -> a -> b) -> a -> b) + | |}] + +let%expect_test "" = + infer ~env + {| + fun[a, b] (x: a -> b) -> let y = fun [c] (z: c) -> x(z) in y + |}; + [%expect + {| + (fun[a, b] (x: a -> b) -> + let y : a -> b = fun[a] (z: a) -> x(z) in y + : b, a . (a -> b) -> a -> b) + | |}] + +let%expect_test "" = + infer ~env {| + fun[a] (x: a) -> let y = fun [b] (z: b) -> x in y + |}; + [%expect + {| + (fun[a] (x: a) -> + let y : b/1 . b/1 -> a = fun[b/1] (z: b/1) -> x in y + : a, b . a -> b -> a) + | |}] + +let%expect_test "" = + infer ~env + {| + fun[a, b](x: a -> b) -> + fun [c, d](y: c -> d) -> + let x = x(y) in fun [e](x: e) -> y(x) + |}; + [%expect + {| + (fun[b] (x: (e -> d) -> b) -> + fun[e, d] (y: e -> d) -> + let x : b = x(y) in fun[e] (x: e) -> y(x) + : b, d, e . ((e -> d) -> b) -> (e -> d) -> e -> d) + | |}] + +let%expect_test "" = + infer ~env {| + fun[a](x: a) -> let y = fun[a](z: a) -> z in y(y) + |}; + [%expect + {| + (fun[a] (x: a) -> + let y : a/2 . a/2 -> a/2 = fun[a/2] (z: a/2) -> z in y(y) + : a, a/1 . a -> a/1 -> a/1) + | |}] + +let%expect_test "" = + infer ~env "one(id)"; + [%expect + {| + Line 1, characters 0-3: + 1 | one(id) + ^^^ + Error: expected a function + + | |}] + +let%expect_test "" = + infer ~env + {| + fun[a, b](f: a -> b) -> + let x = + fun (g: a -> b, y: a) -> + let _ = g(y) in eq(f, g) + in x + |}; + [%expect + {| + (fun[a, b] (f: a -> b) -> + let x : (a -> b, a) -> bool = + fun (g: a -> b, y: a) -> + let _ : b = g(y) in eq(f, g) + in + x + : a, b . (a -> b) -> (a -> b, a) -> bool) + | |}] + +let%expect_test "" = + infer ~env + {| + let const : a, b . b -> a -> b = fun x -> fun y -> x in const + |}; + [%expect + {| + let const : a, b . b -> a -> b = + fun (x: b) -> fun (y: a) -> x + in + const + | |}] + +let%expect_test "" = + infer ~env + {| + let apply : a, b . (a -> b, a) -> b = + fun (f, x) -> f(x) + in apply + |}; + [%expect + {| + let apply : a, b . (a -> b, a) -> b = + fun (f: a -> b, x: a) -> f(x) + in + apply + | |}] + +let%expect_test "" = + infer ~env + {| + let apply_curry : a, b . (a -> b) -> a -> b = + fun f -> fun x -> f(x) + in apply_curry + |}; + [%expect + {| + let apply_curry : a, b . (a -> b) -> a -> b = + fun (f: a -> b) -> fun (x: a) -> f(x) + in + apply_curry + | |}] + +let%expect_test "" = + infer ~env {| + {a = one, b = one} + |}; + [%expect {| + ({a = one, b = one} : {a: int, b: int}) + | |}] + +let%expect_test "" = + infer ~env {| + {a = one, b = one}.a + |}; + [%expect {| + ({a = one, b = one}.a : int) + | |}] + +let%expect_test "" = + infer ~env {| + {a = one, b = one}.b + |}; + [%expect {| + ({a = one, b = one}.b : int) + | |}] + +let%expect_test "" = + infer ~env + {| + let extend_a[r, a](data : {...r}, v : a) = + {data with a = v} + in + extend_a({}, one) + |}; + [%expect + {| + (let extend_a : r, a . ({...r}, a) -> {a: a, ...r} = + fun[r, a] (data: {...r}, v: a) -> {data with a = v} + in + extend_a({}, one) + : {a: int}) + | |}] + +let%expect_test "" = + infer ~env + {| + let extend_a[r, a](data : {...r}, v : a) = + {data with a = v} + in + extend_a({b = one}, one) + |}; + [%expect + {| + (let extend_a : r, a . ({...r}, a) -> {a: a, ...r} = + fun[r, a] (data: {...r}, v: a) -> {data with a = v} + in + extend_a({b = one}, one) + : {a: int, b: int}) + | |}] + +let%expect_test "" = + infer ~env + {| + let update_a[r, a](data : {a : a, ...r}, v : a) = + {data with a := v} + in + update_a({a = one, b = true}, null) + |}; + [%expect + {| + (let update_a : a, r . ({a: a, ...r}, a) -> {a: a, ...r} = + fun[r, a] (data: {a: a, ...r}, v: a) -> + {data with a := v} + in + update_a({a = one, b = true}, null) + : {a: int?, b: bool}) + | |}] + +let%expect_test "" = + infer ~env + {| + let update_a[r, a](data : {a : a, ...r}, v : int) = + {data with a := plus(data.a, v)} + in + let data = {a = one, b = true} in + update_a(data, one) + |}; + [%expect + {| + (let update_a : + r . ({a: int, ...r}, int) -> {a: int, ...r} = + fun[r] (data: {a: int, ...r}, v: int) -> + {data with a := plus(data.a, v)} + in + let data : {a: int, b: bool} = {a = one, b = true} in + update_a(data, one) + : {a: int, b: bool}) + | |}] + +let%expect_test "" = + infer ~env + {| + let plusg[a](a : a, b : a) = plus(a, b) in + plusg(one, one) + |}; + [%expect + {| + (let plusg : (int, int) -> int = + fun (a: int, b: int) -> plus(a, b) + in + plusg(one, one) + : int) + | |}] + +let%expect_test "" = + infer ~env "{}"; + [%expect {| + ({} : {}) + | |}] + +let%expect_test "" = + infer ~env "({}).x"; + [%expect + {| + Line 1, characters 1-3: + 1 | ({}).x + ^^ + Error: type {} is not a subtype of {x: _2, ..._1} + + | |}] + +let%expect_test "" = + infer ~env "{a = one}"; + [%expect {| + ({a = one} : {a: int}) + | |}] + +let%expect_test "" = + infer ~env "{a = one, b = true}"; + [%expect {| + ({a = one, b = true} : {a: int, b: bool}) + | |}] + +let%expect_test "" = + infer ~env "{b = true, a = one}"; + [%expect {| + ({b = true, a = one} : {b: bool, a: int}) + | |}] + +let%expect_test "" = + infer ~env "({a = one, b = true}).a"; + [%expect {| + ({a = one, b = true}.a : int) + | |}] + +let%expect_test "" = + infer ~env "({a = one, b = true}).b"; + [%expect {| + ({a = one, b = true}.b : bool) + | |}] + +let%expect_test "" = + infer ~env "({a = one, b = true}).c"; + [%expect + {| + Line 1, characters 1-20: + 1 | ({a = one, b = true}).c + ^^^^^^^^^^^^^^^^^^^ + Error: type + {a: int, b: bool} + is not a subtype of + {c: _2, a: int, b: bool, ..._4} + + | |}] + +let%expect_test "" = + infer ~env "{f = fun[a](x: a) -> x}"; + [%expect {| + ({f = fun[a] (x: a) -> x} : a . {f: a -> a}) + | |}] + +let%expect_test "" = + infer ~env "let r = {a = id, b = succ} in choose(r.a, r.b)"; + [%expect + {| + (let r : b . {a: b -> b, b: int -> int} = {a = id, b = succ} in + choose(r.a, r.b) + : int -> int) + | |}] + +let%expect_test "" = + infer ~env "let r = {a = id, b = fun[a](x: a) -> x} in choose(r.a, r.b)"; + [%expect + {| + (let r : a/1, b/2 . {a: b/2 -> b/2, b: a/1 -> a/1} = { + a = id, + b = fun[a/1] (x: a/1) -> x + } in + choose(r.a, r.b) + : a . a -> a) + | |}] + +let%expect_test "" = + infer ~env "choose({a = one}, {})"; + [%expect {| + (choose({a = one}, {}) : ⊤) + | |}] + +let%expect_test "" = + infer ~env "{ { {} with y = one } with x = zero }"; + [%expect + {| + ({{{} with y = one} with x = zero} : {x: int, y: int}) + | |}] + +let%expect_test "" = + infer ~env + "choose({ { {} with y = one } with x = zero }, {x = one, y = zero})"; + [%expect + {| + (choose( + {{{} with y = one} with x = zero}, + {x = one, y = zero}) + : {x: int, y: int}) + | |}] + +let%expect_test "" = + infer ~env "{ {x = one } with x = true }"; + [%expect {| + ({{x = one} with x = true} : {x: bool, x: int}) + | |}] + +let%expect_test "" = + infer ~env "{ {x = one } with x := true }"; + [%expect + {| + Line 1, characters 0-29: + 1 | { {x = one } with x := true } + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + Error: type {x: bool, ..._1} is not a subtype of {x: int} + + | |}] + +let%expect_test "" = + infer ~env "let a = {} in {a with b := one}"; + [%expect + {| + Line 1, characters 14-31: + 1 | let a = {} in {a with b := one} + ^^^^^^^^^^^^^^^^^ + Error: type {b: int, ..._1} is not a subtype of {} + + | |}] + +let%expect_test "" = + infer ~env "let a = {x = one} in ({a with x = true}).x"; + [%expect + {| + (let a : {x: int} = {x = one} in + ({a with x = true}).x + : bool) + | |}] + +let%expect_test "" = + infer ~env "let a = {x = one} in ({a with x := true}).x"; + [%expect + {| + Line 1, characters 22-40: + 1 | let a = {x = one} in ({a with x := true}).x + ^^^^^^^^^^^^^^^^^^ + Error: type {x: bool, ..._3} is not a subtype of {x: int} + + | |}] + +let%expect_test "" = + infer ~env "let a = {x = one} in a.y"; + [%expect + {| + Line 1, characters 21-22: + 1 | let a = {x = one} in a.y + ^ + Error: type {x: int} is not a subtype of {y: _2, x: int, ..._3} + + | |}] + +let%expect_test "" = + infer ~env "fun[a](r: {...a}) -> {r with x = one}"; + [%expect + {| + (fun[a] (r: {...a}) -> {r with x = one} + : a . {...a} -> {x: int, ...a}) + | |}] + +let%expect_test "" = + infer ~env "fun[r, x](r: {x: x, ...r}) -> {r with x := one}"; + [%expect + {| + (fun[r] (r: {x: int, ...r}) -> {r with x := one} + : r . {x: int, ...r} -> {x: int, ...r}) + | |}] + +let%expect_test "" = + infer ~env "let addx = fun[r](r: {...r}) -> {r with x = one} in addx({})"; + [%expect + {| + (let addx : r . {...r} -> {x: int, ...r} = + fun[r] (r: {...r}) -> {r with x = one} + in + addx({}) + : {x: int}) + | |}] + +let%expect_test "" = + infer ~env + "let addx = fun[r, x](r: {x: x, ...r}) -> {r with x := one} in addx({})"; + [%expect + {| + Line 1, characters 67-69: + 1 | let addx = fun[r, x](r: {x: x, ...r}) -> {r with x := one} in addx({}) + ^^ + Error: type {} is not a subtype of {x: int, ...=_6} + + | |}] + +let%expect_test "" = + infer ~env + "let addx = fun[r](r: {...r}) -> {r with x = one} in addx({x = one})"; + [%expect + {| + (let addx : r . {...r} -> {x: int, ...r} = + fun[r] (r: {...r}) -> {r with x = one} + in + addx({x = one}) + : {x: int, x: int}) + | |}] + +let%expect_test "" = + infer ~env + {| + let addx = fun[r, x](r: {x: x, ...r}) -> {r with x := one} in + addx({x = one}) + |}; + [%expect + {| + (let addx : r . {x: int, ...r} -> {x: int, ...r} = + fun[r] (r: {x: int, ...r}) -> {r with x := one} + in + addx({x = one}) + : {x: int}) + | |}] + +let%expect_test "" = + infer ~env "fun[x, r](r: {x: x, ...r}) -> r.x"; + [%expect + {| + (fun[x, r] (r: {x: x, ...r}) -> r.x + : x, r . {x: x, ...r} -> x) + | |}] + +let%expect_test "" = + infer ~env + {| + let get_x = fun[x, r](r: {x: x, ...r}) -> r.x in + get_x({y = one, x = zero}) + |}; + [%expect + {| + (let get_x : x, r . {x: x, ...r} -> x = + fun[x, r] (r: {x: x, ...r}) -> r.x + in + get_x({y = one, x = zero}) + : int) + | |}] + +let%expect_test "" = + infer ~env + {| + let get_x = fun[x, r](r: {x: x, ...r}) -> r.x in + get_x({y = one, z = true}) + |}; + [%expect + {| + Line 1, characters 59-78: + 0 | ................................................ + 1 | ..........{y = one, z = true}. + Error: type + {y: int, z: bool} + is not a subtype of + {x: =_7, y: int, z: bool, ..._10} + + | |}] + +let%expect_test "" = + infer ~env + {| + fun[r](r: {...r}) -> choose({r with x = zero}, {{} with x = one}) + |}; + [%expect + {| + (fun (r: {}) -> choose({r with x = zero}, {{} with x = one}) + : {} -> {x: int}) + | |}] + +let%expect_test "" = + infer ~env + "fun[r](r: {...r}) -> choose({r with x := zero}, {{} with x = one})"; + [%expect + {| + (fun (r: {x: int}) -> + choose({r with x := zero}, {{} with x = one}) + : {x: int} -> {x: int}) + | |}] + +let%expect_test "" = + infer ~env "fun[r](r: {...r}) -> choose({r with x = zero}, {x = one})"; + [%expect + {| + (fun (r: {}) -> choose({r with x = zero}, {x = one}) + : {} -> {x: int}) + | |}] + +let%expect_test "" = + infer ~env + {| + fun[x, r](r: {x: x, ...r}) -> + choose({r with x := zero}, {x = one}) + |}; + [%expect + {| + (fun (r: {x: int}) -> choose({r with x := zero}, {x = one}) + : {x: int} -> {x: int}) + | |}] + +let%expect_test "" = + infer ~env + {| + fun[r](r: {...r}) -> + choose({r with x = zero}, {r with x = one}) + |}; + [%expect + {| + (fun[r] (r: {...r}) -> + choose({r with x = zero}, {r with x = one}) + : r . {...r} -> {x: int, ...r}) + | |}] + +let%expect_test "" = + infer ~env + {| + fun[r](r: {...r}) -> + choose({r with x := zero}, {r with x := one}) + |}; + [%expect + {| + (fun (r: {x: int, ...b}) -> + choose({r with x := zero}, {r with x := one}) + : b . {x: int, ...b} -> {x: int, ...b}) + | |}] + +let%expect_test "" = + infer ~env + {| + fun[r](r: {...r}) -> + choose({r with x = zero}, {r with y = one}) + |}; + [%expect + {| + File "_none_", line 1: + Error: recursive record type + + | |}] + +let%expect_test "" = + infer ~env + {| + fun[r](r: {x: int, y: int, ...r}) -> + choose({r with x := zero}, {r with y := one}) + |}; + [%expect + {| + (fun[r] (r: {x: int, y: int, ...r}) -> + choose({r with x := zero}, {r with y := one}) + : r . {x: int, y: int, ...r} -> {x: int, y: int, ...r}) + | |}] + +let%expect_test "" = + (* TODO: fix printing here... *) + infer ~env "let f = fun[x](x: {...x}) -> x.t(one) in f({t = succ})"; + [%expect + {| + (let f : b, b . {t: int -> b, ...b} -> b = + fun (x: {t: int -> b, ...b}) -> x.t(one) + in + f({t = succ}) + : int) + | |}] + +let%expect_test "" = + (* TODO: fix printing here... *) + infer ~env "let f = fun[r](x: {...r}) -> x.t(one) in f({t = id})"; + [%expect + {| + (let f : b, b . {t: int -> b, ...b} -> b = + fun (x: {t: int -> b, ...b}) -> x.t(one) + in + f({t = id}) + : int) + | |}] + +let%expect_test "" = + infer ~env "{x = one, x = true}"; + [%expect {| + ({x = one, x = true} : {x: int, x: bool}) + | |}] + +let%expect_test "" = + (* TODO: fix printing here... *) + infer ~env + {| + let f = + fun[r](r: {...r}) -> + let y = r.y + in choose(r, {x = one, x = true}) + in f + |}; + [%expect + {| + let f : b, b . {y: b, ...b} -> ⊤ = + fun (r: {y: b, ...b}) -> + let y : b . b = r.y in choose(r, {x = one, x = true}) + in + f + | |}] + +let%expect_test "" = + infer ~env + {| + fun[r](r: {...r}) -> + choose({r with x = zero}, {x = true, x = one}) + |}; + [%expect + {| + (fun (r: {x: int}) -> + choose({r with x = zero}, {x = true, x = one}) + : {x: int} -> {x: ⊤, x: int}) + | |}] + +let%expect_test "" = + infer ~env + "fun[r](r: {...r}) -> choose({r with x := zero}, {x = true, x = one})"; + [%expect + {| + (fun (r: {x: int, x: int}) -> + choose({r with x := zero}, {x = true, x = one}) + : {x: int, x: int} -> {x: ⊤, x: int}) + | |}] + +let%expect_test "" = + infer ~env "fun[r](r: {...r}) -> choose(r, {x = one, x = true})"; + [%expect + {| + (fun (r: {x: int, x: bool}) -> + choose(r, {x = one, x = true}) + : {x: int, x: bool} -> {x: int, x: bool}) + | |}] + +let%expect_test "" = + infer ~env "fun[r](r: {...r}) -> choose({r with x = zero}, {r with x = true})"; + [%expect + {| + (fun[r] (r: {...r}) -> + choose({r with x = zero}, {r with x = true}) + : r . {...r} -> {x: ⊤, ...r}) + | |}] + +let%expect_test "" = + infer ~env + "fun[r](r: {...r}) -> choose({r with x := zero}, {r with x := true})"; + [%expect + {| + Line 1, characters 48-66: + 1 | fun[r](r: {...r}) -> choose({r with x := zero}, {r with x := true}) + ^^^^^^^^^^^^^^^^^^ + Error: type {x: bool, ..._6} is not a subtype of {x: int, ..._4} + + | |}] + +let%expect_test "" = + infer ~env "choose({x=one},{x=one}).x"; + [%expect {| + (choose({x = one}, {x = one}).x : int) + | |}] + +let%expect_test "" = + infer ~env "choose({x=one,y=null},{y=one,x=one}).y"; + [%expect + {| + (choose({x = one, y = null}, {y = one, x = one}).y : int?) + | |}] diff --git a/bidi_local/ty_subst.ml b/bidi_local/ty_subst.ml new file mode 100644 index 0000000..d46af06 --- /dev/null +++ b/bidi_local/ty_subst.ml @@ -0,0 +1,52 @@ +open! Base +open! Syntax + +type t = { + vars : (var * var) list; + names : (name, var, String.comparator_witness) Map.t; +} + +let empty = { vars = []; names = Map.empty (module String) } + +let add_var subst var ty = { subst with vars = (var, ty) :: subst.vars } + +let add_name subst name ty = + { subst with names = Map.set subst.names ~key:name ~data:ty } + +let find_var subst var = List.Assoc.find ~equal:Var.equal subst.vars var + +let find_name subst name = Map.find subst.names name + +let rec apply_ty ~variance subst ty = + match ty with + | Ty_bot + | Ty_top -> + ty + | Ty_const name -> ( + match find_name subst name with + | None -> ty + | Some v -> + Var.set_variance v variance; + Ty_var v) + | Ty_nullable ty -> Ty_nullable (apply_ty ~variance subst ty) + | Ty_app (a, args) -> + Ty_app + (apply_ty ~variance subst a, List.map args ~f:(apply_ty ~variance subst)) + | Ty_arr (args, b) -> + Ty_arr + ( List.map args ~f:(apply_ty ~variance:(Variance.inv variance) subst), + apply_ty ~variance subst b ) + | Ty_var v -> ( + match Var.ty v with + | Some ty -> apply_ty ~variance subst ty + | None -> ( + match find_var subst v with + | Some v -> + Var.set_variance v variance; + Ty_var v + | None -> ty)) + | Ty_record row -> Ty_record (apply_ty ~variance subst row) + | Ty_row_empty -> ty + | Ty_row_extend ((name, ty), row) -> + Ty_row_extend + ((name, apply_ty ~variance subst ty), apply_ty ~variance subst row) diff --git a/bidi_local/ty_subst.mli b/bidi_local/ty_subst.mli new file mode 100644 index 0000000..a697a34 --- /dev/null +++ b/bidi_local/ty_subst.mli @@ -0,0 +1,14 @@ +(** Type variable substitutions. *) + +open! Base +open! Syntax + +type t + +val empty : t + +val add_var : t -> var -> var -> t + +val add_name : t -> name -> var -> t + +val apply_ty : variance:Variance.t -> t -> ty -> ty diff --git a/bidi_local/type_error.ml b/bidi_local/type_error.ml new file mode 100644 index 0000000..589a9de --- /dev/null +++ b/bidi_local/type_error.ml @@ -0,0 +1,100 @@ +open Import +open Syntax + +type t = + | Error_not_a_subtype of { + sub_loc : Location.t; + sub_ty : ty; + super_loc : Location.t; + super_ty : ty; + } + | Error_not_equal of ty * ty + | Error_recursive_type + | Error_recursive_record_type + | Error_unknown_name of { name : Location.t * name } + | Error_missing_type_annotation of { expr : expr } + | Error_expected_a_function of { loc : Location.t; ty : ty } + | Error_arity_mismatch of { loc : Location.t } + +(* let layout = *) +(* let open Layout in *) +(* function *) +(* | Error_not_a_subtype (ty1, ty2) -> *) +(* let* ty1 = Ty.layout ty1 in *) +(* let* ty2 = Ty.layout ty2 in *) +(* return *) +(* (group *) +(* (string "type" *) +(* ^^ nest 2 (break 1 ^^ ty1) *) +(* ^^ break 1 *) +(* ^^ string "is not a subtype of" *) +(* ^^ nest 2 (break 1 ^^ ty2))) *) +(* | Error_not_equal (ty1, ty2) -> *) +(* let* ty1 = Ty.layout ty1 in *) +(* let* ty2 = Ty.layout ty2 in *) +(* return *) +(* (group *) +(* (string "type" *) +(* ^^ nest 2 (break 1 ^^ ty1) *) +(* ^^ break 1 *) +(* ^^ string "is not equal to" *) +(* ^^ nest 2 (break 1 ^^ ty2))) *) +(* | Error_recursive_type -> return (string "recursive type") *) +(* | Error_recursive_record_type -> return (string "recursive record type") *) +(* | Error_unknown_name name -> return (string "unknown name: " ^^ string name) *) +(* | Error_missing_type_annotation expr -> *) +(* let loc, _ = expr in *) +(* let err = Location.error ~loc "missing type annotation" in *) +(* let* expr = Expr.layout expr in *) +(* Caml.Format.printf "ERROR: %a@." Location.print_report err; *) +(* return (string "missing type annotation: " ^^ expr) *) +(* | Error_expected_a_function ty -> *) +(* let* ty = Ty.layout ty in *) +(* return (string "expected a function but got: " ^^ ty) *) +(* | Error_expected_a_record ty -> *) +(* let* ty = Ty.layout ty in *) +(* return (string "expected a record but got: " ^^ ty) *) +(* | Error_arity_mismatch -> return (string "arity mismatch") *) + +let to_report = function + | Error_not_a_subtype info -> + let doc = + let open Layout in + let* sub_ty = Ty.layout info.sub_ty in + let* super_ty = Ty.layout info.super_ty in + return + (group + (string "type" + ^^ nest 2 (break 1 ^^ bold sub_ty) + ^^ break 1 + ^^ string "is not a subtype of" + ^^ nest 2 (break 1 ^^ bold super_ty))) + in + Location.error ~loc:info.sub_loc (Layout.to_string doc) + | Error_not_equal _ -> Location.error "types are not equal" + | Error_recursive_type -> Location.error "recursive type" + | Error_recursive_record_type -> Location.error "recursive record type" + | Error_unknown_name { name = loc, name } -> + Location.errorf ~loc "'%s' is not defined" name + | Error_missing_type_annotation { expr = loc, _ } -> + Location.error ~loc "missing type annotation" + | Error_expected_a_function { loc; ty = _ty } -> + Location.error ~loc "expected a function" + | Error_arity_mismatch { loc } -> Location.error ~loc "arity mismatch" + +(* include ( *) +(* Showable (struct *) +(* type nonrec t = t *) + +(* let layout v = Layout.render (layout v) *) +(* end) : *) +(* SHOWABLE with type t := t) *) + +exception Type_error of t + +let raise error = raise (Type_error error) + +let raise_not_a_subtype + ?(sub_loc = Location.none) ?(super_loc = Location.none) ~sub_ty ~super_ty () + = + raise (Error_not_a_subtype { sub_loc; sub_ty; super_loc; super_ty }) diff --git a/bidi_local/union_find.ml b/bidi_local/union_find.ml new file mode 100644 index 0000000..1c3a50e --- /dev/null +++ b/bidi_local/union_find.ml @@ -0,0 +1,41 @@ +open! Base + +type 'a loc = Root of 'a | Link of 'a t + +and 'a t = 'a loc ref [@@deriving sexp_of] + +let make value = ref (Root value) + +let rec root p : _ t = + match p.contents with + | Root _ -> p + | Link p' -> + let p'' = root p' in + (* Perform path compression. *) + if not (phys_equal p' p'') then p.contents <- p'.contents; + p'' + +let value p = + match (root p).contents with + | Root value -> value + | Link _ -> assert false + +let union ~f a b = + if phys_equal a b then () + else + let a = root a in + let b = root b in + if phys_equal a b then () + else + match (a.contents, b.contents) with + | Root avalue, Root bvalue -> + a.contents <- Link b; + b.contents <- Root (f avalue bvalue) + | Root _, Link _ + | Link _, Root _ + | Link _, Link _ -> + assert false + +let link ~target p = union ~f:(fun _b target -> target) p target + +let equal a b = phys_equal a b || phys_equal (root a) (root b) diff --git a/bidi_local/union_find.mli b/bidi_local/union_find.mli new file mode 100644 index 0000000..a9ecdfb --- /dev/null +++ b/bidi_local/union_find.mli @@ -0,0 +1,36 @@ +(** Union find. *) + +open! Base + +type 'a t +(** Represents a single element. + + Each element belongs to an equivalence class and each equivalence class has + a value of type ['a] assocated with it. *) + +val make : 'a -> 'a t +(** [make v] creates a new equivalence class consisting of a single element + which is returned to the caller. + + The value [v] is assocated with the equivalence class being created. *) + +val value : 'a t -> 'a +(** [value e] returns the value associated with equivalence class the element + [e] belongs to. *) + +val union : f:('a -> 'a -> 'a) -> 'a t -> 'a t -> unit +(** [union ~f a b] makes elements [a] and [b] belong to the same equivalence + class so that [equal a b] returns [true] afterwards. + + The resulted value associated with the equivalence class is being merged as + specified by the [f] function. *) + +val link : target:'a t -> 'a t -> unit +(** [link a b] is the same as [union a b] but guarantees to link [b] to [a] and + not vice versa. *) + +val equal : 'a t -> 'a t -> bool +(** [equal a b] checks that both elements [a] and [b] belong to the same + equivalence class. *) + +val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t diff --git a/bidi_local/var.ml b/bidi_local/var.ml new file mode 100644 index 0000000..e18d47b --- /dev/null +++ b/bidi_local/var.ml @@ -0,0 +1,175 @@ +open Import +open Syntax + +type t = var + +module Id = MakeId () + +let fresh ?name ?lvl () : var = + let id = Id.fresh () in + Union_find.make + { + id; + name; + lvl; + variance = None; + ty = None; + lower = Ty_bot; + upper = Ty_top; + } + +let refresh ?lvl var = + let id = Id.fresh () in + let var = Union_find.value var in + Union_find.make + { + id; + name = var.name; + lvl; + variance = None; + ty = None; + lower = Ty_bot; + upper = Ty_top; + } + +let reset = Id.reset + +let layout = layout_var' + +include ( + Showable (struct + type nonrec t = t + + let layout v = Layout.render (layout_var' v) + end) : + SHOWABLE with type t := t) + +let sexp_of_t v = Sexp.Atom (show v) + +let equal = Union_find.equal + +let compare a b = + let a' = Union_find.value a in + let b' = Union_find.value b in + Int.compare a'.id b'.id + +let merge_lvl lvl1 lvl2 = + match (lvl1, lvl2) with + | None, None + | Some _, None + | None, Some _ -> + failwith "lvl is not assigned" + | Some lvl1, Some lvl2 -> Some (Int.min lvl1 lvl2) + +let merge_variance variance1 variance2 = + match (variance1, variance2) with + | None, None -> None + | Some v, None + | None, Some v -> + Some v + | Some v1, Some v2 -> Some (Variance.join v1 v2) + +let merge_name v1 v2 = + match (v1.name, v2.name) with + | None, None -> None + | None, n -> n + | n, None -> n + | Some n1, Some n2 -> + Some + (if Option.value v1.lvl ~default:(-1) > Option.value v2.lvl ~default:(-1) + then n2 + else n1) + +let name v = + match (Union_find.value v).name with + | None -> failwith "name is not assigned" + | Some name -> name + +let lvl v = + match (Union_find.value v).lvl with + | None -> failwith "lvl is not assigned" + | Some lvl -> lvl + +let variance v = + match (Union_find.value v).variance with + | None -> failwith "variance is not assigned" + | Some variance -> variance + +let set_variance v variance = + let v' = Union_find.value v in + let new_variance = + match v'.variance with + | None -> variance + | Some variance' -> Variance.join variance' variance + in + v'.variance <- Some new_variance; + if Debug.log_solve then + Caml.Format.printf "VARIANCE %s %s@." (show v) (Variance.show variance) + +(** [occurs_check_adjust_lvl v ty] checks that variable [v] is not + contained within type [ty] and adjust levels of all unbound vars within + the [ty]. *) +let occurs_check_adjust_lvl v = + let v = Union_find.value v in + let rec occurs_check_ty ty' : unit = + match ty' with + | Ty_top + | Ty_bot + | Ty_const _ -> + () + | Ty_arr (args, ret) -> + List.iter args ~f:occurs_check_ty; + occurs_check_ty ret + | Ty_nullable ty -> occurs_check_ty ty + | Ty_app (f, args) -> + occurs_check_ty f; + List.iter args ~f:occurs_check_ty + | Ty_var other_var -> ( + let other_var = Union_find.value other_var in + match other_var.ty with + | Some ty' -> occurs_check_ty ty' + | None -> + if Int.equal other_var.id v.id then + Type_error.raise Error_recursive_type + else v.lvl <- merge_lvl v.lvl other_var.lvl) + | Ty_record row -> occurs_check_ty row + | Ty_row_empty -> () + | Ty_row_extend ((_name, ty), row) -> + occurs_check_ty ty; + occurs_check_ty row + in + occurs_check_ty + +let ty v = + match (Union_find.value v).ty with + | Some (Ty_var _) -> assert false + | Some ty -> Some ty + | None -> None + +let set_ty v ty = + let v' = Union_find.value v in + match (v'.ty, ty) with + | Some _, _ -> failwith "ty is already assigned" + | None, Ty_var _ -> failwith "unable to set ty to another var" + | None, ty -> + if Debug.log_solve then + Caml.Format.printf "SET %s %s@." (show v) (Ty.show ty); + occurs_check_adjust_lvl v ty; + v'.ty <- Some ty + +let is_empty v = Option.is_none (ty v) + +let union ~merge_lower ~merge_upper v1 v2 = + if Debug.log_solve then Caml.Format.printf "UNION %s %s@." (show v1) (show v2); + if equal v1 v2 then () + else + match (ty v1, ty v2) with + | None, None -> + Union_find.union v1 v2 ~f:(fun v1 v2 -> + v1.name <- merge_name v1 v2; + v1.lvl <- merge_lvl v1.lvl v2.lvl; + v1.variance <- merge_variance v1.variance v2.variance; + v1.lower <- merge_lower v1.lower v2.lower; + v1.upper <- merge_upper v1.upper v2.upper; + v1) + | _ -> failwith "cannot unify already resolved variables" diff --git a/bidi_local/var.mli b/bidi_local/var.mli new file mode 100644 index 0000000..72c0637 --- /dev/null +++ b/bidi_local/var.mli @@ -0,0 +1,37 @@ +open! Import +open! Syntax + +type t = var + +val fresh : ?name:string -> ?lvl:lvl -> unit -> t + +val refresh : ?lvl:lvl -> t -> t + +val reset : unit -> unit + +val layout : t -> Layout.layout + +include SHOWABLE with type t := t + +val sexp_of_t : t -> Sexp.t + +val equal : t -> t -> bool + +val compare : t -> t -> int + +val is_empty : t -> bool + +val union : + merge_lower:(ty -> ty -> ty) -> merge_upper:(ty -> ty -> ty) -> t -> t -> unit + +val name : t -> name + +val lvl : t -> lvl + +val ty : t -> ty option + +val set_ty : t -> ty -> unit + +val variance : t -> Variance.t + +val set_variance : t -> Variance.t -> unit diff --git a/hmx/import.ml b/hmx/import.ml new file mode 100644 index 0000000..57aedda --- /dev/null +++ b/hmx/import.ml @@ -0,0 +1,108 @@ +include Base + +module Monad = struct + include Base.Monad + + (** A signature for modules implementing monadic let-syntax. *) + module type MONAD_SYNTAX = sig + type 'a t + + val return : 'a -> 'a t + + val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t + + val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t + end + + (** Make a monadic syntax out of the monad. *) + module Make_monad_syntax (P : Base.Monad.S) : + MONAD_SYNTAX with type 'a t := 'a P.t = struct + let return = P.return + + let ( let* ) v f = P.bind v ~f + + let ( let+ ) v f = P.map v ~f + end + + module type S = sig + include Base.Monad.S + + module Monad_syntax : MONAD_SYNTAX with type 'a t := 'a t + end + + module Make (P : Basic) : S with type 'a t := 'a P.t = struct + module Self = Base.Monad.Make (P) + include Self + + module Monad_syntax = Make_monad_syntax (struct + type 'a t = 'a P.t + + include Self + end) + end +end + +module MakeId () = struct + let c = ref 0 + + let fresh () = + Int.incr c; + !c + + let reset () = c := 0 +end + +module type SHOWABLE = sig + type t + + val show : t -> string + + val print : ?label:string -> t -> unit +end + +module Showable (S : sig + type t + + val layout : t -> PPrint.document +end) : SHOWABLE with type t = S.t = struct + type t = S.t + + let show v = + let width = 60 in + let buf = Buffer.create 100 in + PPrint.ToBuffer.pretty 1. width buf (S.layout v); + Buffer.contents buf + + let print ?label v = + match label with + | Some label -> Caml.print_endline (label ^ ": " ^ show v) + | None -> Caml.print_endline (show v) +end + +module type DUMPABLE = sig + type t + + val dump : ?label:string -> t -> unit + + val sdump : ?label:string -> t -> string +end + +module Dumpable (S : sig + type t + + val sexp_of_t : t -> Sexp.t +end) : DUMPABLE with type t = S.t = struct + type t = S.t + + let dump ?label v = + let s = S.sexp_of_t v in + match label with + | None -> Caml.Format.printf "%a@." Sexp.pp_hum s + | Some label -> Caml.Format.printf "%s %a@." label Sexp.pp_hum s + + let sdump ?label v = + let s = S.sexp_of_t v in + match label with + | None -> Caml.Format.asprintf "%a@." Sexp.pp_hum s + | Some label -> Caml.Format.asprintf "%s %a@." label Sexp.pp_hum s +end