Skip to content

Wasm: specialization of number comparisons #1954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions compiler/lib-wasm/gc_target.ml
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,37 @@ module Type = struct
}
])
})

let int_array_type =
register_type "int_array" (fun () ->
return
{ supertype = None
; final = true
; typ = W.Array { mut = true; typ = Value I32 }
})

let bigarray_type =
register_type "bigarray" (fun () ->
let* custom_operations = custom_operations_type in
let* int_array = int_array_type in
let* custom = custom_type in
return
{ supertype = Some custom
; final = true
; typ =
W.Struct
[ { mut = false
; typ = Value (Ref { nullable = false; typ = Type custom_operations })
}
; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) }
; { mut = false
; typ = Value (Ref { nullable = false; typ = Type int_array })
}
; { mut = false; typ = Packed I8 }
; { mut = false; typ = Packed I8 }
; { mut = false; typ = Packed I8 }
]
})
end

module Value = struct
Expand Down Expand Up @@ -1367,6 +1398,56 @@ module Math = struct
let exp2 x = power (return (W.Const (F64 2.))) x
end

module Bigarray = struct
let dim1 a =
let* ty = Type.bigarray_type in
Memory.wasm_array_get
~ty:Type.int_array_type
(Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2)
(Arith.const 0l)

let get ~kind a i =
match kind with
| Typing.Bigarray.Int8_unsigned | Char ->
let* f =
register_import
~import_module:"bindings"
~name:"ta_get_ui8"
(Fun
{ W.params = [ Ref { nullable = false; typ = Extern }; I32 ]
; result = [ I32 ]
})
in
let* ty = Type.bigarray_type in
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
let* i = Value.int_val i in
Value.val_int (return (W.Call (f, [ ta; i ])))
| _ -> assert false

let set ~kind a i v =
match kind with
| Typing.Bigarray.Int8_unsigned | Char ->
let* f =
register_import
~import_module:"bindings"
~name:"ta_set_ui8"
(Fun
{ W.params =
[ Ref { nullable = false; typ = Extern }
; I32
; Ref { nullable = false; typ = I31 }
]
; result = []
})
in
let* ty = Type.bigarray_type in
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
let* i = Value.int_val i in
let* v = cast I31 v in
instr (W.CallInstr (f, [ ta; i; v ]))
| _ -> assert false
end

module JavaScript = struct
let anyref = W.Ref { nullable = true; typ = Any }

Expand Down
158 changes: 154 additions & 4 deletions compiler/lib-wasm/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ module Generate (Target : Target_sig.S) = struct
{ live : int array
; in_cps : Effects.in_cps
; deadcode_sentinal : Var.t
; types : Typing.typ Var.Tbl.t
; blocks : block Addr.Map.t
; closures : Closure_conversion.closure Var.Map.t
; global_context : Code_generation.context
Expand Down Expand Up @@ -230,6 +231,39 @@ module Generate (Target : Target_sig.S) = struct
f context (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z)
| _ -> invalid_arity name l ~expected:3)

let get_type ctx p =
match p with
| Pv x -> Var.Tbl.get ctx.types x
| Pc c -> Typing.constant_type c

let register_comparison name cmp_int cmp_boxed_int cmp_float =
register_prim name `Mutable (fun ctx _ transl_prim_arg l ->
match l with
| [ x; y ] -> (
let x' = transl_prim_arg x in
let y' = transl_prim_arg y in
match get_type ctx x, get_type ctx y with
| Number Int, Number Int -> cmp_int x' y'
| Number Int32, Number Int32 ->
let* x' = Memory.unbox_int32 x' in
let* y' = Memory.unbox_int32 y' in
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
| Number Nativeint, Number Nativeint ->
let* x' = Memory.unbox_nativeint x' in
let* y' = Memory.unbox_nativeint y' in
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
| Number Int64, Number Int64 ->
let* x' = Memory.unbox_int64 x' in
let* y' = Memory.unbox_int64 y' in
Value.val_int (return (W.BinOp (I64 cmp_boxed_int, x', y')))
| Number Float, Number Float -> float_comparison cmp_float x' y'
| _ ->
let* f = register_import ~name (Fun (Type.primitive_type 2)) in
let* x' = x' in
let* y' = y' in
return (W.Call (f, [ x'; y' ])))
| _ -> invalid_arity name l ~expected:2)

let () =
register_bin_prim "caml_array_unsafe_get" `Mutable Memory.gen_array_get;
register_bin_prim "caml_floatarray_unsafe_get" `Mutable Memory.float_array_get;
Expand Down Expand Up @@ -602,7 +636,117 @@ module Generate (Target : Target_sig.S) = struct
l
~init:(return [])
in
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l)
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l);
register_comparison "caml_greaterthan" (fun y x -> Value.lt x y) (Gt S) Gt;
register_comparison "caml_greaterequal" (fun y x -> Value.le x y) (Ge S) Ge;
register_comparison "caml_lessthan" Value.lt (Lt S) Lt;
register_comparison "caml_lessequal" Value.le (Le S) Le;
register_comparison
"caml_equal"
(fun x y ->
let* x = x in
let* y = y in
Value.val_int (return (W.RefEq (x, y))))
Eq
Eq;
register_comparison
"caml_notequal"
(fun x y ->
let* x = x in
let* y = y in
Value.val_int (return (W.UnOp (I32 Eqz, RefEq (x, y)))))
Ne
Ne;
register_prim "caml_compare" `Mutable (fun ctx _ transl_prim_arg l ->
match l with
| [ x; y ] -> (
let x' = transl_prim_arg x in
let y' = transl_prim_arg y in
match get_type ctx x, get_type ctx y with
| Number Int, Number Int ->
Value.val_int
Arith.(
(Value.int_val y' < Value.int_val x')
- (Value.int_val x' < Value.int_val y'))
| Number Int32, Number Int32 ->
let* f =
register_import ~name:"caml_int32_compare" (Fun (Type.primitive_type 2))
in
let* x' = Memory.unbox_int32 x' in
let* y' = Memory.unbox_int32 y' in
return (W.Call (f, [ x'; y' ]))
| Number Nativeint, Number Nativeint ->
let* f =
register_import
~name:"caml_nativeint_compare"
(Fun (Type.primitive_type 2))
in
let* x' = Memory.unbox_nativeint x' in
let* y' = Memory.unbox_nativeint y' in
return (W.Call (f, [ x'; y' ]))
| Number Int64, Number Int64 ->
let* f =
register_import ~name:"caml_int64_compare" (Fun (Type.primitive_type 2))
in
let* x' = Memory.unbox_int64 x' in
let* y' = Memory.unbox_int64 y' in
return (W.Call (f, [ x'; y' ]))
| Number Float, Number Float ->
let* f =
register_import ~name:"caml_float_compare" (Fun (Type.primitive_type 2))
in
let* x' = Memory.unbox_int64 x' in
let* y' = Memory.unbox_int64 y' in
return (W.Call (f, [ x'; y' ]))
| _ ->
let* f =
register_import ~name:"caml_compare" (Fun (Type.primitive_type 2))
in
let* x' = x' in
let* y' = y' in
return (W.Call (f, [ x'; y' ])))
| _ -> invalid_arity "caml_compare" l ~expected:2);
register_prim "caml_ba_get_1" `Mutator (fun ctx context transl_prim_arg l ->
match l with
| [ x; y ] -> (
let x' = transl_prim_arg x in
let y' = transl_prim_arg y in
match get_type ctx x with
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
seq
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
instr (W.Br_if (label_index context bound_error_pc, cond)))
(Bigarray.get ~kind x' y')
| _ ->
let* f =
register_import ~name:"caml_ba_get_1" (Fun (Type.primitive_type 2))
in
let* x' = x' in
let* y' = y' in
return (W.Call (f, [ x'; y' ])))
| _ -> invalid_arity "caml_ba_get_1" l ~expected:2);
register_prim "caml_ba_set_1" `Mutator (fun ctx context transl_prim_arg l ->
match l with
| [ x; y; z ] -> (
let x' = transl_prim_arg x in
let y' = transl_prim_arg y in
let z' = transl_prim_arg z in
match get_type ctx x with
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
seq
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
let* () = instr (W.Br_if (label_index context bound_error_pc, cond)) in
Bigarray.set ~kind x' y' z')
Value.unit
| _ ->
let* f =
register_import ~name:"caml_ba_set_1" (Fun (Type.primitive_type 3))
in
let* x' = x' in
let* y' = y' in
let* z' = z' in
return (W.Call (f, [ x'; y'; z' ])))
| _ -> invalid_arity "caml_ba_set_1" l ~expected:3)

let rec translate_expr ctx context x e =
match e with
Expand Down Expand Up @@ -844,7 +988,9 @@ module Generate (Target : Target_sig.S) = struct
| "caml_bytes_set"
| "caml_check_bound"
| "caml_check_bound_gen"
| "caml_check_bound_float" )
| "caml_check_bound_float"
| "caml_ba_get_1"
| "caml_ba_set_1" )
, _ ) ) -> fst n, true
| Let
( _
Expand Down Expand Up @@ -1183,7 +1329,8 @@ module Generate (Target : Target_sig.S) = struct
~should_export
~warn_on_unhandled_effect
*)
~deadcode_sentinal =
~deadcode_sentinal
~types =
global_context.unit_name <- unit_name;
let p, closures = Closure_conversion.f p in
(*
Expand All @@ -1193,6 +1340,7 @@ module Generate (Target : Target_sig.S) = struct
{ live = live_vars
; in_cps
; deadcode_sentinal
; types
; blocks = p.blocks
; closures
; global_context
Expand Down Expand Up @@ -1306,8 +1454,10 @@ let start () = make_context ~value_type:Gc_target.Type.value

let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal =
let t = Timer.make () in
let state, info = Global_flow.f' ~fast:false p in
let types = Typing.f ~state ~info p in
let p = fix_switch_branches p in
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal p in
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal ~types p in
if times () then Format.eprintf " code gen.: %a@." Timer.print t;
res

Expand Down
13 changes: 13 additions & 0 deletions compiler/lib-wasm/target_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,19 @@ module type S = sig
val round : expression -> expression
end

module Bigarray : sig
val dim1 : expression -> expression

val get : kind:Typing.Bigarray.kind -> expression -> expression -> expression

val set :
kind:Typing.Bigarray.kind
-> expression
-> expression
-> expression
-> unit Code_generation.t
end

val internal_primitives :
(string
* Primitive.kind
Expand Down
Loading
Loading