Skip to content

Commit

Permalink
feat: Refactor to compcore
Browse files Browse the repository at this point in the history
  • Loading branch information
spotandjake committed Feb 12, 2025
1 parent c79f8fe commit 2691ae7
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 64 deletions.
101 changes: 66 additions & 35 deletions compiler/src/codegen/compcore.re
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ open Grain_utils;
open Comp_utils;
open Comp_wasm_prim;

module StringSet = Set.Make(String);

let sources: ref(list((Expression.t, Grain_parsing.Location.t))) = ref([]);

/** Environment */
Expand All @@ -22,6 +24,7 @@ type codegen_env = {
/* Allocated closures which need backpatching */
backpatches: ref(list((Expression.t, closure_data))),
required_imports: list(import),
raw_resolutions: ref(StringSet.t),
global_import_resolutions: Hashtbl.t(string, string),
func_import_resolutions: Hashtbl.t(string, string),
compilation_mode: Config.compilation_mode,
Expand Down Expand Up @@ -90,6 +93,7 @@ let init_codegen_env =
},
backpatches: ref([]),
required_imports: [],
raw_resolutions: ref(StringSet.empty),
global_import_resolutions,
func_import_resolutions,
compilation_mode: Normal,
Expand Down Expand Up @@ -2782,6 +2786,14 @@ and compile_instr = (wasm_mod, env, instr) =>
compiled_args,
Type.create(Array.of_list(List.map(wasm_type, retty))),
);
} else if (StringSet.mem(func_name, env.raw_resolutions^)) {
// Deduplicated imports; call resolved name directly
Expression.Call.make(
wasm_mod,
resolved_name,
compiled_args,
Type.create(Array.of_list(List.map(wasm_type, retty))),
);
} else {
// Raw function resolved to Grain function; inject closure argument
let closure_global = resolve_global(~env, func_name);
Expand Down Expand Up @@ -3070,7 +3082,7 @@ let compute_table_size = (env, {function_table_elements}) => {
List.length(function_table_elements);
};

let compile_imports = (wasm_mod, env, {imports}) => {
let compile_imports = (wasm_mod, env, {imports}, import_map) => {
let compile_module_name = name =>
fun
| MImportWasm => name
Expand All @@ -3081,7 +3093,6 @@ let compile_imports = (wasm_mod, env, {imports}) => {
| (MImportGrain, MGlobalImport(_)) => "GRAIN$EXPORT$" ++ name
| _ => name
};

let compile_import = ({mimp_id, mimp_mod, mimp_name, mimp_type, mimp_kind}) => {
let module_name = compile_module_name(mimp_mod, mimp_kind);
let item_name = compile_import_name(mimp_name, mimp_kind, mimp_type);
Expand All @@ -3090,37 +3101,52 @@ let compile_imports = (wasm_mod, env, {imports}) => {
| MImportGrain => get_grain_imported_name(mimp_mod, mimp_id)
| MImportWasm => Ident.unique_name(mimp_id)
};
switch (mimp_kind, mimp_type) {
| (MImportGrain, MGlobalImport(ty, mut)) =>
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
wasm_type(ty),
mut,
)
| (_, MFuncImport(args, ret)) =>
let proc_list = l =>
Type.create @@ Array.of_list @@ List.map(wasm_type, l);
Import.add_function_import(
wasm_mod,
internal_name,
module_name,
item_name,
proc_list(args),
proc_list(ret),
);
| (_, MGlobalImport(typ, mut)) =>
let typ = wasm_type(typ);
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
typ,
mut,
);
let imp_key = (module_name, item_name, mimp_kind, mimp_type);
switch (Hashtbl.find_opt(import_map, imp_key)) {
| Some(name) when mimp_kind == MImportWasm =>
// Deduplicate wasm imports by resolving them to the previously imported name
let linked_name = linked_name(~env, internal_name);
switch (mimp_type) {
| MFuncImport(_, _) =>
Hashtbl.add(env.func_import_resolutions, linked_name, name)
| MGlobalImport(_, _) =>
Hashtbl.add(env.global_import_resolutions, linked_name, name)
};
env.raw_resolutions := StringSet.add(linked_name, env.raw_resolutions^);
| _ =>
Hashtbl.add(import_map, imp_key, internal_name);
switch (mimp_kind, mimp_type) {
| (MImportGrain, MGlobalImport(ty, mut)) =>
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
wasm_type(ty),
mut,
)
| (_, MFuncImport(args, ret)) =>
let proc_list = l =>
Type.create @@ Array.of_list @@ List.map(wasm_type, l);
Import.add_function_import(
wasm_mod,
internal_name,
module_name,
item_name,
proc_list(args),
proc_list(ret),
);
| (_, MGlobalImport(typ, mut)) =>
let typ = wasm_type(typ);
Import.add_global_import(
wasm_mod,
internal_name,
module_name,
item_name,
typ,
mut,
);
};
};
};

Expand Down Expand Up @@ -3472,12 +3498,14 @@ let compile_wasm_module =
Type.funcref,
);

let import_map = Hashtbl.create(10);

let compile_one = (dep_id, prog: mash_code) => {
let env = {...env, dep_id, compilation_mode: prog.compilation_mode};
ignore @@ compile_imports(wasm_mod, env, prog, import_map);
ignore @@ compile_globals(wasm_mod, env, prog);
ignore @@ compile_functions(wasm_mod, env, prog);
ignore @@ compile_exports(wasm_mod, env, prog);
ignore @@ compile_imports(wasm_mod, env, prog);
ignore @@ compile_tables(wasm_mod, env, prog);
};

Expand All @@ -3499,7 +3527,10 @@ let compile_wasm_module =

validate_module(~name?, wasm_mod);

Optimize_mod.optimize(wasm_mod);
switch (Config.profile^) {
| Some(Release) => Optimize_mod.optimize(wasm_mod)
| None => ()
};
wasm_mod;
};

Expand Down
6 changes: 1 addition & 5 deletions compiler/src/codegen/linkedtree.re
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ let link = main_mashtree => {
(resolved_module, import.mimp_name),
);
let import_name =
Printf.sprintf(
"%s_%d",
Ident.unique_name(import.mimp_id),
dep_id^,
);
internal_name(Ident.unique_name(import.mimp_id), dep_id^);
Option.iter(
global =>
Hashtbl.add(global_import_resolutions, import_name, global),
Expand Down
42 changes: 19 additions & 23 deletions compiler/src/codegen/optimize_mod.re
Original file line number Diff line number Diff line change
Expand Up @@ -227,29 +227,25 @@ let optimize =
~shrink_level=default_shrink_level,
wasm_mod,
) => {
let passes =
switch (Config.profile^) {
| Some(Release) =>
List.concat([
default_global_optimization_pre_passes(
~optimize_level,
~shrink_level,
wasm_mod,
),
default_function_optimization_passes(
~optimize_level,
~shrink_level,
wasm_mod,
),
default_global_optimization_post_passes(
~optimize_level,
~shrink_level,
wasm_mod,
),
])
| None => [Passes.duplicate_import_elimination]
};
// Translation of https://github.com/WebAssembly/binaryen/blob/version_107/src/passes/pass.cpp#L441-L445
let default_optimizations_passes =
List.concat([
default_global_optimization_pre_passes(
~optimize_level,
~shrink_level,
wasm_mod,
),
default_function_optimization_passes(
~optimize_level,
~shrink_level,
wasm_mod,
),
default_global_optimization_post_passes(
~optimize_level,
~shrink_level,
wasm_mod,
),
]);

Module.run_passes(wasm_mod, passes);
Module.run_passes(wasm_mod, default_optimizations_passes);
};
2 changes: 1 addition & 1 deletion compiler/test/suites/basic_functionality.re
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,6 @@ describe("basic functionality", ({test, testSkip}) => {
~config_fn=smallestFileConfig,
"smallest_grain_program",
"",
6507,
6540,
);
});

0 comments on commit 2691ae7

Please sign in to comment.