Skip to content

Commit

Permalink
feat(compiler): Deduplicate foreign imports (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
spotandjake authored Feb 13, 2025
1 parent a76df88 commit e8a3ed2
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 40 deletions.
97 changes: 63 additions & 34 deletions compiler/src/codegen/compcore.re
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ open Grain_utils;
open Comp_utils;
open Comp_wasm_prim;

module StringSet = Set.Make(String);

let sources: ref(list((Expression.t, Grain_parsing.Location.t))) = ref([]);

/** Environment */
Expand All @@ -22,6 +24,7 @@ type codegen_env = {
/* Allocated closures which need backpatching */
backpatches: ref(list((Expression.t, closure_data))),
required_imports: list(import),
foreign_import_resolutions: ref(StringSet.t),
global_import_resolutions: Hashtbl.t(string, string),
func_import_resolutions: Hashtbl.t(string, string),
compilation_mode: Config.compilation_mode,
Expand Down Expand Up @@ -90,6 +93,7 @@ let init_codegen_env =
},
backpatches: ref([]),
required_imports: [],
foreign_import_resolutions: ref(StringSet.empty),
global_import_resolutions,
func_import_resolutions,
compilation_mode: Normal,
Expand Down Expand Up @@ -2782,6 +2786,14 @@ and compile_instr = (wasm_mod, env, instr) =>
compiled_args,
Type.create(Array.of_list(List.map(wasm_type, retty))),
);
} else if (StringSet.mem(func_name, env.foreign_import_resolutions^)) {
// Deduplicated imports; call resolved name directly
Expression.Call.make(
wasm_mod,
resolved_name,
compiled_args,
Type.create(Array.of_list(List.map(wasm_type, retty))),
);
} else {
// Raw function resolved to Grain function; inject closure argument
let closure_global = resolve_global(~env, func_name);
Expand Down Expand Up @@ -3070,7 +3082,7 @@ let compute_table_size = (env, {function_table_elements}) => {
List.length(function_table_elements);
};

let compile_imports = (wasm_mod, env, {imports}) => {
let compile_imports = (wasm_mod, env, {imports}, import_map) => {
let compile_module_name = name =>
fun
| MImportWasm => name
Expand All @@ -3081,7 +3093,6 @@ let compile_imports = (wasm_mod, env, {imports}) => {
| (MImportGrain, MGlobalImport(_)) => "GRAIN$EXPORT$" ++ name
| _ => name
};

let compile_import = ({mimp_id, mimp_mod, mimp_name, mimp_type, mimp_kind}) => {
let module_name = compile_module_name(mimp_mod, mimp_kind);
let item_name = compile_import_name(mimp_name, mimp_kind, mimp_type);
Expand All @@ -3090,37 +3101,53 @@ let compile_imports = (wasm_mod, env, {imports}) => {
| MImportGrain => get_grain_imported_name(mimp_mod, mimp_id)
| MImportWasm => Ident.unique_name(mimp_id)
};
switch (mimp_kind, mimp_type) {
| (MImportGrain, MGlobalImport(ty, mut)) =>
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
wasm_type(ty),
mut,
)
| (_, MFuncImport(args, ret)) =>
let proc_list = l =>
Type.create @@ Array.of_list @@ List.map(wasm_type, l);
Import.add_function_import(
wasm_mod,
internal_name,
module_name,
item_name,
proc_list(args),
proc_list(ret),
);
| (_, MGlobalImport(typ, mut)) =>
let typ = wasm_type(typ);
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
typ,
mut,
);
let import_key = (module_name, item_name, mimp_kind, mimp_type);
switch (Hashtbl.find_opt(import_map, import_key)) {
| Some(name) when mimp_kind == MImportWasm =>
// Deduplicate wasm imports by resolving them to the previously imported name
let linked_name = linked_name(~env, internal_name);
switch (mimp_type) {
| MFuncImport(_, _) =>
Hashtbl.add(env.func_import_resolutions, linked_name, name)
| MGlobalImport(_, _) =>
Hashtbl.add(env.global_import_resolutions, linked_name, name)
};
env.foreign_import_resolutions :=
StringSet.add(linked_name, env.foreign_import_resolutions^);
| _ =>
Hashtbl.add(import_map, import_key, internal_name);
switch (mimp_kind, mimp_type) {
| (MImportGrain, MGlobalImport(ty, mut)) =>
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
wasm_type(ty),
mut,
)
| (_, MFuncImport(args, ret)) =>
let proc_list = l =>
Type.create @@ Array.of_list @@ List.map(wasm_type, l);
Import.add_function_import(
wasm_mod,
internal_name,
module_name,
item_name,
proc_list(args),
proc_list(ret),
);
| (_, MGlobalImport(typ, mut)) =>
let typ = wasm_type(typ);
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
typ,
mut,
);
};
};
};

Expand Down Expand Up @@ -3472,12 +3499,14 @@ let compile_wasm_module =
Type.funcref,
);

let import_map = Hashtbl.create(10);

let compile_one = (dep_id, prog: mash_code) => {
let env = {...env, dep_id, compilation_mode: prog.compilation_mode};
ignore @@ compile_imports(wasm_mod, env, prog, import_map);
ignore @@ compile_globals(wasm_mod, env, prog);
ignore @@ compile_functions(wasm_mod, env, prog);
ignore @@ compile_exports(wasm_mod, env, prog);
ignore @@ compile_imports(wasm_mod, env, prog);
ignore @@ compile_tables(wasm_mod, env, prog);
};

Expand Down
6 changes: 1 addition & 5 deletions compiler/src/codegen/linkedtree.re
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ let link = main_mashtree => {
(resolved_module, import.mimp_name),
);
let import_name =
Printf.sprintf(
"%s_%d",
Ident.unique_name(import.mimp_id),
dep_id^,
);
internal_name(Ident.unique_name(import.mimp_id), dep_id^);
Option.iter(
global =>
Hashtbl.add(global_import_resolutions, import_name, global),
Expand Down
2 changes: 1 addition & 1 deletion compiler/test/suites/basic_functionality.re
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,6 @@ describe("basic functionality", ({test, testSkip}) => {
~config_fn=smallestFileConfig,
"smallest_grain_program",
"",
6507,
6540,
);
});
48 changes: 48 additions & 0 deletions compiler/test/suites/includes.re
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,52 @@ describe("includes", ({test, testSkip}) => {
"from \"reprovideContents\" include ReprovideContents; use ReprovideContents.{ type OtherT as Other }; print({ x: 1 }: Other)",
"{\n x: 1\n}\n",
);
/* Duplicate imports */
test("dedupe_includes", ({expect}) => {
let name = "dedupe_includes";
let outfile = wasmfile(name);
ignore @@
compile(
~hook=Grain.Compile.stop_after_assembled,
name,
{|
module DeDupeIncludes
// Ensures test is only included once
foreign wasm test: WasmI32 => WasmI32 from "env"
let test2 = test
foreign wasm test: WasmI32 => WasmI32 from "env"
@unsafe
let _ = {
test(1n)
test2(1n)
}
|},
);
let ic = open_in_bin(outfile);
let sections = Grain_utils.Wasm_utils.get_wasm_sections(ic);
close_in(ic);
let import_section =
List.find_map(
(sec: Grain_utils.Wasm_utils.wasm_bin_section) =>
switch (sec) {
| {sec_type: Import(imports)} => Some(imports)
| _ => None
},
sections,
);
expect.option(import_section).toBeSome();
expect.int(List.length(Option.get(import_section))).toBe(2);
// Runtime printing import
expect.list(Option.get(import_section)).toContainEqual((
WasmFunction,
"wasi_snapshot_preview1",
"fd_write",
));
// Test import
expect.list(Option.get(import_section)).toContainEqual((
WasmFunction,
"env",
"test",
));
});
});

0 comments on commit e8a3ed2

Please sign in to comment.