From 9c334eba6bfbe8d6f0b7050dd71606a4b4394d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Vouillon?= Date: Mon, 12 May 2025 17:09:55 +0200 Subject: [PATCH 1/3] Global flow analysis: keep track of which fields are mutable --- compiler/lib/global_flow.ml | 94 ++++++++++++++++++++++++++++--------- 1 file changed, 72 insertions(+), 22 deletions(-) diff --git a/compiler/lib/global_flow.ml b/compiler/lib/global_flow.ml index 1e167a47ef..483f941ad9 100644 --- a/compiler/lib/global_flow.ml +++ b/compiler/lib/global_flow.ml @@ -93,16 +93,21 @@ type escape_status = | Escape_constant (* Escapes but we know the value is not modified *) | No +type mutable_fields = + | No_field + | Some_fields of IntSet.t + | All_fields + type state = { vars : Var.ISet.t (* Set of all veriables considered *) ; deps : Var.t list Var.Tbl.t (* Dependency between variables *) ; defs : def array (* Definition of each variable *) ; variable_may_escape : escape_status array (* Any value bound to this variable may escape *) - ; variable_possibly_mutable : Var.ISet.t + ; variable_mutable_fields : mutable_fields array (* Any value bound to this variable may be mutable *) ; may_escape : escape_status array (* This value may escape *) - ; possibly_mutable : Var.ISet.t (* This value may be mutable *) + ; mutable_fields : mutable_fields array (* This value may be mutable *) ; return_values : Var.Set.t Var.Map.t (* Set of variables holding return values of each function *) ; functions_from_returned_value : Var.t list Var.Hashtbl.t @@ -162,7 +167,14 @@ let cont_deps blocks st ?ignore (pc, args) = let do_escape st level x = st.variable_may_escape.(Var.idx x) <- level -let possibly_mutable st x = Var.ISet.add st.variable_possibly_mutable x +let possibly_mutable st x = st.variable_mutable_fields.(Var.idx x) <- All_fields + +let field_possibly_mutable st x n = + match st.variable_mutable_fields.(Var.idx x) with + | No_field -> st.variable_mutable_fields.(Var.idx x) <- Some_fields (IntSet.singleton n) + | Some_fields s -> + st.variable_mutable_fields.(Var.idx x) <- Some_fields (IntSet.add n s) + | All_fields -> () let expr_deps blocks st x e = match e with @@ -267,7 +279,10 @@ let program_deps st { start; blocks; _ } = add_expr_def st x e; expr_deps blocks st x e | Assign (x, y) -> add_assign_def st x y - | Set_field (x, _, _, y) | Array_set (x, _, y) -> + | Set_field (x, n, _, y) -> + field_possibly_mutable st x n; + do_escape st Escape y + | Array_set (x, _, y) -> possibly_mutable st x; do_escape st Escape y | Event _ | Offset_ref _ -> ()); @@ -360,7 +375,7 @@ module Domain = struct Array.iter ~f:(fun y -> variable_escape ~update ~st ~approx s y) a; match s, mut with | Escape, Maybe_mutable -> - Var.ISet.add st.possibly_mutable x; + st.mutable_fields.(Var.idx x) <- All_fields; update ~children:true x | (Escape_constant | No), _ | Escape, Immutable -> ()) | Expr (Closure (params, _, _)) -> @@ -405,18 +420,28 @@ module Domain = struct s (if o then others else bot) - let mark_mutable ~update ~st a = + let mark_mutable ~update ~st a mutable_fields = match a with | Top -> () | Values { known; _ } -> Var.Set.iter (fun x -> match st.defs.(Var.idx x) with - | Expr (Block (_, _, _, Maybe_mutable)) -> - if not (Var.ISet.mem st.possibly_mutable x) - then ( - Var.ISet.add st.possibly_mutable x; - update ~children:true x) + | Expr (Block (_, _, _, Maybe_mutable)) -> ( + match st.mutable_fields.(Var.idx x), mutable_fields with + | _, No_field -> () + | No_field, _ -> + st.mutable_fields.(Var.idx x) <- mutable_fields; + update ~children:true x + | Some_fields s, Some_fields s' -> + if IntSet.exists (fun i -> not (IntSet.mem i s)) s' + then ( + st.mutable_fields.(Var.idx x) <- Some_fields (IntSet.union s s'); + update ~children:true x) + | Some_fields _, All_fields -> + st.mutable_fields.(Var.idx x) <- All_fields; + update ~children:true x + | All_fields, _ -> ()) | Expr (Block (_, _, _, Immutable)) | Expr (Closure _) -> () | Phi _ | Expr _ -> assert false) known @@ -452,7 +477,12 @@ let propagate st ~update approx x = | Some tags -> List.mem ~eq:Int.equal t tags | None -> true -> let t = a.(n) in - let m = Var.ISet.mem st.possibly_mutable z in + let m = + match st.mutable_fields.(Var.idx z) with + | No_field -> false + | Some_fields s -> IntSet.mem n s + | All_fields -> true + in if not m then add_dep st x z; add_dep st x t; let a = Var.Tbl.get approx t in @@ -480,7 +510,11 @@ let propagate st ~update approx x = (fun z -> match st.defs.(Var.idx z) with | Expr (Block (_, lst, _, _)) -> - let m = Var.ISet.mem st.possibly_mutable z in + let m = + match st.mutable_fields.(Var.idx z) with + | No_field -> false + | Some_fields _ | All_fields -> true + in if not m then add_dep st x z; Array.iter ~f:(fun t -> add_dep st x t) lst; let a = @@ -574,8 +608,9 @@ let propagate st ~update approx x = (match st.variable_may_escape.(Var.idx x) with | (Escape | Escape_constant) as s -> Domain.approx_escape ~update ~st ~approx s res | No -> ()); - if Var.ISet.mem st.variable_possibly_mutable x - then Domain.mark_mutable ~update ~st res; + (match st.variable_mutable_fields.(Var.idx x) with + | No_field -> () + | (Some_fields _ | All_fields) as s -> Domain.mark_mutable ~update ~st res s); res | Top -> Top @@ -653,9 +688,9 @@ let f ~fast p = let deps = Var.Tbl.make () [] in let defs = Array.make nv undefined in let variable_may_escape = Array.make nv No in - let variable_possibly_mutable = Var.ISet.empty () in + let variable_mutable_fields = Array.make nv No_field in let may_escape = Array.make nv No in - let possibly_mutable = Var.ISet.empty () in + let mutable_fields = Array.make nv No_field in let functions_from_returned_value = Var.Hashtbl.create 128 in Var.Map.iter (fun f s -> Var.Set.iter (fun x -> add_to_list functions_from_returned_value x f) s) @@ -667,9 +702,9 @@ let f ~fast p = ; return_values = rets ; functions_from_returned_value ; variable_may_escape - ; variable_possibly_mutable + ; variable_mutable_fields ; may_escape - ; possibly_mutable + ; mutable_fields ; known_cases = Var.Hashtbl.create 16 ; applied_functions = VarPairTbl.create 16 ; fast @@ -698,13 +733,28 @@ let f ~fast p = match a with | Top -> Format.fprintf f "top" | Values _ -> + let print_mutable_fields f s = + match s with + | No_field -> Format.fprintf f "no" + | Some_fields s -> + Format.fprintf + f + "{%a}" + (Format.pp_print_list + ~pp_sep:(fun f () -> Format.fprintf f ", ") + (fun f i -> Format.fprintf f "%d" i)) + (IntSet.elements s) + | All_fields -> Format.fprintf f "yes" + in Format.fprintf f - "%a mut:%b vmut:%b vesc:%s esc:%s" + "%a mut:%a vmut:%a vesc:%s esc:%s" (print_approx st) a - (Var.ISet.mem st.possibly_mutable x) - (Var.ISet.mem st.variable_possibly_mutable x) + print_mutable_fields + st.mutable_fields.(Var.idx x) + print_mutable_fields + st.variable_mutable_fields.(Var.idx x) (match st.variable_may_escape.(Var.idx x) with | Escape -> "Y" | Escape_constant -> "y" From fe49f37b5f3413438b61d4b02968f53f7fe91215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Vouillon?= Date: Wed, 23 Apr 2025 18:33:40 +0200 Subject: [PATCH 2/3] Wasm: specialization of number comparisons --- compiler/lib-wasm/generate.ml | 113 +++++++++- compiler/lib-wasm/typing.ml | 395 ++++++++++++++++++++++++++++++++++ compiler/lib-wasm/typing.mli | 17 ++ compiler/lib/global_flow.ml | 17 +- compiler/lib/global_flow.mli | 33 +++ runtime/js/compare.js | 2 +- 6 files changed, 566 insertions(+), 11 deletions(-) create mode 100644 compiler/lib-wasm/typing.ml create mode 100644 compiler/lib-wasm/typing.mli diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 561f8cf3fd..0ba16df0c3 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -36,6 +36,7 @@ module Generate (Target : Target_sig.S) = struct { live : int array ; in_cps : Effects.in_cps ; deadcode_sentinal : Var.t + ; types : Typing.typ Var.Tbl.t ; blocks : block Addr.Map.t ; closures : Closure_conversion.closure Var.Map.t ; global_context : Code_generation.context @@ -230,6 +231,39 @@ module Generate (Target : Target_sig.S) = struct f context (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z) | _ -> invalid_arity name l ~expected:3) + let get_type ctx p = + match p with + | Pv x -> Var.Tbl.get ctx.types x + | Pc c -> Typing.constant_type c + + let register_comparison name cmp_int cmp_boxed_int cmp_float = + register_prim name `Mutable (fun ctx _ transl_prim_arg l -> + match l with + | [ x; y ] -> ( + let x' = transl_prim_arg x in + let y' = transl_prim_arg y in + match get_type ctx x, get_type ctx y with + | Number Int, Number Int -> cmp_int x' y' + | Number Int32, Number Int32 -> + let* x' = Memory.unbox_int32 x' in + let* y' = Memory.unbox_int32 y' in + Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y'))) + | Number Nativeint, Number Nativeint -> + let* x' = Memory.unbox_nativeint x' in + let* y' = Memory.unbox_nativeint y' in + Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y'))) + | Number Int64, Number Int64 -> + let* x' = Memory.unbox_int64 x' in + let* y' = Memory.unbox_int64 y' in + Value.val_int (return (W.BinOp (I64 cmp_boxed_int, x', y'))) + | Number Float, Number Float -> float_comparison cmp_float x' y' + | _ -> + let* f = register_import ~name (Fun (Type.primitive_type 2)) in + let* x' = x' in + let* y' = y' in + return (W.Call (f, [ x'; y' ]))) + | _ -> invalid_arity name l ~expected:2) + let () = register_bin_prim "caml_array_unsafe_get" `Mutable Memory.gen_array_get; register_bin_prim "caml_floatarray_unsafe_get" `Mutable Memory.float_array_get; @@ -602,7 +636,76 @@ module Generate (Target : Target_sig.S) = struct l ~init:(return []) in - Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l) + Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l); + register_comparison "caml_greaterthan" (fun y x -> Value.lt x y) (Gt S) Gt; + register_comparison "caml_greaterequal" (fun y x -> Value.le x y) (Ge S) Ge; + register_comparison "caml_lessthan" Value.lt (Lt S) Lt; + register_comparison "caml_lessequal" Value.le (Le S) Le; + register_comparison + "caml_equal" + (fun x y -> + let* x = x in + let* y = y in + Value.val_int (return (W.RefEq (x, y)))) + Eq + Eq; + register_comparison + "caml_notequal" + (fun x y -> + let* x = x in + let* y = y in + Value.val_int (return (W.UnOp (I32 Eqz, RefEq (x, y))))) + Ne + Ne; + register_prim "caml_compare" `Mutable (fun ctx _ transl_prim_arg l -> + match l with + | [ x; y ] -> ( + let x' = transl_prim_arg x in + let y' = transl_prim_arg y in + match get_type ctx x, get_type ctx y with + | Number Int, Number Int -> + Value.val_int + Arith.( + (Value.int_val y' < Value.int_val x') + - (Value.int_val x' < Value.int_val y')) + | Number Int32, Number Int32 -> + let* f = + register_import ~name:"caml_int32_compare" (Fun (Type.primitive_type 2)) + in + let* x' = Memory.unbox_int32 x' in + let* y' = Memory.unbox_int32 y' in + return (W.Call (f, [ x'; y' ])) + | Number Nativeint, Number Nativeint -> + let* f = + register_import + ~name:"caml_nativeint_compare" + (Fun (Type.primitive_type 2)) + in + let* x' = Memory.unbox_nativeint x' in + let* y' = Memory.unbox_nativeint y' in + return (W.Call (f, [ x'; y' ])) + | Number Int64, Number Int64 -> + let* f = + register_import ~name:"caml_int64_compare" (Fun (Type.primitive_type 2)) + in + let* x' = Memory.unbox_int64 x' in + let* y' = Memory.unbox_int64 y' in + return (W.Call (f, [ x'; y' ])) + | Number Float, Number Float -> + let* f = + register_import ~name:"caml_float_compare" (Fun (Type.primitive_type 2)) + in + let* x' = Memory.unbox_int64 x' in + let* y' = Memory.unbox_int64 y' in + return (W.Call (f, [ x'; y' ])) + | _ -> + let* f = + register_import ~name:"caml_compare" (Fun (Type.primitive_type 2)) + in + let* x' = x' in + let* y' = y' in + return (W.Call (f, [ x'; y' ]))) + | _ -> invalid_arity "caml_compare" l ~expected:2) let rec translate_expr ctx context x e = match e with @@ -1183,7 +1286,8 @@ module Generate (Target : Target_sig.S) = struct ~should_export ~warn_on_unhandled_effect *) - ~deadcode_sentinal = + ~deadcode_sentinal + ~types = global_context.unit_name <- unit_name; let p, closures = Closure_conversion.f p in (* @@ -1193,6 +1297,7 @@ module Generate (Target : Target_sig.S) = struct { live = live_vars ; in_cps ; deadcode_sentinal + ; types ; blocks = p.blocks ; closures ; global_context @@ -1306,8 +1411,10 @@ let start () = make_context ~value_type:Gc_target.Type.value let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal = let t = Timer.make () in + let state, info = Global_flow.f' ~fast:false p in + let types = Typing.f ~state ~info p in let p = fix_switch_branches p in - let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal p in + let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal ~types p in if times () then Format.eprintf " code gen.: %a@." Timer.print t; res diff --git a/compiler/lib-wasm/typing.ml b/compiler/lib-wasm/typing.ml new file mode 100644 index 0000000000..56f4e7101c --- /dev/null +++ b/compiler/lib-wasm/typing.ml @@ -0,0 +1,395 @@ +open! Stdlib +open Code +open Global_flow + +let debug = Debug.find "typing" + +type number = + | Int + | Int32 + | Int64 + | Nativeint + | Float + +type typ = + | Top + | Number of number + | Tuple of typ array + | Bot + +module Domain = struct + type t = typ + + let rec join t t' = + match t, t' with + | Bot, t | t, Bot -> t + | Number n, Number n' -> if Poly.equal n n' then t else Top + | Tuple t, Tuple t' -> + if Array.length t = Array.length t' then Tuple (Array.map2 ~f:join t t') else Top + | Top, _ | _, Top -> Top + | Number _, Tuple _ | Tuple _, Number _ -> Top + + let join_set ?(others = false) f s = + if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot + + let rec equal t t' = + match t, t' with + | Top, Top | Bot, Bot -> true + | Number t, Number t' -> Poly.equal t t' + | Tuple t, Tuple t' -> + Array.length t = Array.length t' && Array.for_all2 ~f:equal t t' + | (Top | Tuple _ | Number _ | Bot), _ -> false + + let bot = Bot + + let rec sub t t' = + match t, t' with + | _, Top | Bot, _ -> true + | Top, _ | _, Bot -> false + | Number t, Number t' -> Poly.equal t t' + | Tuple t, Tuple t' -> Array.length t = Array.length t' && Array.for_all2 ~f:sub t t' + | Number _, _ | Tuple _, _ -> false + + let rec print f t = + match t with + | Top -> Format.fprintf f "top" + | Bot -> Format.fprintf f "bot" + | Number Int -> Format.fprintf f "int" + | Number Int32 -> Format.fprintf f "int32" + | Number Int64 -> Format.fprintf f "int64" + | Number Nativeint -> Format.fprintf f "nativeint" + | Number Float -> Format.fprintf f "float" + | Tuple t -> + Format.fprintf + f + "(%a)" + (Format.pp_print_array ~pp_sep:(fun f () -> Format.fprintf f ",") print) + t +end + +let update_deps st { blocks; _ } = + let add_dep st x y = Var.Tbl.set st.deps y (x :: Var.Tbl.get st.deps y) in + Addr.Map.iter + (fun _ block -> + List.iter block.body ~f:(fun i -> + match i with + | Let (x, Block (_, lst, _, _)) -> Array.iter ~f:(fun y -> add_dep st x y) lst + | _ -> ())) + blocks + +type st = + { state : state + ; info : info + } + +let rec constant_type (c : constant) = + match c with + | Int _ -> Number Int + | Int32 _ -> Number Int32 + | Int64 _ -> Number Int64 + | NativeInt _ -> Number Nativeint + | Float _ -> Number Float + | Tuple (_, a, _) -> Tuple (Array.map ~f:constant_type a) + | _ -> Top + +let prim_type prim = + match prim with + | "%int_add" + | "%int_sub" + | "%int_mul" + | "%int_div" + | "%int_mod" + | "%direct_int_mul" + | "%direct_int_div" + | "%direct_int_mod" + | "%int_and" + | "%int_or" + | "%int_xor" + | "%int_lsl" + | "%int_lsr" + | "%int_asr" + | "%int_neg" + | "caml_greaterthan" + | "caml_greaterequal" + | "caml_lessthan" + | "caml_lessequal" + | "caml_equal" + | "caml_compare" -> Number Int + | "caml_int32_bswap" -> Number Int32 + | "caml_nativeint_bswap" -> Number Nativeint + | "caml_int64_bswap" -> Number Int64 + | "caml_int32_compare" -> Number Int + | "caml_nativeint_compare" -> Number Int + | "caml_int64_compare" -> Number Int + | "caml_string_get32" -> Number Int32 + | "caml_string_get64" -> Number Int64 + | "caml_bytes_get32" -> Number Int32 + | "caml_bytes_get64" -> Number Int64 + | "caml_bytes_set32" -> Number Int + | "caml_bytes_set64" -> Number Int + | "caml_lxm_next" -> Number Int64 + | "caml_ba_uint8_get32" -> Number Int32 + | "caml_ba_uint8_get64" -> Number Int64 + | "caml_ba_uint8_set32" -> Number Int + | "caml_ba_uint8_set64" -> Number Int + | "caml_nextafter_float" -> Number Float + | "caml_classify_float" -> Number Int + | "caml_ldexp_float" -> Number Float + | "caml_erf_float" -> Number Float + | "caml_erfc_float" -> Number Float + | "caml_float_compare" -> Number Int + | "caml_floatarray_unsafe_get" -> Number Float + | "caml_bytes_unsafe_get" -> Number Int + | "caml_string_unsafe_get" -> Number Int + | "caml_bytes_get" -> Number Int + | "caml_string_get" -> Number Int + | "caml_ml_string_length" -> Number Int + | "caml_ml_bytes_length" -> Number Int + | "%direct_obj_tag" -> Number Int + | "caml_add_float" + | "caml_sub_float" + | "caml_mul_float" + | "caml_div_float" + | "caml_copysign_float" -> Number Float + | "caml_signbit_float" -> Number Float + | "caml_neg_float" + | "caml_abs_float" + | "caml_ceil_float" + | "caml_floor_float" + | "caml_trunc_float" + | "caml_round_float" + | "caml_sqrt_float" -> Number Float + | "caml_eq_float" + | "caml_neq_float" + | "caml_ge_float" + | "caml_le_float" + | "caml_gt_float" + | "caml_lt_float" + | "caml_int_of_float" -> Number Int + | "caml_float_of_int" + | "caml_cos_float" + | "caml_sin_float" + | "caml_tan_float" + | "caml_acos_float" + | "caml_asin_float" + | "caml_atan_float" + | "caml_atan2_float" + | "caml_cosh_float" + | "caml_sinh_float" + | "caml_tanh_float" + | "caml_acosh_float" + | "caml_asinh_float" + | "caml_atanh_float" + | "caml_cbrt_float" + | "caml_exp_float" + | "caml_exp2_float" + | "caml_log_float" + | "caml_expm1_float" + | "caml_log1p_float" + | "caml_log2_float" + | "caml_log10_float" + | "caml_power_float" + | "caml_hypot_float" + | "caml_fmod_float" -> Number Float + | "caml_int32_bits_of_float" -> Number Int32 + | "caml_int32_float_of_bits" -> Number Float + | "caml_int32_of_float" -> Number Int32 + | "caml_int32_to_float" -> Number Float + | "caml_int32_neg" + | "caml_int32_add" + | "caml_int32_sub" + | "caml_int32_mul" + | "caml_int32_and" + | "caml_int32_or" + | "caml_int32_xor" + | "caml_int32_div" -> Number Int32 + | "caml_int32_mod" + | "caml_int32_shift_left" + | "caml_int32_shift_right" + | "caml_int32_shift_right_unsigned" -> Number Int32 + | "caml_int32_to_int" -> Number Int + | "caml_int32_of_int" -> Number Int32 + | "caml_nativeint_of_int32" -> Number Nativeint + | "caml_nativeint_to_int32" -> Number Int32 + | "caml_int64_bits_of_float" -> Number Int64 + | "caml_int64_float_of_bits" -> Number Float + | "caml_int64_of_float" -> Number Int64 + | "caml_int64_to_float" -> Number Float + | "caml_int64_neg" + | "caml_int64_add" + | "caml_int64_sub" + | "caml_int64_mul" + | "caml_int64_and" + | "caml_int64_or" + | "caml_int64_xor" + | "caml_int64_div" + | "caml_int64_mod" + | "caml_int64_shift_left" + | "caml_int64_shift_right" + | "caml_int64_shift_right_unsigned" -> Number Int64 + | "caml_int64_to_int" -> Number Int + | "caml_int64_of_int" -> Number Int64 + | "caml_int64_to_int32" -> Number Int32 + | "caml_int64_of_int32" -> Number Int64 + | "caml_int64_to_nativeint" -> Number Nativeint + | "caml_int64_of_nativeint" -> Number Int64 + | "caml_nativeint_bits_of_float" -> Number Nativeint + | "caml_nativeint_float_of_bits" -> Number Float + | "caml_nativeint_of_float" -> Number Nativeint + | "caml_nativeint_to_float" -> Number Float + | "caml_nativeint_neg" + | "caml_nativeint_add" + | "caml_nativeint_sub" + | "caml_nativeint_mul" + | "caml_nativeint_and" + | "caml_nativeint_or" + | "caml_nativeint_xor" + | "caml_nativeint_div" + | "caml_nativeint_mod" + | "caml_nativeint_shift_left" + | "caml_nativeint_shift_right" + | "caml_nativeint_shift_right_unsigned" -> Number Nativeint + | "caml_nativeint_to_int" -> Number Int + | "caml_nativeint_of_int" -> Number Nativeint + | "caml_int_compare" -> Number Int + | _ -> Top + +let propagate st approx x : Domain.t = + match st.state.defs.(Var.idx x) with + | Phi { known; others } -> Domain.join_set ~others (fun y -> Var.Tbl.get approx y) known + | Expr e -> ( + match e with + | Constant c -> constant_type c + | Closure _ -> Top + | Block (_, lst, _, _) -> + Tuple + (Array.mapi + ~f:(fun i y -> + match st.state.mutable_fields.(Var.idx x) with + | All_fields -> Top + | Some_fields s when IntSet.mem i s -> Top + | Some_fields _ | No_field -> Var.Tbl.get approx y) + lst) + | Field (y, n, _) -> ( + match Var.Tbl.get approx y with + | Tuple t -> if n < Array.length t then t.(n) else Bot + | Top -> Top + | _ -> Bot) + | Prim + ( Extern ("caml_check_bound" | "caml_check_bound_float" | "caml_check_bound_gen") + , [ Pv y; _ ] ) -> Var.Tbl.get approx y + | Prim ((Array_get | Extern "caml_array_unsafe_get"), [ Pv y; _ ]) -> ( + match Var.Tbl.get st.info.info_approximation y with + | Values { known; others } -> + Domain.join_set + ~others + (fun z -> + match st.state.defs.(Var.idx z) with + | Expr (Block (_, lst, _, _)) -> + let m = + match st.state.mutable_fields.(Var.idx z) with + | No_field -> false + | Some_fields _ | All_fields -> true + in + if m + then Top + else + Array.fold_left + ~f:(fun acc t -> Domain.join (Var.Tbl.get approx t) acc) + ~init:Domain.bot + lst + | Expr (Closure _) -> Bot + | Phi _ | Expr _ -> assert false) + known + | Top -> Top) + | Prim (Array_get, _) -> Top + | Prim ((Vectlength | Not | IsInt | Eq | Neq | Lt | Le | Ult), _) -> Number Int + | Prim (Extern prim, _) -> prim_type prim + | Special _ -> Top + | Apply { f; args; _ } -> ( + match Var.Tbl.get st.info.info_approximation f with + | Values { known; others } -> + Domain.join_set + ~others + (fun g -> + match st.state.defs.(Var.idx g) with + | Expr (Closure (params, _, _)) + when List.length args = List.length params -> + Domain.join_set + (fun y -> Var.Tbl.get approx y) + (Var.Map.find g st.state.return_values) + | Expr (Closure (_, _, _)) -> + (* The function is partially applied or over applied *) + Top + | Expr (Block _) -> Bot + | Phi _ | Expr _ -> assert false) + known + | Top -> Top)) + +module G = Dgraph.Make_Imperative (Var) (Var.ISet) (Var.Tbl) +module Solver = G.Solver (Domain) + +let solver st = + let associated_list h x = try Var.Hashtbl.find h x with Not_found -> [] in + let g = + { G.domain = st.state.vars + ; G.iter_children = + (fun f x -> + List.iter ~f (Var.Tbl.get st.state.deps x); + List.iter + ~f:(fun g -> List.iter ~f (associated_list st.state.function_call_sites g)) + (associated_list st.state.functions_from_returned_value x)) + } + in + Solver.f () g (propagate st) + +let print_opt typ f e = + match e with + | Prim + ( Extern + ( "caml_greaterthan" + | "caml_greaterequal" + | "caml_lessthan" + | "caml_lessequal" + | "caml_equal" + | "caml_compare" ) + , l ) -> + if + List.exists + ~f:(fun t' -> + List.for_all + ~f:(fun p -> + let t = + match p with + | Pc c -> constant_type c + | Pv x -> Var.Tbl.get typ x + in + Domain.sub t t') + l) + [ Number Int; Number Int32; Number Int64; Number Nativeint; Number Float ] + then Format.fprintf f " OPT" + | _ -> () + +let f ~state ~info p = + update_deps state p; + let typ = solver { state; info } in + if debug () + then ( + Var.ISet.iter + (fun x -> + match state.defs.(Var.idx x) with + | Expr _ -> () + | Phi _ -> + let t = Var.Tbl.get typ x in + if not (Domain.equal t Top) + then Format.eprintf "%a: %a@." Var.print x Domain.print t) + state.vars; + Print.program + Format.err_formatter + (fun _ i -> + match i with + | Instr (Let (x, e)) -> + Format.asprintf "{%a}%a" Domain.print (Var.Tbl.get typ x) (print_opt typ) e + | _ -> "") + p); + typ diff --git a/compiler/lib-wasm/typing.mli b/compiler/lib-wasm/typing.mli new file mode 100644 index 0000000000..74bdf568ca --- /dev/null +++ b/compiler/lib-wasm/typing.mli @@ -0,0 +1,17 @@ +type number = + | Int + | Int32 + | Int64 + | Nativeint + | Float + +type typ = + | Top + | Number of number + | Tuple of typ array + | Bot + +val constant_type : Code.constant -> typ + +val f : + state:Global_flow.state -> info:Global_flow.info -> Code.program -> typ Code.Var.Tbl.t diff --git a/compiler/lib/global_flow.ml b/compiler/lib/global_flow.ml index 483f941ad9..1772b52582 100644 --- a/compiler/lib/global_flow.ml +++ b/compiler/lib/global_flow.ml @@ -679,7 +679,7 @@ type info = ; info_return_vals : Var.Set.t Var.Map.t } -let f ~fast p = +let f' ~fast p = let t = Timer.make () in let t1 = Timer.make () in let rets = return_values p in @@ -773,12 +773,15 @@ let f ~fast p = | Escape_constant | Escape -> Var.ISet.add info_may_escape (Var.of_idx i) | No -> ()) may_escape; - { info_defs = defs - ; info_approximation = approximation - ; info_variable_may_escape - ; info_may_escape - ; info_return_vals = rets - } + ( st + , { info_defs = defs + ; info_approximation = approximation + ; info_variable_may_escape + ; info_may_escape + ; info_return_vals = rets + } ) + +let f ~fast p = snd (f' ~fast p) let exact_call info f n = match Var.Tbl.get info.info_approximation f with diff --git a/compiler/lib/global_flow.mli b/compiler/lib/global_flow.mli index 61f5dbfb6a..4581569bb9 100644 --- a/compiler/lib/global_flow.mli +++ b/compiler/lib/global_flow.mli @@ -44,8 +44,41 @@ type info = ; info_return_vals : Var.Set.t Var.Map.t } +type mutable_fields = + | No_field + | Some_fields of Stdlib.IntSet.t + | All_fields + +type state = + { vars : Var.ISet.t (* Set of all veriables considered *) + ; deps : Var.t list Var.Tbl.t (* Dependency between variables *) + ; defs : def array (* Definition of each variable *) + ; variable_may_escape : escape_status array + (* Any value bound to this variable may escape *) + ; variable_mutable_fields : mutable_fields array + (* Any value bound to this variable may be mutable *) + ; may_escape : escape_status array (* This value may escape *) + ; mutable_fields : mutable_fields array (* This value may be mutable *) + ; return_values : Var.Set.t Var.Map.t + (* Set of variables holding return values of each function *) + ; functions_from_returned_value : Var.t list Var.Hashtbl.t + (* Functions associated to each return value *) + ; known_cases : int list Var.Hashtbl.t + (* Possible tags for a block after a [switch]. This is used to + get a more precise approximation of the effect of a field + access [Field] *) + ; applied_functions : (Var.t * Var.t, unit) Hashtbl.t + (* Functions that have been already considered at a call site. + This is to avoid repeated computations *) + ; function_call_sites : Var.t list Var.Hashtbl.t + (* Known call sites of each functions *) + ; fast : bool + } + val f : fast:bool -> Code.program -> info +val f' : fast:bool -> Code.program -> state * info + val exact_call : info -> Var.t -> int -> bool val function_arity : info -> Var.t -> int option diff --git a/runtime/js/compare.js b/runtime/js/compare.js index 0aa1289d93..7ccde88b71 100644 --- a/runtime/js/compare.js +++ b/runtime/js/compare.js @@ -251,7 +251,7 @@ function caml_compare_val(a, b, total) { b = b[i]; } } -//Provides: caml_compare (const, const) +//Provides: caml_compare mutable (const, const) //Requires: caml_compare_val function caml_compare(a, b) { return caml_compare_val(a, b, true); From 00764eb8ec4cd90215875358eaffd6abdb1cf673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Vouillon?= Date: Mon, 12 May 2025 17:09:41 +0200 Subject: [PATCH 3/3] Bigarrays --- compiler/lib-wasm/gc_target.ml | 81 +++++++++++++++++ compiler/lib-wasm/generate.ml | 47 +++++++++- compiler/lib-wasm/target_sig.ml | 13 +++ compiler/lib-wasm/typing.ml | 156 ++++++++++++++++++++++++++++++-- compiler/lib-wasm/typing.mli | 28 ++++++ 5 files changed, 316 insertions(+), 9 deletions(-) diff --git a/compiler/lib-wasm/gc_target.ml b/compiler/lib-wasm/gc_target.ml index 2cebbd38d6..70f08878d5 100644 --- a/compiler/lib-wasm/gc_target.ml +++ b/compiler/lib-wasm/gc_target.ml @@ -422,6 +422,37 @@ module Type = struct } ]) }) + + let int_array_type = + register_type "int_array" (fun () -> + return + { supertype = None + ; final = true + ; typ = W.Array { mut = true; typ = Value I32 } + }) + + let bigarray_type = + register_type "bigarray" (fun () -> + let* custom_operations = custom_operations_type in + let* int_array = int_array_type in + let* custom = custom_type in + return + { supertype = Some custom + ; final = true + ; typ = + W.Struct + [ { mut = false + ; typ = Value (Ref { nullable = false; typ = Type custom_operations }) + } + ; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) } + ; { mut = false + ; typ = Value (Ref { nullable = false; typ = Type int_array }) + } + ; { mut = false; typ = Packed I8 } + ; { mut = false; typ = Packed I8 } + ; { mut = false; typ = Packed I8 } + ] + }) end module Value = struct @@ -1367,6 +1398,56 @@ module Math = struct let exp2 x = power (return (W.Const (F64 2.))) x end +module Bigarray = struct + let dim1 a = + let* ty = Type.bigarray_type in + Memory.wasm_array_get + ~ty:Type.int_array_type + (Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2) + (Arith.const 0l) + + let get ~kind a i = + match kind with + | Typing.Bigarray.Int8_unsigned | Char -> + let* f = + register_import + ~import_module:"bindings" + ~name:"ta_get_ui8" + (Fun + { W.params = [ Ref { nullable = false; typ = Extern }; I32 ] + ; result = [ I32 ] + }) + in + let* ty = Type.bigarray_type in + let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in + let* i = Value.int_val i in + Value.val_int (return (W.Call (f, [ ta; i ]))) + | _ -> assert false + + let set ~kind a i v = + match kind with + | Typing.Bigarray.Int8_unsigned | Char -> + let* f = + register_import + ~import_module:"bindings" + ~name:"ta_set_ui8" + (Fun + { W.params = + [ Ref { nullable = false; typ = Extern } + ; I32 + ; Ref { nullable = false; typ = I31 } + ] + ; result = [] + }) + in + let* ty = Type.bigarray_type in + let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in + let* i = Value.int_val i in + let* v = cast I31 v in + instr (W.CallInstr (f, [ ta; i; v ])) + | _ -> assert false +end + module JavaScript = struct let anyref = W.Ref { nullable = true; typ = Any } diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 0ba16df0c3..fea622e5f4 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -705,7 +705,48 @@ module Generate (Target : Target_sig.S) = struct let* x' = x' in let* y' = y' in return (W.Call (f, [ x'; y' ]))) - | _ -> invalid_arity "caml_compare" l ~expected:2) + | _ -> invalid_arity "caml_compare" l ~expected:2); + register_prim "caml_ba_get_1" `Mutator (fun ctx context transl_prim_arg l -> + match l with + | [ x; y ] -> ( + let x' = transl_prim_arg x in + let y' = transl_prim_arg y in + match get_type ctx x with + | Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } -> + seq + (let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in + instr (W.Br_if (label_index context bound_error_pc, cond))) + (Bigarray.get ~kind x' y') + | _ -> + let* f = + register_import ~name:"caml_ba_get_1" (Fun (Type.primitive_type 2)) + in + let* x' = x' in + let* y' = y' in + return (W.Call (f, [ x'; y' ]))) + | _ -> invalid_arity "caml_ba_get_1" l ~expected:2); + register_prim "caml_ba_set_1" `Mutator (fun ctx context transl_prim_arg l -> + match l with + | [ x; y; z ] -> ( + let x' = transl_prim_arg x in + let y' = transl_prim_arg y in + let z' = transl_prim_arg z in + match get_type ctx x with + | Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } -> + seq + (let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in + let* () = instr (W.Br_if (label_index context bound_error_pc, cond)) in + Bigarray.set ~kind x' y' z') + Value.unit + | _ -> + let* f = + register_import ~name:"caml_ba_set_1" (Fun (Type.primitive_type 3)) + in + let* x' = x' in + let* y' = y' in + let* z' = z' in + return (W.Call (f, [ x'; y'; z' ]))) + | _ -> invalid_arity "caml_ba_set_1" l ~expected:3) let rec translate_expr ctx context x e = match e with @@ -947,7 +988,9 @@ module Generate (Target : Target_sig.S) = struct | "caml_bytes_set" | "caml_check_bound" | "caml_check_bound_gen" - | "caml_check_bound_float" ) + | "caml_check_bound_float" + | "caml_ba_get_1" + | "caml_ba_set_1" ) , _ ) ) -> fst n, true | Let ( _ diff --git a/compiler/lib-wasm/target_sig.ml b/compiler/lib-wasm/target_sig.ml index 1182b30424..3c7464fd79 100644 --- a/compiler/lib-wasm/target_sig.ml +++ b/compiler/lib-wasm/target_sig.ml @@ -252,6 +252,19 @@ module type S = sig val round : expression -> expression end + module Bigarray : sig + val dim1 : expression -> expression + + val get : kind:Typing.Bigarray.kind -> expression -> expression -> expression + + val set : + kind:Typing.Bigarray.kind + -> expression + -> expression + -> expression + -> unit Code_generation.t + end + val internal_primitives : (string * Primitive.kind diff --git a/compiler/lib-wasm/typing.ml b/compiler/lib-wasm/typing.ml index 56f4e7101c..a80e96f45b 100644 --- a/compiler/lib-wasm/typing.ml +++ b/compiler/lib-wasm/typing.ml @@ -11,10 +11,89 @@ type number = | Nativeint | Float +module Bigarray = struct + type kind = + | Float32 + | Float64 + | Int8_signed + | Int8_unsigned + | Int16_signed + | Int16_unsigned + | Int32 + | Int64 + | Int + | Nativeint + | Complex32 + | Complex64 + | Char + | Float16 + + type layout = + | C + | Fortran + + type t = + { kind : kind + ; layout : layout + } + + let make ~kind ~layout = + { kind = + (match kind with + | 0 -> Float32 + | 1 -> Float64 + | 2 -> Int8_signed + | 3 -> Int8_unsigned + | 4 -> Int16_signed + | 5 -> Int16_unsigned + | 6 -> Int32 + | 7 -> Int64 + | 8 -> Int + | 9 -> Nativeint + | 10 -> Complex32 + | 11 -> Complex64 + | 12 -> Char + | 13 -> Float16 + | _ -> assert false) + ; layout = + (match layout with + | 0 -> C + | 1 -> Fortran + | _ -> assert false) + } + + let print f { kind; layout } = + Format.fprintf + f + "bigarray{%s,%s}" + (match kind with + | Float32 -> "float32" + | Float64 -> "float64" + | Int8_signed -> "sint8" + | Int8_unsigned -> "uint8" + | Int16_signed -> "sint16" + | Int16_unsigned -> "uint16" + | Int32 -> "int32" + | Int64 -> "int64" + | Int -> "int" + | Nativeint -> "nativeint" + | Complex32 -> "complex32" + | Complex64 -> "complex64" + | Char -> "char" + | Float16 -> "float16") + (match layout with + | C -> "C" + | Fortran -> "Fortran") + + let equal { kind; layout } { kind = kind'; layout = layout' } = + phys_equal kind kind' && phys_equal layout layout' +end + type typ = | Top | Number of number | Tuple of typ array + | Bigarray of Bigarray.t | Bot module Domain = struct @@ -25,9 +104,17 @@ module Domain = struct | Bot, t | t, Bot -> t | Number n, Number n' -> if Poly.equal n n' then t else Top | Tuple t, Tuple t' -> - if Array.length t = Array.length t' then Tuple (Array.map2 ~f:join t t') else Top + let l = Array.length t in + let l' = Array.length t' in + Tuple + (if l = l' + then Array.map2 ~f:join t t' + else + Array.init (max l l') ~f:(fun i -> + if i < l then if i < l' then join t.(i) t'.(i) else t.(i) else t'.(i))) + | Bigarray b, Bigarray b' when Bigarray.equal b b' -> t | Top, _ | _, Top -> Top - | Number _, Tuple _ | Tuple _, Number _ -> Top + | (Number _ | Tuple _ | Bigarray _), _ -> Top let join_set ?(others = false) f s = if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot @@ -38,7 +125,8 @@ module Domain = struct | Number t, Number t' -> Poly.equal t t' | Tuple t, Tuple t' -> Array.length t = Array.length t' && Array.for_all2 ~f:equal t t' - | (Top | Tuple _ | Number _ | Bot), _ -> false + | Bigarray b, Bigarray b' -> Bigarray.equal b b' + | (Top | Tuple _ | Number _ | Bigarray _ | Bot), _ -> false let bot = Bot @@ -47,8 +135,15 @@ module Domain = struct | _, Top | Bot, _ -> true | Top, _ | _, Bot -> false | Number t, Number t' -> Poly.equal t t' - | Tuple t, Tuple t' -> Array.length t = Array.length t' && Array.for_all2 ~f:sub t t' - | Number _, _ | Tuple _, _ -> false + | Tuple t, Tuple t' -> + Array.length t <= Array.length t' + && + let rec compare t t' i l = + i = l || (sub t.(i) t'.(i) && compare t t' (i + 1) l) + in + compare t t' 0 (Array.length t) + | Bigarray b, Bigarray b' -> Bigarray.equal b b' + | (Number _ | Tuple _ | Bigarray _), _ -> false let rec print f t = match t with @@ -59,12 +154,30 @@ module Domain = struct | Number Int64 -> Format.fprintf f "int64" | Number Nativeint -> Format.fprintf f "nativeint" | Number Float -> Format.fprintf f "float" + | Bigarray b -> Bigarray.print f b | Tuple t -> Format.fprintf f "(%a)" (Format.pp_print_array ~pp_sep:(fun f () -> Format.fprintf f ",") print) t + + let depth_treshold = 4 + + let rec depth t = + match t with + | Top | Bot | Number _ | Bigarray _ -> 0 + | Tuple l -> 1 + Array.fold_left ~f:(fun acc t' -> max (depth t') acc) l ~init:0 + + let rec truncate depth t = + match t with + | Top | Bot | Number _ | Bigarray _ -> t + | Tuple l -> + if depth = 0 + then Top + else Tuple (Array.map ~f:(fun t' -> truncate (depth - 1) t') l) + + let limit t = if depth t > depth_treshold then truncate depth_treshold t else t end let update_deps st { blocks; _ } = @@ -268,7 +381,7 @@ let propagate st approx x : Domain.t = match st.state.mutable_fields.(Var.idx x) with | All_fields -> Top | Some_fields s when IntSet.mem i s -> Top - | Some_fields _ | No_field -> Var.Tbl.get approx y) + | Some_fields _ | No_field -> Domain.limit (Var.Tbl.get approx y)) lst) | Field (y, n, _) -> ( match Var.Tbl.get approx y with @@ -316,7 +429,32 @@ let propagate st approx x : Domain.t = | Expr (Closure (params, _, _)) when List.length args = List.length params -> Domain.join_set - (fun y -> Var.Tbl.get approx y) + (fun y -> + match st.state.defs.(Var.idx y) with + | Expr + (Prim (Extern "caml_ba_create", [ Pv kind; Pv layout; _ ])) + -> ( + let m = + List.fold_left2 + ~f:(fun m p a -> Var.Map.add p a m) + ~init:Var.Map.empty + params + args + in + try + match + ( st.state.defs.(Var.idx (Var.Map.find kind m)) + , st.state.defs.(Var.idx (Var.Map.find layout m)) ) + with + | Expr (Constant (Int kind)), Expr (Constant (Int layout)) + -> + Bigarray + (Bigarray.make + ~kind:(Targetint.to_int_exn kind) + ~layout:(Targetint.to_int_exn layout)) + | _ -> raise Not_found + with Not_found -> Var.Tbl.get approx y) + | _ -> Var.Tbl.get approx y) (Var.Map.find g st.state.return_values) | Expr (Closure (_, _, _)) -> (* The function is partially applied or over applied *) @@ -368,6 +506,10 @@ let print_opt typ f e = l) [ Number Int; Number Int32; Number Int64; Number Nativeint; Number Float ] then Format.fprintf f " OPT" + | Prim (Extern ("caml_ba_get_1" | "caml_ba_set_1"), Pv x :: _) -> ( + match Var.Tbl.get typ x with + | Bigarray _ -> Format.fprintf f " OPT" + | _ -> ()) | _ -> () let f ~state ~info p = diff --git a/compiler/lib-wasm/typing.mli b/compiler/lib-wasm/typing.mli index 74bdf568ca..85b34aa8b7 100644 --- a/compiler/lib-wasm/typing.mli +++ b/compiler/lib-wasm/typing.mli @@ -5,10 +5,38 @@ type number = | Nativeint | Float +module Bigarray : sig + type kind = + | Float32 + | Float64 + | Int8_signed + | Int8_unsigned + | Int16_signed + | Int16_unsigned + | Int32 + | Int64 + | Int + | Nativeint + | Complex32 + | Complex64 + | Char + | Float16 + + type layout = + | C + | Fortran + + type t = + { kind : kind + ; layout : layout + } +end + type typ = | Top | Number of number | Tuple of typ array + | Bigarray of Bigarray.t | Bot val constant_type : Code.constant -> typ