From e8a3ed2a04713748ba7c577a04ae8f588844430c Mon Sep 17 00:00:00 2001 From: Spotandjake <40705786+spotandjake@users.noreply.github.com> Date: Wed, 12 Feb 2025 19:26:40 -0500 Subject: [PATCH] feat(compiler): Deduplicate foreign imports (#2233) --- compiler/src/codegen/compcore.re | 97 +++++++++++++-------- compiler/src/codegen/linkedtree.re | 6 +- compiler/test/suites/basic_functionality.re | 2 +- compiler/test/suites/includes.re | 48 ++++++++++ 4 files changed, 113 insertions(+), 40 deletions(-) diff --git a/compiler/src/codegen/compcore.re b/compiler/src/codegen/compcore.re index c16ccf80f..c552b5475 100644 --- a/compiler/src/codegen/compcore.re +++ b/compiler/src/codegen/compcore.re @@ -8,6 +8,8 @@ open Grain_utils; open Comp_utils; open Comp_wasm_prim; +module StringSet = Set.Make(String); + let sources: ref(list((Expression.t, Grain_parsing.Location.t))) = ref([]); /** Environment */ @@ -22,6 +24,7 @@ type codegen_env = { /* Allocated closures which need backpatching */ backpatches: ref(list((Expression.t, closure_data))), required_imports: list(import), + foreign_import_resolutions: ref(StringSet.t), global_import_resolutions: Hashtbl.t(string, string), func_import_resolutions: Hashtbl.t(string, string), compilation_mode: Config.compilation_mode, @@ -90,6 +93,7 @@ let init_codegen_env = }, backpatches: ref([]), required_imports: [], + foreign_import_resolutions: ref(StringSet.empty), global_import_resolutions, func_import_resolutions, compilation_mode: Normal, @@ -2782,6 +2786,14 @@ and compile_instr = (wasm_mod, env, instr) => compiled_args, Type.create(Array.of_list(List.map(wasm_type, retty))), ); + } else if (StringSet.mem(func_name, env.foreign_import_resolutions^)) { + // Deduplicated imports; call resolved name directly + Expression.Call.make( + wasm_mod, + resolved_name, + compiled_args, + Type.create(Array.of_list(List.map(wasm_type, retty))), + ); } else { // Raw function resolved to Grain function; inject closure argument let closure_global = resolve_global(~env, func_name); @@ -3070,7 +3082,7 @@ let compute_table_size = (env, {function_table_elements}) => { List.length(function_table_elements); }; -let compile_imports = (wasm_mod, env, {imports}) => { +let compile_imports = (wasm_mod, env, {imports}, import_map) => { let compile_module_name = name => fun | MImportWasm => name @@ -3081,7 +3093,6 @@ let compile_imports = (wasm_mod, env, {imports}) => { | (MImportGrain, MGlobalImport(_)) => "GRAIN$EXPORT$" ++ name | _ => name }; - let compile_import = ({mimp_id, mimp_mod, mimp_name, mimp_type, mimp_kind}) => { let module_name = compile_module_name(mimp_mod, mimp_kind); let item_name = compile_import_name(mimp_name, mimp_kind, mimp_type); @@ -3090,37 +3101,53 @@ let compile_imports = (wasm_mod, env, {imports}) => { | MImportGrain => get_grain_imported_name(mimp_mod, mimp_id) | MImportWasm => Ident.unique_name(mimp_id) }; - switch (mimp_kind, mimp_type) { - | (MImportGrain, MGlobalImport(ty, mut)) => - Import.add_global_import( - wasm_mod, - internal_name, - module_name, - item_name, - wasm_type(ty), - mut, - ) - | (_, MFuncImport(args, ret)) => - let proc_list = l => - Type.create @@ Array.of_list @@ List.map(wasm_type, l); - Import.add_function_import( - wasm_mod, - internal_name, - module_name, - item_name, - proc_list(args), - proc_list(ret), - ); - | (_, MGlobalImport(typ, mut)) => - let typ = wasm_type(typ); - Import.add_global_import( - wasm_mod, - internal_name, - module_name, - item_name, - typ, - mut, - ); + let import_key = (module_name, item_name, mimp_kind, mimp_type); + switch (Hashtbl.find_opt(import_map, import_key)) { + | Some(name) when mimp_kind == MImportWasm => + // Deduplicate wasm imports by resolving them to the previously imported name + let linked_name = linked_name(~env, internal_name); + switch (mimp_type) { + | MFuncImport(_, _) => + Hashtbl.add(env.func_import_resolutions, linked_name, name) + | MGlobalImport(_, _) => + Hashtbl.add(env.global_import_resolutions, linked_name, name) + }; + env.foreign_import_resolutions := + StringSet.add(linked_name, env.foreign_import_resolutions^); + | _ => + Hashtbl.add(import_map, import_key, internal_name); + switch (mimp_kind, mimp_type) { + | (MImportGrain, MGlobalImport(ty, mut)) => + Import.add_global_import( + wasm_mod, + internal_name, + module_name, + item_name, + wasm_type(ty), + mut, + ) + | (_, MFuncImport(args, ret)) => + let proc_list = l => + Type.create @@ Array.of_list @@ List.map(wasm_type, l); + Import.add_function_import( + wasm_mod, + internal_name, + module_name, + item_name, + proc_list(args), + proc_list(ret), + ); + | (_, MGlobalImport(typ, mut)) => + let typ = wasm_type(typ); + Import.add_global_import( + wasm_mod, + internal_name, + module_name, + item_name, + typ, + mut, + ); + }; }; }; @@ -3472,12 +3499,14 @@ let compile_wasm_module = Type.funcref, ); + let import_map = Hashtbl.create(10); + let compile_one = (dep_id, prog: mash_code) => { let env = {...env, dep_id, compilation_mode: prog.compilation_mode}; + ignore @@ compile_imports(wasm_mod, env, prog, import_map); ignore @@ compile_globals(wasm_mod, env, prog); ignore @@ compile_functions(wasm_mod, env, prog); ignore @@ compile_exports(wasm_mod, env, prog); - ignore @@ compile_imports(wasm_mod, env, prog); ignore @@ compile_tables(wasm_mod, env, prog); }; diff --git a/compiler/src/codegen/linkedtree.re b/compiler/src/codegen/linkedtree.re index 1f9f00563..d39d2fa6f 100644 --- a/compiler/src/codegen/linkedtree.re +++ b/compiler/src/codegen/linkedtree.re @@ -109,11 +109,7 @@ let link = main_mashtree => { (resolved_module, import.mimp_name), ); let import_name = - Printf.sprintf( - "%s_%d", - Ident.unique_name(import.mimp_id), - dep_id^, - ); + internal_name(Ident.unique_name(import.mimp_id), dep_id^); Option.iter( global => Hashtbl.add(global_import_resolutions, import_name, global), diff --git a/compiler/test/suites/basic_functionality.re b/compiler/test/suites/basic_functionality.re index 83e3d0575..b0949fccd 100644 --- a/compiler/test/suites/basic_functionality.re +++ b/compiler/test/suites/basic_functionality.re @@ -377,6 +377,6 @@ describe("basic functionality", ({test, testSkip}) => { ~config_fn=smallestFileConfig, "smallest_grain_program", "", - 6507, + 6540, ); }); diff --git a/compiler/test/suites/includes.re b/compiler/test/suites/includes.re index f7ec1ff34..f49fc321e 100644 --- a/compiler/test/suites/includes.re +++ b/compiler/test/suites/includes.re @@ -207,4 +207,52 @@ describe("includes", ({test, testSkip}) => { "from \"reprovideContents\" include ReprovideContents; use ReprovideContents.{ type OtherT as Other }; print({ x: 1 }: Other)", "{\n x: 1\n}\n", ); + /* Duplicate imports */ + test("dedupe_includes", ({expect}) => { + let name = "dedupe_includes"; + let outfile = wasmfile(name); + ignore @@ + compile( + ~hook=Grain.Compile.stop_after_assembled, + name, + {| + module DeDupeIncludes + // Ensures test is only included once + foreign wasm test: WasmI32 => WasmI32 from "env" + let test2 = test + foreign wasm test: WasmI32 => WasmI32 from "env" + @unsafe + let _ = { + test(1n) + test2(1n) + } + |}, + ); + let ic = open_in_bin(outfile); + let sections = Grain_utils.Wasm_utils.get_wasm_sections(ic); + close_in(ic); + let import_section = + List.find_map( + (sec: Grain_utils.Wasm_utils.wasm_bin_section) => + switch (sec) { + | {sec_type: Import(imports)} => Some(imports) + | _ => None + }, + sections, + ); + expect.option(import_section).toBeSome(); + expect.int(List.length(Option.get(import_section))).toBe(2); + // Runtime printing import + expect.list(Option.get(import_section)).toContainEqual(( + WasmFunction, + "wasi_snapshot_preview1", + "fd_write", + )); + // Test import + expect.list(Option.get(import_section)).toContainEqual(( + WasmFunction, + "env", + "test", + )); + }); });