Skip to content

Commit de0c2cc

Browse files
committed
More precise return types
1 parent db598fc commit de0c2cc

File tree

9 files changed

+283
-76
lines changed

9 files changed

+283
-76
lines changed

compiler/lib-wasm/code_generation.ml

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,68 @@ let heap_type_sub (ty : W.heap_type) (ty' : W.heap_type) st =
199199
(* I31, struct, array and none have no other subtype *)
200200
| _, (I31 | Type _ | Struct | Array | None_) -> false, st
201201

202+
(*ZZZ*)
203+
let rec type_index_lub ty ty' st =
204+
if Var.equal ty ty'
205+
then Some ty
206+
else
207+
let type_field = Var.Hashtbl.find st.context.types ty in
208+
match type_field.supertype with
209+
| None -> None
210+
| Some ty -> (
211+
match type_index_lub ty ty' st with
212+
| Some ty -> Some ty
213+
| None -> (
214+
let type_field = Var.Hashtbl.find st.context.types ty' in
215+
match type_field.supertype with
216+
| None -> None
217+
| Some ty' -> type_index_lub ty ty' st))
218+
219+
let heap_type_lub (ty : W.heap_type) (ty' : W.heap_type) =
220+
match ty, ty' with
221+
| (Func | Extern), _ | _, (Func | Extern) -> assert false
222+
| None_, _ -> return ty'
223+
| _, None_ | Struct, Struct | Array, Array -> return ty
224+
| Any, _ | _, Any -> return W.Any
225+
| Eq, _
226+
| _, Eq
227+
| (Struct | Array | Type _), I31
228+
| I31, (Struct | Array | Type _)
229+
| Struct, Array
230+
| Array, Struct -> return (Eq : W.heap_type)
231+
| Struct, Type t | Type t, Struct -> (
232+
fun st ->
233+
let type_field = Var.Hashtbl.find st.context.types t in
234+
match type_field.typ with
235+
| Struct _ -> W.Struct, st
236+
| Array _ | Func _ -> W.Eq, st)
237+
| Array, Type t | Type t, Array -> (
238+
fun st ->
239+
let type_field = Var.Hashtbl.find st.context.types t in
240+
match type_field.typ with
241+
| Array _ -> W.Struct, st
242+
| Struct _ | Func _ -> W.Eq, st)
243+
| Type t, Type t' -> (
244+
let* r = fun st -> type_index_lub t t' st, st in
245+
match r with
246+
| Some t'' -> return (Type t'' : W.heap_type)
247+
| None -> (
248+
fun st ->
249+
let type_field = Var.Hashtbl.find st.context.types t in
250+
let type_field' = Var.Hashtbl.find st.context.types t' in
251+
match type_field.typ, type_field'.typ with
252+
| Struct _, Struct _ -> (Struct : W.heap_type), st
253+
| Array _, Array _ -> W.Array, st
254+
| (Array _ | Struct _ | Func _), (Array _ | Struct _ | Func _) -> W.Eq, st))
255+
| I31, I31 -> return W.I31
256+
257+
let value_type_lub (ty : W.value_type) (ty' : W.value_type) =
258+
match ty, ty' with
259+
| Ref { nullable; typ }, Ref { nullable = nullable'; typ = typ' } ->
260+
let* typ = heap_type_lub typ typ' in
261+
return (W.Ref { nullable = nullable || nullable'; typ })
262+
| _ -> assert false
263+
202264
let register_global name ?exported_name ?(constant = false) typ init st =
203265
st.context.other_fields <-
204266
W.Global { name; exported_name; typ; init } :: st.context.other_fields;
@@ -703,13 +765,28 @@ let push e =
703765
instr (Push e')
704766
| _ -> instr (Push e)
705767

768+
let blk' ty l st =
769+
let instrs = st.instrs in
770+
let (), st = l { st with instrs = [] } in
771+
let ty, st =
772+
match st.instrs with
773+
| Push e :: _ ->
774+
(let* ty' = expression_type e in
775+
match ty' with
776+
| None -> return ty
777+
| Some ty' -> return { ty with W.result = [ ty' ] })
778+
st
779+
| _ -> ty, st
780+
in
781+
(List.rev st.instrs, ty), { st with instrs }
782+
706783
let loop ty l =
707-
let* instrs = blk l in
708-
instr (Loop (ty, instrs))
784+
let* instrs, ty' = blk' ty l in
785+
instr (Loop (ty', instrs))
709786

710787
let block ty l =
711-
let* instrs = blk l in
712-
instr (Block (ty, instrs))
788+
let* instrs, ty' = blk' ty l in
789+
instr (Block (ty', instrs))
713790

714791
let block_expr ty l =
715792
let* instrs = blk l in
@@ -782,7 +859,7 @@ let init_code context = instrs context.init_code
782859

783860
let function_body ~context ~param_names ~body =
784861
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
785-
let (), st = body st in
862+
let res, st = body st in
786863
let local_count, body = st.var_count, List.rev st.instrs in
787864
let local_types = Array.make local_count (Var.fresh (), None) in
788865
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
@@ -800,4 +877,10 @@ let function_body ~context ~param_names ~body =
800877
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
801878
|> Array.to_list
802879
in
803-
locals, body
880+
locals, res, body
881+
882+
let eval ~context e =
883+
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
884+
let r, st = e st in
885+
assert (st.var_count = 0 && List.is_empty st.instrs);
886+
r

compiler/lib-wasm/code_generation.mli

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ val register_type : string -> (unit -> type_def t) -> Wasm_ast.var t
160160

161161
val heap_type_sub : Wasm_ast.heap_type -> Wasm_ast.heap_type -> bool t
162162

163+
val value_type_lub : Wasm_ast.value_type -> Wasm_ast.value_type -> Wasm_ast.value_type t
164+
163165
val register_import :
164166
?import_module:string -> name:string -> Wasm_ast.import_desc -> Wasm_ast.var t
165167

@@ -202,8 +204,8 @@ val need_dummy_fun : cps:bool -> arity:int -> Code.Var.t t
202204
val function_body :
203205
context:context
204206
-> param_names:Code.Var.t list
205-
-> body:unit t
206-
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
207+
-> body:'a t
208+
-> (Wasm_ast.var * Wasm_ast.value_type) list * 'a * Wasm_ast.instruction list
207209

208210
val variable_type : Code.Var.t -> Wasm_ast.value_type option t
209211

@@ -214,3 +216,5 @@ val array_placeholder : Code.Var.t -> expression
214216
val default_value :
215217
Wasm_ast.value_type
216218
-> (Wasm_ast.expression * Wasm_ast.value_type * Wasm_ast.ref_type option) t
219+
220+
val eval : context:context -> 'a t -> 'a

compiler/lib-wasm/curry.ml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ module Make (Target : Target_sig.S) = struct
9595
loop m [] f None
9696
in
9797
let param_names = args @ [ f ] in
98-
let locals, body = function_body ~context ~param_names ~body in
98+
let locals, _, body = function_body ~context ~param_names ~body in
9999
W.Function
100100
{ name
101101
; exported_name = None
102-
; typ = None
102+
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
103103
; signature = Type.func_type 1
104104
; param_names
105105
; locals
@@ -130,11 +130,11 @@ module Make (Target : Target_sig.S) = struct
130130
push (Closure.curry_allocate ~cps:false ~arity m ~f:name' ~closure:f ~arg:x)
131131
in
132132
let param_names = [ x; f ] in
133-
let locals, body = function_body ~context ~param_names ~body in
133+
let locals, _, body = function_body ~context ~param_names ~body in
134134
W.Function
135135
{ name
136136
; exported_name = None
137-
; typ = None
137+
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
138138
; signature = Type.func_type 1
139139
; param_names
140140
; locals
@@ -181,11 +181,11 @@ module Make (Target : Target_sig.S) = struct
181181
loop m [] f None
182182
in
183183
let param_names = args @ [ f ] in
184-
let locals, body = function_body ~context ~param_names ~body in
184+
let locals, _, body = function_body ~context ~param_names ~body in
185185
W.Function
186186
{ name
187187
; exported_name = None
188-
; typ = None
188+
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
189189
; signature = Type.func_type 2
190190
; param_names
191191
; locals
@@ -220,11 +220,11 @@ module Make (Target : Target_sig.S) = struct
220220
instr (W.Return (Some c))
221221
in
222222
let param_names = [ x; cont; f ] in
223-
let locals, body = function_body ~context ~param_names ~body in
223+
let locals, _, body = function_body ~context ~param_names ~body in
224224
W.Function
225225
{ name
226226
; exported_name = None
227-
; typ = None
227+
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
228228
; signature = Type.func_type 2
229229
; param_names
230230
; locals
@@ -264,7 +264,7 @@ module Make (Target : Target_sig.S) = struct
264264
build_applies (load f) l)
265265
in
266266
let param_names = l @ [ f ] in
267-
let locals, body = function_body ~context ~param_names ~body in
267+
let locals, _, body = function_body ~context ~param_names ~body in
268268
W.Function
269269
{ name
270270
; exported_name = None
@@ -311,7 +311,7 @@ module Make (Target : Target_sig.S) = struct
311311
push (call ~cps:true ~arity:2 (load f) [ x; iterate ]))
312312
in
313313
let param_names = l @ [ f ] in
314-
let locals, body = function_body ~context ~param_names ~body in
314+
let locals, _, body = function_body ~context ~param_names ~body in
315315
W.Function
316316
{ name
317317
; exported_name = None
@@ -346,11 +346,11 @@ module Make (Target : Target_sig.S) = struct
346346
instr (W.Return (Some e))
347347
in
348348
let param_names = l @ [ f ] in
349-
let locals, body = function_body ~context ~param_names ~body in
349+
let locals, _, body = function_body ~context ~param_names ~body in
350350
W.Function
351351
{ name
352352
; exported_name = None
353-
; typ = None
353+
; typ = Some (eval ~context (Type.function_type ~cps arity))
354354
; signature = Type.func_type arity
355355
; param_names
356356
; locals

compiler/lib-wasm/gc_target.ml

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,35 @@ module Type = struct
205205
let primitive_type n =
206206
{ W.params = List.init ~len:n ~f:(fun _ -> value); result = [ value ] }
207207

208-
let func_type n = primitive_type (n + 1)
209-
210-
let function_type ~cps n =
211-
let n = if cps then n + 1 else n in
212-
register_type (Printf.sprintf "function_%d" n) (fun () ->
213-
return { supertype = None; final = true; typ = W.Func (func_type n) })
208+
let func_type ?(ret = value) n =
209+
{ W.params = List.init ~len:(n + 1) ~f:(fun _ -> value); result = [ ret ] }
210+
211+
let rec function_type ~cps ?ret n =
212+
let n' = if cps then n + 1 else n in
213+
let ret_str =
214+
match ret with
215+
| None -> ""
216+
| Some (W.Ref { nullable = false; typ }) -> (
217+
match typ with
218+
| Eq -> "_eq" (*ZZZ remove ret in that case*)
219+
| I31 -> "_i31"
220+
| Struct -> "_struct"
221+
| Array -> "_array"
222+
| None_ -> "_none"
223+
| Type v -> (
224+
match Code.Var.get_name v with
225+
| None -> assert false
226+
| Some name -> "_" ^ name)
227+
| _ -> assert false)
228+
| _ -> assert false
229+
in
230+
register_type (Printf.sprintf "function_%d%s" n' ret_str) (fun () ->
231+
match ret with
232+
| None -> return { supertype = None; final = false; typ = W.Func (func_type n') }
233+
| Some ret ->
234+
let* super = function_type ~cps n in
235+
return
236+
{ supertype = Some super; final = false; typ = W.Func (func_type ~ret n') })
214237

215238
let closure_common_fields ~cps =
216239
let* fun_ty = function_type ~cps 1 in

0 commit comments

Comments
 (0)