Skip to content

Commit

Permalink
Collect referenced modules and init them together (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Oct 6, 2024
1 parent ac333bb commit 5943fa0
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 50 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ end

- run the benchmarks of sorting algorithms
```sh
mix run bench/sort.exs
mix run bench/sort_benchmark.exs
mix run bench/list_add_benchmark.exs
```
2 changes: 1 addition & 1 deletion bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ENIFQuickSort do
@moduledoc false
use Charms, init: false
use Charms
alias Charms.{Pointer, Term, Env}

defm swap(a :: Pointer.t(), b :: Pointer.t()) do
Expand Down
2 changes: 1 addition & 1 deletion bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ENIFTimSort do
@moduledoc false
use Charms, init: false
use Charms
alias Charms.{Pointer, Term, Env}

defm insertion_sort(arr :: Pointer.t(), left :: i32(), right :: i32()) do
Expand Down
5 changes: 0 additions & 5 deletions bench/list_add_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
mod = AddTwoIntVec
Charms.JIT.init(mod)

a = b = Enum.to_list(1..10)
AddTwoIntVec.add(a, b, :err_msg)
AddTwoIntVec.dummy_load_no_make(a, b, :err_msg)
Expand All @@ -26,5 +23,3 @@ Benchee.run(
{a, b}
end
)

Charms.JIT.destroy(mod)
7 changes: 0 additions & 7 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
Charms.JIT.init(ENIFQuickSort)
Charms.JIT.init([ENIFTimSort, ENIFMergeSort])

arr = Enum.to_list(1..10000) |> Enum.shuffle()
ENIFQuickSort.sort(arr, :arg_err)
ENIFMergeSort.sort(arr, :arg_err)
Expand All @@ -23,7 +20,3 @@ Benchee.run(
Enum.to_list(1..i) |> Enum.shuffle()
end
)

Charms.JIT.destroy(ENIFMergeSort)
Charms.JIT.destroy(ENIFQuickSort)
Charms.JIT.destroy(ENIFTimSort)
2 changes: 1 addition & 1 deletion bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defmodule AddTwoIntVec do
use Charms, init: false
use Charms
alias Charms.{SIMD, Term, Pointer}

defm load_list(env, l :: Term.t()) :: SIMD.t(i32(), 8) do
Expand Down
13 changes: 12 additions & 1 deletion lib/charms.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,21 @@ defmodule Charms do

defmacro __before_compile__(_env) do
quote do
@ir @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
{ir, referenced_modules} = @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
@ir ir
@referenced_modules referenced_modules

@doc false
def __ir__ do
@ir
end

@doc false
def referenced_modules do
@referenced_modules
end

defoverridable referenced_modules: 0
end
end

Expand Down
34 changes: 33 additions & 1 deletion lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ defmodule Charms.Defm do
:ok
end

defp referenced_modules(module) do
Beaver.Walker.postwalk(module, MapSet.new(), fn
%MLIR.Operation{} = op, acc ->
with "func.call" <- MLIR.Operation.name(op),
callee when not is_nil(callee) <- Beaver.Walker.attributes(op)["callee"] do
case callee |> to_string do
"@Elixir." <> _ = name ->
acc |> MapSet.put(extract_mangled_mod(name))

_ ->
acc
end
|> then(&{op, &1})
else
_ ->
{op, acc}
end

ir, acc ->
{ir, acc}
end)
|> then(fn {_, acc} -> MapSet.to_list(acc) end)
end

@doc false
def compile_definitions(definitions) do
import MLIR.Transforms
Expand Down Expand Up @@ -186,12 +210,20 @@ defmodule Charms.Defm do
|> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> canonicalize
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.to_string(bytecode: true)
|> then(&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1)})
|> tap(fn _ -> MLIR.Context.destroy(ctx) end)
end

@doc false
def mangling(mod, func) do
Module.concat(mod, func)
end

defp extract_mangled_mod("@" <> name) do
name
|> String.split(".")
|> then(&Enum.take(&1, length(&1) - 1))
|> Enum.join(".")
|> String.to_atom()
end
end
52 changes: 31 additions & 21 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,46 @@ defmodule Charms.JIT do
|> then(&{:ok, &1})
end

def init(module, opts \\ [])
defp collect_modules(module, acc \\ [])

def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do
init(module, opts)
end
defp collect_modules(module, acc) when is_atom(module) do
if module in acc do
acc
else
acc = [module | acc]

def init(module, opts) when is_atom(module) do
name = opts[:name] || module
opts = Keyword.put_new(opts, :name, name)
init([module], opts)
module.referenced_modules()
|> Enum.reduce(acc, fn m, acc ->
collect_modules(m, acc)
end)
end
end

def init(modules, opts) do
modules = modules |> List.wrap()
defp collect_modules(module, acc), do: [module | acc]

case {opts[:name], modules} do
{name, [_]} when not is_nil(name) ->
__MODULE__.LockedCache.run(name, fn -> do_init(modules) end)
def init(module, opts \\ [])

{nil, modules} when modules != [] ->
[key | tail] = modules
{:ok, jit} = __MODULE__.LockedCache.run(key, fn -> do_init(modules) end)
def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do
init(module, opts)
end

for module <- tail,
do:
__MODULE__.LockedCache.run(module, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
def init(module, opts) do
name = opts[:name] || module

{name, modules} when not is_nil(name) and is_list(modules) ->
__MODULE__.LockedCache.run(name, fn -> do_init(modules) end)
{modules, jit} =
__MODULE__.LockedCache.run(name, fn ->
modules = collect_modules(module)
{:ok, jit} = do_init(modules)
{modules, jit}
end)

# modules will be nil if cache is hit
for m when is_atom(module) <- modules || [],
module != m do
__MODULE__.LockedCache.run(m, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
end

{:ok, jit}
end

@doc """
Expand Down
1 change: 0 additions & 1 deletion test/const_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ defmodule ConstTest do
one = const 1.0 :: unranked_tensor(f64())
end
end
|> Charms.JIT.init()
end
end
end
12 changes: 5 additions & 7 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ defmodule DefmTest do
assert Charms.JIT.invoke(engine, {AddTwoInt, :add, [1, "", :arg_err]}) == :arg_err
assert Charms.JIT.invoke(engine, &AddTwoInt.add/3, [1, 2, :arg_err]) == 3
assert Charms.JIT.invoke(engine, &AddTwoInt.add/3, [1, "", :arg_err]) == :arg_err
:ok = Charms.JIT.destroy(:add_int)
assert :ok = Charms.JIT.destroy(:add_int)

Charms.JIT.init(AddTwoInt)
assert AddTwoInt.add(1, 2, :arg_err) == 3
Charms.JIT.destroy(AddTwoInt)
assert :ok = Charms.JIT.destroy(AddTwoInt)
end

test "quick sort" do
Charms.JIT.init(ENIFQuickSort)
Charms.JIT.init([ENIFTimSort, ENIFMergeSort])
assert ENIFQuickSort.sort(:what, :arg_err) == :arg_err
arr = [5, 4, 3, 2, 1]
assert ENIFQuickSort.sort(arr, :arg_err) == Enum.sort(arr)
Expand All @@ -59,8 +57,8 @@ defmodule DefmTest do
assert ENIFMergeSort.sort(arr, :arg_err) == Enum.sort(arr)
end

:ok = Charms.JIT.destroy(ENIFQuickSort)
:noop = Charms.JIT.destroy(ENIFMergeSort)
:ok = Charms.JIT.destroy(ENIFTimSort)
assert :ok = Charms.JIT.destroy(ENIFQuickSort)
assert :noop = Charms.JIT.destroy(ENIFMergeSort)
assert :ok = Charms.JIT.destroy(ENIFTimSort)
end
end
2 changes: 1 addition & 1 deletion test/expander_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ defmodule POCTest do
{:ok, %Charms.JIT{}} = Charms.JIT.init(m, name: :return_this)
engine = Charms.JIT.engine(:return_this)
assert Charms.JIT.invoke(engine, {ReturnPassedArg, :bar, [:identical]}) == :identical
:ok = Charms.JIT.destroy(:return_this)
assert :ok = Charms.JIT.destroy(:return_this)
end)
end

Expand Down
2 changes: 1 addition & 1 deletion test/macro_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ defmodule MacroTest do
test "expand macro" do
assert 100 == CallMacroMod.term_roundtrip1(100)
assert 200 == CallMacroMod.term_roundtrip2(200)
Charms.JIT.destroy(CallMacroMod)
assert :ok = Charms.JIT.destroy(CallMacroMod)
end
end
1 change: 0 additions & 1 deletion test/vec_add_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ defmodule VecAddTest do
use ExUnit.Case, async: true

test "vec add" do
{:ok, _} = Charms.JIT.init(AddTwoIntVec)
a = 1..8 |> Enum.to_list()
b = List.duplicate(1, 8)
assert AddTwoIntVec.add(a, b, :err) == Enum.to_list(2..9)
Expand Down

0 comments on commit 5943fa0

Please sign in to comment.