Skip to content

Commit

Permalink
Compile pointer to memref (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Dec 19, 2024
1 parent 37abfda commit 5ab98a7
Show file tree
Hide file tree
Showing 17 changed files with 446 additions and 286 deletions.
5 changes: 2 additions & 3 deletions bench/enif_merge_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ defmodule ENIFMergeSort do
use Charms
alias Charms.{Pointer, Term}

defm do_sort(arr :: Pointer.t(), l :: i32(), r :: i32()) do
defm do_sort(arr :: Pointer.t(Term.t()), l :: i32(), r :: i32()) do
if l < r do
two = const 2 :: i32()
m = op arith.divsi(l + r, two) :: i32()
m = result_at(m, 0)
m = value arith.divsi(l + r, two) :: i32()
do_sort(arr, l, m)
do_sort(arr, m + 1, r)
SortUtil.merge(arr, l, m, r)
Expand Down
34 changes: 15 additions & 19 deletions bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,34 @@ defmodule ENIFQuickSort do
use Charms
alias Charms.{Pointer, Term}

defm swap(a :: Pointer.t(), b :: Pointer.t()) do
tmp = Pointer.allocate(Term.t())
val_a = Pointer.load(Term.t(), a)
val_b = Pointer.load(Term.t(), b)
Pointer.store(val_b, tmp)
defm swap(a :: Pointer.t(Term.t()), b :: Pointer.t(Term.t())) do
val_a = Pointer.load(a)
val_b = Pointer.load(b)
Pointer.store(val_b, a)
Pointer.store(val_a, b)
val_tmp = Pointer.load(Term.t(), tmp)
Pointer.store(val_tmp, a)
end

defm partition(arr :: Pointer.t(), low :: i32(), high :: i32()) :: i32() do
pivot_ptr = Pointer.element_ptr(Term.t(), arr, high)
pivot = Pointer.load(Term.t(), pivot_ptr)
defm partition(arr :: Pointer.t(Term.t()), low :: i32(), high :: i32()) :: i32() do
pivot_ptr = Pointer.element_ptr(arr, high)
pivot = Pointer.load(pivot_ptr)
i_ptr = Pointer.allocate(i32())
Pointer.store(low - 1, i_ptr)
start = Pointer.element_ptr(Term.t(), arr, low)
start = Pointer.element_ptr(arr, low)

for_loop {element, j} <- {Term.t(), start, high - low} do
for_loop {element, j} <- {start, high - low} do
if enif_compare(element, pivot) < 0 do
i = Pointer.load(i32(), i_ptr) + 1
i = Pointer.load(i_ptr) + 1
Pointer.store(i, i_ptr)
j = value index.casts(j) :: i32()
swap(Pointer.element_ptr(Term.t(), arr, i), Pointer.element_ptr(Term.t(), start, j))
swap(Pointer.element_ptr(arr, i), Pointer.element_ptr(start, j))
end
end

i = Pointer.load(i32(), i_ptr)
swap(Pointer.element_ptr(Term.t(), arr, i + 1), Pointer.element_ptr(Term.t(), arr, high))
i = Pointer.load(i_ptr)
swap(Pointer.element_ptr(arr, i + 1), Pointer.element_ptr(arr, high))
func.return(i + 1)
end

defm do_sort(arr :: Pointer.t(), low :: i32(), high :: i32()) do
defm do_sort(arr :: Pointer.t(Term.t()), low :: i32(), high :: i32()) do
if low < high do
pi = partition(arr, low, high)
do_sort(arr, low, pi - 1)
Expand All @@ -49,7 +45,7 @@ defmodule ENIFQuickSort do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
len = Pointer.load(len_ptr)
arr = Pointer.allocate(Term.t(), len)
SortUtil.copy_terms(env, movable_list_ptr, arr)
zero = const 0 :: i32()
Expand Down
16 changes: 8 additions & 8 deletions bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,38 @@ defmodule ENIFTimSort do
use Charms
alias Charms.{Pointer, Term}

defm insertion_sort(arr :: Pointer.t(), left :: i32(), right :: i32()) do
defm insertion_sort(arr :: Pointer.t(Term.t()), left :: i32(), right :: i32()) do
start_i = left + 1
start = Pointer.element_ptr(Term.t(), arr, start_i)
start = Pointer.element_ptr(arr, start_i)
n = right - start_i + 1

for_loop {temp, i} <- {Term.t(), start, n} do
for_loop {temp, i} <- {start, n} do
i = value index.casts(i) :: i32()
i = i + start_i
j_ptr = Pointer.allocate(i32())
Pointer.store(i - 1, j_ptr)

while(
Pointer.load(i32(), j_ptr) >= left &&
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), arr, Pointer.load(i32(), j_ptr))) >
Pointer.load(Pointer.element_ptr(arr, Pointer.load(i32(), j_ptr))) >
temp
) do
j = Pointer.load(i32(), j_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), arr, j)),
Pointer.element_ptr(Term.t(), arr, j + 1)
Pointer.load(Pointer.element_ptr(arr, j)),
Pointer.element_ptr(arr, j + 1)
)

Pointer.store(j - 1, j_ptr)
end

j = Pointer.load(i32(), j_ptr)
Pointer.store(temp, Pointer.element_ptr(Term.t(), arr, j + 1))
Pointer.store(temp, Pointer.element_ptr(arr, j + 1))
end
end

defm tim_sort(arr :: Pointer.t(), n :: i32()) do
defm tim_sort(arr :: Pointer.t(Term.t()), n :: i32()) do
run = const 32 :: i32()
i_ptr = Pointer.allocate(i32())
zero = const 0 :: i32()
Expand Down
8 changes: 2 additions & 6 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ Benchee.run(
"enif_merge_sort" => &ENIFMergeSort.sort(&1),
"enif_tim_sort" => &ENIFTimSort.sort(&1)
},
inputs: %{
"array size 10" => 10,
"array size 100" => 100,
"array size 1000" => 1000,
"array size 10000" => 10000
},
parallel: 2,
inputs: [10, 100, 1000, 10000] |> Enum.map(&{"array size #{&1}", &1}) |> Enum.into(%{}),
before_scenario: fn i ->
Enum.to_list(1..i) |> Enum.shuffle()
end
Expand Down
45 changes: 21 additions & 24 deletions bench/sort_util.ex
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
defmodule SortUtil do
@moduledoc false
use Charms
alias Charms.{Pointer, Term}

defm copy_terms(env, movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
defm copy_terms(env, movable_list_ptr :: Pointer.t(Term.t()), arr :: Pointer.t(Term.t())) do
head = Pointer.allocate(Term.t())
zero = const 0 :: i32()
i_ptr = Pointer.allocate(i32())
Expand All @@ -11,36 +12,32 @@ defmodule SortUtil do
while(
enif_get_list_cell(
env,
Pointer.load(Term.t(), movable_list_ptr),
Pointer.load(movable_list_ptr),
head,
movable_list_ptr
) > 0
) do
head_val = Pointer.load(Term.t(), head)
i = Pointer.load(i32(), i_ptr)
ith_term_ptr = Pointer.element_ptr(Term.t(), arr, i)
head_val = Pointer.load(head)
i = Pointer.load(i_ptr)
ith_term_ptr = Pointer.element_ptr(arr, i)
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end
end

defm merge(arr :: Pointer.t(), l :: i32(), m :: i32(), r :: i32()) do
defm merge(arr :: Pointer.t(Term.t()), l :: i32(), m :: i32(), r :: i32()) do
n1 = m - l + 1
n2 = r - m

left_temp = Pointer.allocate(Term.t(), n1)
right_temp = Pointer.allocate(Term.t(), n2)

for_loop {element, i} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, l), n1} do
i = op index.casts(i) :: i32()
i = result_at(i, 0)
Pointer.store(element, Pointer.element_ptr(Term.t(), left_temp, i))
for_loop {element, i} <- {Pointer.element_ptr(arr, l), n1} do
Pointer.store(element, Pointer.element_ptr(left_temp, i))
end

for_loop {element, j} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, m + 1), n2} do
j = op index.casts(j) :: i32()
j = result_at(j, 0)
Pointer.store(element, Pointer.element_ptr(Term.t(), right_temp, j))
for_loop {element, j} <- {Pointer.element_ptr(arr, m + 1), n2} do
Pointer.store(element, Pointer.element_ptr(right_temp, j))
end

i_ptr = Pointer.allocate(i32())
Expand All @@ -57,20 +54,20 @@ defmodule SortUtil do
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)

left_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i))
right_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j))
left_term = Pointer.load(Term.t(), Pointer.element_ptr(left_temp, i))
right_term = Pointer.load(Term.t(), Pointer.element_ptr(right_temp, j))

if enif_compare(left_term, right_term) <= 0 do
Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)),
Pointer.element_ptr(Term.t(), arr, k)
Pointer.load(Term.t(), Pointer.element_ptr(left_temp, i)),
Pointer.element_ptr(arr, k)
)

Pointer.store(i + 1, i_ptr)
else
Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)),
Pointer.element_ptr(Term.t(), arr, k)
Pointer.load(Term.t(), Pointer.element_ptr(right_temp, j)),
Pointer.element_ptr(arr, k)
)

Pointer.store(j + 1, j_ptr)
Expand All @@ -84,8 +81,8 @@ defmodule SortUtil do
k = Pointer.load(i32(), k_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)),
Pointer.element_ptr(Term.t(), arr, k)
Pointer.load(Term.t(), Pointer.element_ptr(left_temp, i)),
Pointer.element_ptr(arr, k)
)

Pointer.store(i + 1, i_ptr)
Expand All @@ -97,8 +94,8 @@ defmodule SortUtil do
k = Pointer.load(i32(), k_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)),
Pointer.element_ptr(Term.t(), arr, k)
Pointer.load(Term.t(), Pointer.element_ptr(right_temp, j)),
Pointer.element_ptr(arr, k)
)

Pointer.store(j + 1, j_ptr)
Expand Down
10 changes: 4 additions & 6 deletions bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ defmodule AddTwoIntVec do

defm load_list(env, l :: Term.t()) :: SIMD.t(i32(), 8) do
i_ptr = Pointer.allocate(i32())
# TODO: remove the const here, when pointer's type can be inferred
Pointer.store(const(0 :: i32()), i_ptr)
zero = const 0 :: Pointer.element_type(i_ptr)
Pointer.store(zero, i_ptr)
init = SIMD.new(SIMD.t(i32(), 8), [0, 0, 0, 0, 0, 0, 0, 0])

Enum.reduce(l, init, fn x, acc ->
v_ptr = Pointer.allocate(i32())
enif_get_int(env, x, v_ptr)
i = Pointer.load(i32(), i_ptr)
i = Pointer.load(i_ptr)
Pointer.store(i + 1, i_ptr)

Pointer.load(i32(), v_ptr)
|> vector.insertelement(acc, i)
Pointer.load(v_ptr) |> vector.insertelement(acc, i)
end)
end

Expand Down
12 changes: 11 additions & 1 deletion lib/charms.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ defmodule Charms do

defmacro __before_compile__(env) do
defm_decls = Module.get_attribute(env.module, :defm) || []
{ir, referenced_modules} = defm_decls |> Enum.reverse() |> Charms.Defm.Definition.compile()

{ir, referenced_modules, required_intrinsic_modules} =
defm_decls |> Enum.reverse() |> Charms.Defm.Definition.compile()

# create uses in Elixir, to disallow loop reference
r =
Expand All @@ -58,10 +60,18 @@ defmodule Charms do
end
end

i =
for r <- required_intrinsic_modules, r != env.module do
quote do
unquote(r).__use_intrinsic__
end
end

quote do
@ir unquote(ir)
@referenced_modules unquote(referenced_modules)
unquote_splicing(r)
unquote_splicing(i)

@ir_hash [
:erlang.phash2(@ir)
Expand Down
5 changes: 4 additions & 1 deletion lib/charms/constant.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ defmodule Charms.Constant do

true ->
loc = Beaver.Deferred.create(loc, ctx)
raise CompileError, Charms.Diagnostic.meta_from_loc(loc) ++ [description: "Not a supported type for constant, #{to_string(t)}"]

raise CompileError,
Charms.Diagnostic.meta_from_loc(loc) ++
[description: "Not a supported type for constant, #{to_string(t)}"]
end
end
end
Expand Down
5 changes: 0 additions & 5 deletions lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ defmodule Charms.Defm do
"""
defmacro value(_expr), do: :implemented_in_expander

@doc """
syntax sugar to create an MLIR value from an Elixir value
"""
defmacro const(_), do: :implemented_in_expander

@doc """
call a local function with return
"""
Expand Down
36 changes: 22 additions & 14 deletions lib/charms/defm/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,17 @@ defmodule Charms.Defm.Definition do

mlir_expander = %Charms.Defm.Expander{mlir_expander | return_types: return_types}

for %__MODULE__{env: env, call: call, ret_types: ret_types, body: body} <- definitions do
quote(do: unquote(call) :: unquote(ret_types))
|> then(&quote(do: defm(unquote(&1), unquote(body))))
|> Charms.Defm.Expander.expand_to_mlir(env, mlir_expander)
end
required_intrinsic_modules =
for %__MODULE__{env: env, call: call, ret_types: ret_types, body: body} <- definitions,
reduce: MapSet.new() do
required_intrinsic_modules ->
{_, state, _} =
quote(do: unquote(call) :: unquote(ret_types))
|> then(&quote(do: defm(unquote(&1), unquote(body))))
|> Charms.Defm.Expander.expand_to_mlir(env, mlir_expander)

MapSet.union(state.mlir.required_intrinsic_modules, required_intrinsic_modules)
end
end

m
Expand All @@ -284,7 +290,9 @@ defmodule Charms.Defm.Definition do
raise_compile_error(__ENV__, msg)
end
end)
|> then(&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1)})
|> then(
&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1), required_intrinsic_modules}
)
end

@doc """
Expand All @@ -307,22 +315,22 @@ defmodule Charms.Defm.Definition do
{:ok, do_compile(ctx, definitions)}
rescue
err ->
{:error, err}
{:error, err, __STACKTRACE__}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

case {res, msg} do
{{:ok, {mlir, mods}}, nil} ->
MLIR.Context.destroy(ctx)
{mlir, mods}
{{:ok, {mlir, mods, i_mods}}, nil} ->
# MLIR.Context.destroy(ctx)
{mlir, mods, i_mods}

{_, {:ok, d_msg}} ->
raise CompileError, d_msg
{{:error, _, st}, {:ok, d_msg}} ->
reraise CompileError, d_msg, st

{{:error, err}, _} ->
raise err
{{:error, err, st}, _} ->
reraise err, st
end
end

Expand Down
Loading

0 comments on commit 5ab98a7

Please sign in to comment.