Skip to content

Commit 2335dde

Browse files
committed
Bigarrays
1 parent 159ebfa commit 2335dde

File tree

5 files changed

+313
-9
lines changed

5 files changed

+313
-9
lines changed

compiler/lib-wasm/gc_target.ml

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,38 @@ module Type = struct
420420
}
421421
])
422422
})
423+
424+
let int_array_type =
425+
register_type "int_array" (fun () ->
426+
return
427+
{ supertype = None
428+
; final = true
429+
; typ = W.Array { mut = true; typ = Value I32 }
430+
})
431+
432+
let bigarray_type =
433+
register_type "bigarray" (fun () ->
434+
let* custom_operations = custom_operations_type in
435+
let* int_array = int_array_type in
436+
let* custom = custom_type in
437+
return
438+
{ supertype = Some custom
439+
; final = true
440+
; typ =
441+
W.Struct
442+
[ { mut = false
443+
; typ = Value (Ref { nullable = false; typ = Type custom_operations })
444+
}
445+
; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) }
446+
; { mut = false
447+
; typ = Value (Ref { nullable = false; typ = Type int_array })
448+
}
449+
; { mut = false; typ = Packed I8 }
450+
; { mut = false; typ = Packed I8 }
451+
; { mut = false; typ = Packed I8 }
452+
; { mut = false; typ = Value I32 }
453+
]
454+
})
423455
end
424456

425457
module Value = struct
@@ -1366,6 +1398,56 @@ module Math = struct
13661398
let exp2 x = power (return (W.Const (F64 2.))) x
13671399
end
13681400

1401+
module Bigarray = struct
1402+
let dim1 a =
1403+
let* ty = Type.bigarray_type in
1404+
Memory.wasm_array_get
1405+
~ty:Type.int_array_type
1406+
(Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2)
1407+
(Arith.const 0l)
1408+
1409+
let get ~kind a i =
1410+
match kind with
1411+
| Typing.Bigarray.Int8_unsigned | Char ->
1412+
let* f =
1413+
register_import
1414+
~import_module:"bindings"
1415+
~name:"ta_get_ui8"
1416+
(Fun
1417+
{ W.params = [ Ref { nullable = false; typ = Extern }; I32 ]
1418+
; result = [ I32 ]
1419+
})
1420+
in
1421+
let* ty = Type.bigarray_type in
1422+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
1423+
let* i = Value.int_val i in
1424+
Value.val_int (return (W.Call (f, [ ta; i ])))
1425+
| _ -> assert false
1426+
1427+
let set ~kind a i v =
1428+
match kind with
1429+
| Typing.Bigarray.Int8_unsigned | Char ->
1430+
let* f =
1431+
register_import
1432+
~import_module:"bindings"
1433+
~name:"ta_set_ui8"
1434+
(Fun
1435+
{ W.params =
1436+
[ Ref { nullable = false; typ = Extern }
1437+
; I32
1438+
; Ref { nullable = false; typ = I31 }
1439+
]
1440+
; result = []
1441+
})
1442+
in
1443+
let* ty = Type.bigarray_type in
1444+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
1445+
let* i = Value.int_val i in
1446+
let* v = cast I31 v in
1447+
instr (W.CallInstr (f, [ ta; i; v ]))
1448+
| _ -> assert false
1449+
end
1450+
13691451
module JavaScript = struct
13701452
let anyref = W.Ref { nullable = true; typ = Any }
13711453

compiler/lib-wasm/generate.ml

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,44 @@ module Generate (Target : Target_sig.S) = struct
698698
let* x' = x' in
699699
let* y' = y' in
700700
return (W.Call (f, [ x'; y' ])))
701-
| _ -> invalid_arity "caml_compare" l ~expected:2)
701+
| _ -> invalid_arity "caml_compare" l ~expected:2);
702+
register_prim "caml_ba_get_1" `Mutator (fun ctx context transl_prim_arg l ->
703+
match l with
704+
| [ x; y ] -> (
705+
let x' = transl_prim_arg x in
706+
let y' = transl_prim_arg y in
707+
match get_type ctx x with
708+
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
709+
seq
710+
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
711+
instr (W.Br_if (label_index context bound_error_pc, cond)))
712+
(Bigarray.get ~kind x' y')
713+
| _ ->
714+
let* f = register_import ~name:"caml_ba_get_1" (Fun (func_type 2)) in
715+
let* x' = x' in
716+
let* y' = y' in
717+
return (W.Call (f, [ x'; y' ])))
718+
| _ -> invalid_arity "caml_ba_get_1" l ~expected:2);
719+
register_prim "caml_ba_set_1" `Mutator (fun ctx context transl_prim_arg l ->
720+
match l with
721+
| [ x; y; z ] -> (
722+
let x' = transl_prim_arg x in
723+
let y' = transl_prim_arg y in
724+
let z' = transl_prim_arg z in
725+
match get_type ctx x with
726+
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
727+
seq
728+
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
729+
let* () = instr (W.Br_if (label_index context bound_error_pc, cond)) in
730+
Bigarray.set ~kind x' y' z')
731+
Value.unit
732+
| _ ->
733+
let* f = register_import ~name:"caml_ba_set_1" (Fun (func_type 3)) in
734+
let* x' = x' in
735+
let* y' = y' in
736+
let* z' = z' in
737+
return (W.Call (f, [ x'; y'; z' ])))
738+
| _ -> invalid_arity "caml_ba_set_1" l ~expected:3)
702739

703740
let rec translate_expr ctx context x e =
704741
match e with
@@ -933,7 +970,9 @@ module Generate (Target : Target_sig.S) = struct
933970
| "caml_bytes_set"
934971
| "caml_check_bound"
935972
| "caml_check_bound_gen"
936-
| "caml_check_bound_float" )
973+
| "caml_check_bound_float"
974+
| "caml_ba_get_1"
975+
| "caml_ba_set_1" )
937976
, _ ) ) -> fst n, true
938977
| Let
939978
( _

compiler/lib-wasm/target_sig.ml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,19 @@ module type S = sig
246246
val round : expression -> expression
247247
end
248248

249+
module Bigarray : sig
250+
val dim1 : expression -> expression
251+
252+
val get : kind:Typing.Bigarray.kind -> expression -> expression -> expression
253+
254+
val set :
255+
kind:Typing.Bigarray.kind
256+
-> expression
257+
-> expression
258+
-> expression
259+
-> unit Code_generation.t
260+
end
261+
249262
val internal_primitives :
250263
(string
251264
* Primitive.kind

compiler/lib-wasm/typing.ml

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,89 @@ type number =
1111
| Nativeint
1212
| Float
1313

14+
module Bigarray = struct
15+
type kind =
16+
| Float32
17+
| Float64
18+
| Int8_signed
19+
| Int8_unsigned
20+
| Int16_signed
21+
| Int16_unsigned
22+
| Int32
23+
| Int64
24+
| Int
25+
| Nativeint
26+
| Complex32
27+
| Complex64
28+
| Char
29+
| Float16
30+
31+
type layout =
32+
| C
33+
| Fortran
34+
35+
type t =
36+
{ kind : kind
37+
; layout : layout
38+
}
39+
40+
let make ~kind ~layout =
41+
{ kind =
42+
(match kind with
43+
| 0 -> Float32
44+
| 1 -> Float64
45+
| 2 -> Int8_signed
46+
| 3 -> Int8_unsigned
47+
| 4 -> Int16_signed
48+
| 5 -> Int16_unsigned
49+
| 6 -> Int32
50+
| 7 -> Int64
51+
| 8 -> Int
52+
| 9 -> Nativeint
53+
| 10 -> Complex32
54+
| 11 -> Complex64
55+
| 12 -> Char
56+
| 13 -> Float16
57+
| _ -> assert false)
58+
; layout =
59+
(match layout with
60+
| 0 -> C
61+
| 1 -> Fortran
62+
| _ -> assert false)
63+
}
64+
65+
let print f { kind; layout } =
66+
Format.fprintf
67+
f
68+
"bigarray{%s,%s}"
69+
(match kind with
70+
| Float32 -> "float32"
71+
| Float64 -> "float64"
72+
| Int8_signed -> "sint8"
73+
| Int8_unsigned -> "uint8"
74+
| Int16_signed -> "sint16"
75+
| Int16_unsigned -> "uint16"
76+
| Int32 -> "int32"
77+
| Int64 -> "int64"
78+
| Int -> "int"
79+
| Nativeint -> "nativeint"
80+
| Complex32 -> "complex32"
81+
| Complex64 -> "complex64"
82+
| Char -> "char"
83+
| Float16 -> "float16")
84+
(match layout with
85+
| C -> "C"
86+
| Fortran -> "Fortran")
87+
88+
let equal { kind; layout } { kind = kind'; layout = layout' } =
89+
phys_equal kind kind' && phys_equal layout layout'
90+
end
91+
1492
type typ =
1593
| Top
1694
| Number of number
1795
| Tuple of typ array
96+
| Bigarray of Bigarray.t
1897
| Bot
1998

2099
module Domain = struct
@@ -25,9 +104,17 @@ module Domain = struct
25104
| Bot, t | t, Bot -> t
26105
| Number n, Number n' -> if Poly.equal n n' then t else Top
27106
| Tuple t, Tuple t' ->
28-
if Array.length t = Array.length t' then Tuple (Array.map2 ~f:join t t') else Top
107+
let l = Array.length t in
108+
let l' = Array.length t' in
109+
Tuple
110+
(if l = l'
111+
then Array.map2 ~f:join t t'
112+
else
113+
Array.init (max l l') ~f:(fun i ->
114+
if i < l then if i < l' then join t.(i) t'.(i) else t.(i) else t'.(i)))
115+
| Bigarray b, Bigarray b' when Bigarray.equal b b' -> t
29116
| Top, _ | _, Top -> Top
30-
| Number _, Tuple _ | Tuple _, Number _ -> Top
117+
| (Number _ | Tuple _ | Bigarray _), _ -> Top
31118

32119
let join_set ?(others = false) f s =
33120
if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot
@@ -38,7 +125,8 @@ module Domain = struct
38125
| Number t, Number t' -> Poly.equal t t'
39126
| Tuple t, Tuple t' ->
40127
Array.length t = Array.length t' && Array.for_all2 ~f:equal t t'
41-
| (Top | Tuple _ | Number _ | Bot), _ -> false
128+
| Bigarray b, Bigarray b' -> Bigarray.equal b b'
129+
| (Top | Tuple _ | Number _ | Bigarray _ | Bot), _ -> false
42130

43131
let bot = Bot
44132

@@ -47,8 +135,15 @@ module Domain = struct
47135
| _, Top | Bot, _ -> true
48136
| Top, _ | _, Bot -> false
49137
| Number t, Number t' -> Poly.equal t t'
50-
| Tuple t, Tuple t' -> Array.length t = Array.length t' && Array.for_all2 ~f:sub t t'
51-
| Number _, _ | Tuple _, _ -> false
138+
| Tuple t, Tuple t' ->
139+
Array.length t <= Array.length t'
140+
&&
141+
let rec compare t t' i l =
142+
i = l || (sub t.(i) t'.(i) && compare t t' (i + 1) l)
143+
in
144+
compare t t' 0 (Array.length t)
145+
| Bigarray b, Bigarray b' -> Bigarray.equal b b'
146+
| (Number _ | Tuple _ | Bigarray _), _ -> false
52147

53148
let rec print f t =
54149
match t with
@@ -59,12 +154,30 @@ module Domain = struct
59154
| Number Int64 -> Format.fprintf f "int64"
60155
| Number Nativeint -> Format.fprintf f "nativeint"
61156
| Number Float -> Format.fprintf f "float"
157+
| Bigarray b -> Bigarray.print f b
62158
| Tuple t ->
63159
Format.fprintf
64160
f
65161
"(%a)"
66162
(Format.pp_print_array ~pp_sep:(fun f () -> Format.fprintf f ",") print)
67163
t
164+
165+
let depth_treshold = 4
166+
167+
let rec depth t =
168+
match t with
169+
| Top | Bot | Number _ | Bigarray _ -> 0
170+
| Tuple l -> 1 + Array.fold_left ~f:(fun acc t' -> max (depth t') acc) l ~init:0
171+
172+
let rec truncate depth t =
173+
match t with
174+
| Top | Bot | Number _ | Bigarray _ -> t
175+
| Tuple l ->
176+
if depth = 0
177+
then Top
178+
else Tuple (Array.map ~f:(fun t' -> truncate (depth - 1) t') l)
179+
180+
let limit t = if depth t > depth_treshold then truncate depth_treshold t else t
68181
end
69182

70183
let update_deps st { blocks; _ } =
@@ -268,7 +381,7 @@ let propagate st approx x : Domain.t =
268381
match st.state.mutable_fields.(Var.idx x) with
269382
| All_fields -> Top
270383
| Some_fields s when IntSet.mem i s -> Top
271-
| Some_fields _ | No_field -> Var.Tbl.get approx y)
384+
| Some_fields _ | No_field -> Domain.limit (Var.Tbl.get approx y))
272385
lst)
273386
| Field (y, n, _) -> (
274387
match Var.Tbl.get approx y with
@@ -316,7 +429,32 @@ let propagate st approx x : Domain.t =
316429
| Expr (Closure (params, _, _))
317430
when List.length args = List.length params ->
318431
Domain.join_set
319-
(fun y -> Var.Tbl.get approx y)
432+
(fun y ->
433+
match st.state.defs.(Var.idx y) with
434+
| Expr
435+
(Prim (Extern "caml_ba_create", [ Pv kind; Pv layout; _ ]))
436+
-> (
437+
let m =
438+
List.fold_left2
439+
~f:(fun m p a -> Var.Map.add p a m)
440+
~init:Var.Map.empty
441+
params
442+
args
443+
in
444+
try
445+
match
446+
( st.state.defs.(Var.idx (Var.Map.find kind m))
447+
, st.state.defs.(Var.idx (Var.Map.find layout m)) )
448+
with
449+
| Expr (Constant (Int kind)), Expr (Constant (Int layout))
450+
->
451+
Bigarray
452+
(Bigarray.make
453+
~kind:(Targetint.to_int_exn kind)
454+
~layout:(Targetint.to_int_exn layout))
455+
| _ -> raise Not_found
456+
with Not_found -> Var.Tbl.get approx y)
457+
| _ -> Var.Tbl.get approx y)
320458
(Var.Map.find g st.state.return_values)
321459
| Expr (Closure (_, _, _)) ->
322460
(* The function is partially applied or over applied *)
@@ -368,6 +506,10 @@ let print_opt typ f e =
368506
l)
369507
[ Number Int; Number Int32; Number Int64; Number Nativeint; Number Float ]
370508
then Format.fprintf f " OPT"
509+
| Prim (Extern ("caml_ba_get_1" | "caml_ba_set_1"), Pv x :: _) -> (
510+
match Var.Tbl.get typ x with
511+
| Bigarray _ -> Format.fprintf f " OPT"
512+
| _ -> ())
371513
| _ -> ()
372514

373515
let f ~state ~info p =

0 commit comments

Comments
 (0)