Skip to content

Commit 086b456

Browse files
committed
Bigarrays
1 parent 0669be0 commit 086b456

File tree

5 files changed

+317
-9
lines changed

5 files changed

+317
-9
lines changed

compiler/lib-wasm/gc_target.ml

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,38 @@ module Type = struct
422422
}
423423
])
424424
})
425+
426+
let int_array_type =
427+
register_type "int_array" (fun () ->
428+
return
429+
{ supertype = None
430+
; final = true
431+
; typ = W.Array { mut = true; typ = Value I32 }
432+
})
433+
434+
let bigarray_type =
435+
register_type "bigarray" (fun () ->
436+
let* custom_operations = custom_operations_type in
437+
let* int_array = int_array_type in
438+
let* custom = custom_type in
439+
return
440+
{ supertype = Some custom
441+
; final = true
442+
; typ =
443+
W.Struct
444+
[ { mut = false
445+
; typ = Value (Ref { nullable = false; typ = Type custom_operations })
446+
}
447+
; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) }
448+
; { mut = false
449+
; typ = Value (Ref { nullable = false; typ = Type int_array })
450+
}
451+
; { mut = false; typ = Packed I8 }
452+
; { mut = false; typ = Packed I8 }
453+
; { mut = false; typ = Packed I8 }
454+
; { mut = false; typ = Value I32 }
455+
]
456+
})
425457
end
426458

427459
module Value = struct
@@ -1367,6 +1399,56 @@ module Math = struct
13671399
let exp2 x = power (return (W.Const (F64 2.))) x
13681400
end
13691401

1402+
module Bigarray = struct
1403+
let dim1 a =
1404+
let* ty = Type.bigarray_type in
1405+
Memory.wasm_array_get
1406+
~ty:Type.int_array_type
1407+
(Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2)
1408+
(Arith.const 0l)
1409+
1410+
let get ~kind a i =
1411+
match kind with
1412+
| Typing.Bigarray.Int8_unsigned | Char ->
1413+
let* f =
1414+
register_import
1415+
~import_module:"bindings"
1416+
~name:"ta_get_ui8"
1417+
(Fun
1418+
{ W.params = [ Ref { nullable = false; typ = Extern }; I32 ]
1419+
; result = [ I32 ]
1420+
})
1421+
in
1422+
let* ty = Type.bigarray_type in
1423+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
1424+
let* i = Value.int_val i in
1425+
Value.val_int (return (W.Call (f, [ ta; i ])))
1426+
| _ -> assert false
1427+
1428+
let set ~kind a i v =
1429+
match kind with
1430+
| Typing.Bigarray.Int8_unsigned | Char ->
1431+
let* f =
1432+
register_import
1433+
~import_module:"bindings"
1434+
~name:"ta_set_ui8"
1435+
(Fun
1436+
{ W.params =
1437+
[ Ref { nullable = false; typ = Extern }
1438+
; I32
1439+
; Ref { nullable = false; typ = I31 }
1440+
]
1441+
; result = []
1442+
})
1443+
in
1444+
let* ty = Type.bigarray_type in
1445+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
1446+
let* i = Value.int_val i in
1447+
let* v = cast I31 v in
1448+
instr (W.CallInstr (f, [ ta; i; v ]))
1449+
| _ -> assert false
1450+
end
1451+
13701452
module JavaScript = struct
13711453
let anyref = W.Ref { nullable = true; typ = Any }
13721454

compiler/lib-wasm/generate.ml

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,48 @@ module Generate (Target : Target_sig.S) = struct
705705
let* x' = x' in
706706
let* y' = y' in
707707
return (W.Call (f, [ x'; y' ])))
708-
| _ -> invalid_arity "caml_compare" l ~expected:2)
708+
| _ -> invalid_arity "caml_compare" l ~expected:2);
709+
register_prim "caml_ba_get_1" `Mutator (fun ctx context transl_prim_arg l ->
710+
match l with
711+
| [ x; y ] -> (
712+
let x' = transl_prim_arg x in
713+
let y' = transl_prim_arg y in
714+
match get_type ctx x with
715+
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
716+
seq
717+
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
718+
instr (W.Br_if (label_index context bound_error_pc, cond)))
719+
(Bigarray.get ~kind x' y')
720+
| _ ->
721+
let* f =
722+
register_import ~name:"caml_ba_get_1" (Fun (Type.primitive_type 2))
723+
in
724+
let* x' = x' in
725+
let* y' = y' in
726+
return (W.Call (f, [ x'; y' ])))
727+
| _ -> invalid_arity "caml_ba_get_1" l ~expected:2);
728+
register_prim "caml_ba_set_1" `Mutator (fun ctx context transl_prim_arg l ->
729+
match l with
730+
| [ x; y; z ] -> (
731+
let x' = transl_prim_arg x in
732+
let y' = transl_prim_arg y in
733+
let z' = transl_prim_arg z in
734+
match get_type ctx x with
735+
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
736+
seq
737+
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
738+
let* () = instr (W.Br_if (label_index context bound_error_pc, cond)) in
739+
Bigarray.set ~kind x' y' z')
740+
Value.unit
741+
| _ ->
742+
let* f =
743+
register_import ~name:"caml_ba_set_1" (Fun (Type.primitive_type 3))
744+
in
745+
let* x' = x' in
746+
let* y' = y' in
747+
let* z' = z' in
748+
return (W.Call (f, [ x'; y'; z' ])))
749+
| _ -> invalid_arity "caml_ba_set_1" l ~expected:3)
709750

710751
let rec translate_expr ctx context x e =
711752
match e with
@@ -942,7 +983,9 @@ module Generate (Target : Target_sig.S) = struct
942983
| "caml_bytes_set"
943984
| "caml_check_bound"
944985
| "caml_check_bound_gen"
945-
| "caml_check_bound_float" )
986+
| "caml_check_bound_float"
987+
| "caml_ba_get_1"
988+
| "caml_ba_set_1" )
946989
, _ ) ) -> fst n, true
947990
| Let
948991
( _

compiler/lib-wasm/target_sig.ml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,19 @@ module type S = sig
252252
val round : expression -> expression
253253
end
254254

255+
module Bigarray : sig
256+
val dim1 : expression -> expression
257+
258+
val get : kind:Typing.Bigarray.kind -> expression -> expression -> expression
259+
260+
val set :
261+
kind:Typing.Bigarray.kind
262+
-> expression
263+
-> expression
264+
-> expression
265+
-> unit Code_generation.t
266+
end
267+
255268
val internal_primitives :
256269
(string
257270
* Primitive.kind

compiler/lib-wasm/typing.ml

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,89 @@ type number =
1111
| Nativeint
1212
| Float
1313

14+
module Bigarray = struct
15+
type kind =
16+
| Float32
17+
| Float64
18+
| Int8_signed
19+
| Int8_unsigned
20+
| Int16_signed
21+
| Int16_unsigned
22+
| Int32
23+
| Int64
24+
| Int
25+
| Nativeint
26+
| Complex32
27+
| Complex64
28+
| Char
29+
| Float16
30+
31+
type layout =
32+
| C
33+
| Fortran
34+
35+
type t =
36+
{ kind : kind
37+
; layout : layout
38+
}
39+
40+
let make ~kind ~layout =
41+
{ kind =
42+
(match kind with
43+
| 0 -> Float32
44+
| 1 -> Float64
45+
| 2 -> Int8_signed
46+
| 3 -> Int8_unsigned
47+
| 4 -> Int16_signed
48+
| 5 -> Int16_unsigned
49+
| 6 -> Int32
50+
| 7 -> Int64
51+
| 8 -> Int
52+
| 9 -> Nativeint
53+
| 10 -> Complex32
54+
| 11 -> Complex64
55+
| 12 -> Char
56+
| 13 -> Float16
57+
| _ -> assert false)
58+
; layout =
59+
(match layout with
60+
| 0 -> C
61+
| 1 -> Fortran
62+
| _ -> assert false)
63+
}
64+
65+
let print f { kind; layout } =
66+
Format.fprintf
67+
f
68+
"bigarray{%s,%s}"
69+
(match kind with
70+
| Float32 -> "float32"
71+
| Float64 -> "float64"
72+
| Int8_signed -> "sint8"
73+
| Int8_unsigned -> "uint8"
74+
| Int16_signed -> "sint16"
75+
| Int16_unsigned -> "uint16"
76+
| Int32 -> "int32"
77+
| Int64 -> "int64"
78+
| Int -> "int"
79+
| Nativeint -> "nativeint"
80+
| Complex32 -> "complex32"
81+
| Complex64 -> "complex64"
82+
| Char -> "char"
83+
| Float16 -> "float16")
84+
(match layout with
85+
| C -> "C"
86+
| Fortran -> "Fortran")
87+
88+
let equal { kind; layout } { kind = kind'; layout = layout' } =
89+
phys_equal kind kind' && phys_equal layout layout'
90+
end
91+
1492
type typ =
1593
| Top
1694
| Number of number
1795
| Tuple of typ array
96+
| Bigarray of Bigarray.t
1897
| Bot
1998

2099
module Domain = struct
@@ -25,9 +104,17 @@ module Domain = struct
25104
| Bot, t | t, Bot -> t
26105
| Number n, Number n' -> if Poly.equal n n' then t else Top
27106
| Tuple t, Tuple t' ->
28-
if Array.length t = Array.length t' then Tuple (Array.map2 ~f:join t t') else Top
107+
let l = Array.length t in
108+
let l' = Array.length t' in
109+
Tuple
110+
(if l = l'
111+
then Array.map2 ~f:join t t'
112+
else
113+
Array.init (max l l') ~f:(fun i ->
114+
if i < l then if i < l' then join t.(i) t'.(i) else t.(i) else t'.(i)))
115+
| Bigarray b, Bigarray b' when Bigarray.equal b b' -> t
29116
| Top, _ | _, Top -> Top
30-
| Number _, Tuple _ | Tuple _, Number _ -> Top
117+
| (Number _ | Tuple _ | Bigarray _), _ -> Top
31118

32119
let join_set ?(others = false) f s =
33120
if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot
@@ -38,7 +125,8 @@ module Domain = struct
38125
| Number t, Number t' -> Poly.equal t t'
39126
| Tuple t, Tuple t' ->
40127
Array.length t = Array.length t' && Array.for_all2 ~f:equal t t'
41-
| (Top | Tuple _ | Number _ | Bot), _ -> false
128+
| Bigarray b, Bigarray b' -> Bigarray.equal b b'
129+
| (Top | Tuple _ | Number _ | Bigarray _ | Bot), _ -> false
42130

43131
let bot = Bot
44132

@@ -47,8 +135,15 @@ module Domain = struct
47135
| _, Top | Bot, _ -> true
48136
| Top, _ | _, Bot -> false
49137
| Number t, Number t' -> Poly.equal t t'
50-
| Tuple t, Tuple t' -> Array.length t = Array.length t' && Array.for_all2 ~f:sub t t'
51-
| Number _, _ | Tuple _, _ -> false
138+
| Tuple t, Tuple t' ->
139+
Array.length t <= Array.length t'
140+
&&
141+
let rec compare t t' i l =
142+
i = l || (sub t.(i) t'.(i) && compare t t' (i + 1) l)
143+
in
144+
compare t t' 0 (Array.length t)
145+
| Bigarray b, Bigarray b' -> Bigarray.equal b b'
146+
| (Number _ | Tuple _ | Bigarray _), _ -> false
52147

53148
let rec print f t =
54149
match t with
@@ -59,12 +154,30 @@ module Domain = struct
59154
| Number Int64 -> Format.fprintf f "int64"
60155
| Number Nativeint -> Format.fprintf f "nativeint"
61156
| Number Float -> Format.fprintf f "float"
157+
| Bigarray b -> Bigarray.print f b
62158
| Tuple t ->
63159
Format.fprintf
64160
f
65161
"(%a)"
66162
(Format.pp_print_array ~pp_sep:(fun f () -> Format.fprintf f ",") print)
67163
t
164+
165+
let depth_treshold = 4
166+
167+
let rec depth t =
168+
match t with
169+
| Top | Bot | Number _ | Bigarray _ -> 0
170+
| Tuple l -> 1 + Array.fold_left ~f:(fun acc t' -> max (depth t') acc) l ~init:0
171+
172+
let rec truncate depth t =
173+
match t with
174+
| Top | Bot | Number _ | Bigarray _ -> t
175+
| Tuple l ->
176+
if depth = 0
177+
then Top
178+
else Tuple (Array.map ~f:(fun t' -> truncate (depth - 1) t') l)
179+
180+
let limit t = if depth t > depth_treshold then truncate depth_treshold t else t
68181
end
69182

70183
let update_deps st { blocks; _ } =
@@ -268,7 +381,7 @@ let propagate st approx x : Domain.t =
268381
match st.state.mutable_fields.(Var.idx x) with
269382
| All_fields -> Top
270383
| Some_fields s when IntSet.mem i s -> Top
271-
| Some_fields _ | No_field -> Var.Tbl.get approx y)
384+
| Some_fields _ | No_field -> Domain.limit (Var.Tbl.get approx y))
272385
lst)
273386
| Field (y, n, _) -> (
274387
match Var.Tbl.get approx y with
@@ -316,7 +429,32 @@ let propagate st approx x : Domain.t =
316429
| Expr (Closure (params, _, _))
317430
when List.length args = List.length params ->
318431
Domain.join_set
319-
(fun y -> Var.Tbl.get approx y)
432+
(fun y ->
433+
match st.state.defs.(Var.idx y) with
434+
| Expr
435+
(Prim (Extern "caml_ba_create", [ Pv kind; Pv layout; _ ]))
436+
-> (
437+
let m =
438+
List.fold_left2
439+
~f:(fun m p a -> Var.Map.add p a m)
440+
~init:Var.Map.empty
441+
params
442+
args
443+
in
444+
try
445+
match
446+
( st.state.defs.(Var.idx (Var.Map.find kind m))
447+
, st.state.defs.(Var.idx (Var.Map.find layout m)) )
448+
with
449+
| Expr (Constant (Int kind)), Expr (Constant (Int layout))
450+
->
451+
Bigarray
452+
(Bigarray.make
453+
~kind:(Targetint.to_int_exn kind)
454+
~layout:(Targetint.to_int_exn layout))
455+
| _ -> raise Not_found
456+
with Not_found -> Var.Tbl.get approx y)
457+
| _ -> Var.Tbl.get approx y)
320458
(Var.Map.find g st.state.return_values)
321459
| Expr (Closure (_, _, _)) ->
322460
(* The function is partially applied or over applied *)
@@ -368,6 +506,10 @@ let print_opt typ f e =
368506
l)
369507
[ Number Int; Number Int32; Number Int64; Number Nativeint; Number Float ]
370508
then Format.fprintf f " OPT"
509+
| Prim (Extern ("caml_ba_get_1" | "caml_ba_set_1"), Pv x :: _) -> (
510+
match Var.Tbl.get typ x with
511+
| Bigarray _ -> Format.fprintf f " OPT"
512+
| _ -> ())
371513
| _ -> ()
372514

373515
let f ~state ~info p =

0 commit comments

Comments
 (0)