From 5ab98a7ed7fd76ba8bc61759ac85519b139797d9 Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Thu, 19 Dec 2024 22:08:32 +0800 Subject: [PATCH] Compile pointer to memref (#57) --- bench/enif_merge_sort.ex | 5 +- bench/enif_quick_sort.ex | 34 +++---- bench/enif_tim_sort.ex | 16 +-- bench/sort_benchmark.exs | 8 +- bench/sort_util.ex | 45 ++++----- bench/vec_add_int_list.ex | 10 +- lib/charms.ex | 12 ++- lib/charms/constant.ex | 5 +- lib/charms/defm.ex | 5 - lib/charms/defm/definition.ex | 36 ++++--- lib/charms/defm/expander.ex | 73 ++++++-------- lib/charms/intrinsic.ex | 3 +- lib/charms/jit.ex | 27 +---- lib/charms/pointer.ex | 181 +++++++++++++++++++++++++++++----- lib/charms/prelude.ex | 111 +++++++++++++++------ test/add_two_int_test.exs | 59 +++++++++++ test/defm_test.exs | 102 ++++++------------- 17 files changed, 446 insertions(+), 286 deletions(-) create mode 100644 test/add_two_int_test.exs diff --git a/bench/enif_merge_sort.ex b/bench/enif_merge_sort.ex index 0ea52ad..34774a1 100644 --- a/bench/enif_merge_sort.ex +++ b/bench/enif_merge_sort.ex @@ -3,11 +3,10 @@ defmodule ENIFMergeSort do use Charms alias Charms.{Pointer, Term} - defm do_sort(arr :: Pointer.t(), l :: i32(), r :: i32()) do + defm do_sort(arr :: Pointer.t(Term.t()), l :: i32(), r :: i32()) do if l < r do two = const 2 :: i32() - m = op arith.divsi(l + r, two) :: i32() - m = result_at(m, 0) + m = value arith.divsi(l + r, two) :: i32() do_sort(arr, l, m) do_sort(arr, m + 1, r) SortUtil.merge(arr, l, m, r) diff --git a/bench/enif_quick_sort.ex b/bench/enif_quick_sort.ex index f53850b..89d8b32 100644 --- a/bench/enif_quick_sort.ex +++ b/bench/enif_quick_sort.ex @@ -3,38 +3,34 @@ defmodule ENIFQuickSort do use Charms alias Charms.{Pointer, Term} - defm swap(a :: Pointer.t(), b :: Pointer.t()) do - tmp = Pointer.allocate(Term.t()) - val_a = Pointer.load(Term.t(), a) - val_b = Pointer.load(Term.t(), b) - Pointer.store(val_b, tmp) + defm swap(a :: Pointer.t(Term.t()), b :: Pointer.t(Term.t())) do + val_a = Pointer.load(a) + val_b = Pointer.load(b) + Pointer.store(val_b, a) Pointer.store(val_a, b) - val_tmp = Pointer.load(Term.t(), tmp) - Pointer.store(val_tmp, a) end - defm partition(arr :: Pointer.t(), low :: i32(), high :: i32()) :: i32() do - pivot_ptr = Pointer.element_ptr(Term.t(), arr, high) - pivot = Pointer.load(Term.t(), pivot_ptr) + defm partition(arr :: Pointer.t(Term.t()), low :: i32(), high :: i32()) :: i32() do + pivot_ptr = Pointer.element_ptr(arr, high) + pivot = Pointer.load(pivot_ptr) i_ptr = Pointer.allocate(i32()) Pointer.store(low - 1, i_ptr) - start = Pointer.element_ptr(Term.t(), arr, low) + start = Pointer.element_ptr(arr, low) - for_loop {element, j} <- {Term.t(), start, high - low} do + for_loop {element, j} <- {start, high - low} do if enif_compare(element, pivot) < 0 do - i = Pointer.load(i32(), i_ptr) + 1 + i = Pointer.load(i_ptr) + 1 Pointer.store(i, i_ptr) - j = value index.casts(j) :: i32() - swap(Pointer.element_ptr(Term.t(), arr, i), Pointer.element_ptr(Term.t(), start, j)) + swap(Pointer.element_ptr(arr, i), Pointer.element_ptr(start, j)) end end - i = Pointer.load(i32(), i_ptr) - swap(Pointer.element_ptr(Term.t(), arr, i + 1), Pointer.element_ptr(Term.t(), arr, high)) + i = Pointer.load(i_ptr) + swap(Pointer.element_ptr(arr, i + 1), Pointer.element_ptr(arr, high)) func.return(i + 1) end - defm do_sort(arr :: Pointer.t(), low :: i32(), high :: i32()) do + defm do_sort(arr :: Pointer.t(Term.t()), low :: i32(), high :: i32()) do if low < high do pi = partition(arr, low, high) do_sort(arr, low, pi - 1) @@ -49,7 +45,7 @@ defmodule ENIFQuickSort do if enif_get_list_length(env, list, len_ptr) != 0 do movable_list_ptr = Pointer.allocate(Term.t()) Pointer.store(list, movable_list_ptr) - len = Pointer.load(i32(), len_ptr) + len = Pointer.load(len_ptr) arr = Pointer.allocate(Term.t(), len) SortUtil.copy_terms(env, movable_list_ptr, arr) zero = const 0 :: i32() diff --git a/bench/enif_tim_sort.ex b/bench/enif_tim_sort.ex index 50267e5..6b16a33 100644 --- a/bench/enif_tim_sort.ex +++ b/bench/enif_tim_sort.ex @@ -3,12 +3,12 @@ defmodule ENIFTimSort do use Charms alias Charms.{Pointer, Term} - defm insertion_sort(arr :: Pointer.t(), left :: i32(), right :: i32()) do + defm insertion_sort(arr :: Pointer.t(Term.t()), left :: i32(), right :: i32()) do start_i = left + 1 - start = Pointer.element_ptr(Term.t(), arr, start_i) + start = Pointer.element_ptr(arr, start_i) n = right - start_i + 1 - for_loop {temp, i} <- {Term.t(), start, n} do + for_loop {temp, i} <- {start, n} do i = value index.casts(i) :: i32() i = i + start_i j_ptr = Pointer.allocate(i32()) @@ -16,25 +16,25 @@ defmodule ENIFTimSort do while( Pointer.load(i32(), j_ptr) >= left && - Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), arr, Pointer.load(i32(), j_ptr))) > + Pointer.load(Pointer.element_ptr(arr, Pointer.load(i32(), j_ptr))) > temp ) do j = Pointer.load(i32(), j_ptr) Pointer.store( - Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), arr, j)), - Pointer.element_ptr(Term.t(), arr, j + 1) + Pointer.load(Pointer.element_ptr(arr, j)), + Pointer.element_ptr(arr, j + 1) ) Pointer.store(j - 1, j_ptr) end j = Pointer.load(i32(), j_ptr) - Pointer.store(temp, Pointer.element_ptr(Term.t(), arr, j + 1)) + Pointer.store(temp, Pointer.element_ptr(arr, j + 1)) end end - defm tim_sort(arr :: Pointer.t(), n :: i32()) do + defm tim_sort(arr :: Pointer.t(Term.t()), n :: i32()) do run = const 32 :: i32() i_ptr = Pointer.allocate(i32()) zero = const 0 :: i32() diff --git a/bench/sort_benchmark.exs b/bench/sort_benchmark.exs index 08dd5eb..a2a28af 100644 --- a/bench/sort_benchmark.exs +++ b/bench/sort_benchmark.exs @@ -10,12 +10,8 @@ Benchee.run( "enif_merge_sort" => &ENIFMergeSort.sort(&1), "enif_tim_sort" => &ENIFTimSort.sort(&1) }, - inputs: %{ - "array size 10" => 10, - "array size 100" => 100, - "array size 1000" => 1000, - "array size 10000" => 10000 - }, + parallel: 2, + inputs: [10, 100, 1000, 10000] |> Enum.map(&{"array size #{&1}", &1}) |> Enum.into(%{}), before_scenario: fn i -> Enum.to_list(1..i) |> Enum.shuffle() end diff --git a/bench/sort_util.ex b/bench/sort_util.ex index c32711b..ab9c5c4 100644 --- a/bench/sort_util.ex +++ b/bench/sort_util.ex @@ -1,8 +1,9 @@ defmodule SortUtil do + @moduledoc false use Charms alias Charms.{Pointer, Term} - defm copy_terms(env, movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do + defm copy_terms(env, movable_list_ptr :: Pointer.t(Term.t()), arr :: Pointer.t(Term.t())) do head = Pointer.allocate(Term.t()) zero = const 0 :: i32() i_ptr = Pointer.allocate(i32()) @@ -11,36 +12,32 @@ defmodule SortUtil do while( enif_get_list_cell( env, - Pointer.load(Term.t(), movable_list_ptr), + Pointer.load(movable_list_ptr), head, movable_list_ptr ) > 0 ) do - head_val = Pointer.load(Term.t(), head) - i = Pointer.load(i32(), i_ptr) - ith_term_ptr = Pointer.element_ptr(Term.t(), arr, i) + head_val = Pointer.load(head) + i = Pointer.load(i_ptr) + ith_term_ptr = Pointer.element_ptr(arr, i) Pointer.store(head_val, ith_term_ptr) Pointer.store(i + 1, i_ptr) end end - defm merge(arr :: Pointer.t(), l :: i32(), m :: i32(), r :: i32()) do + defm merge(arr :: Pointer.t(Term.t()), l :: i32(), m :: i32(), r :: i32()) do n1 = m - l + 1 n2 = r - m left_temp = Pointer.allocate(Term.t(), n1) right_temp = Pointer.allocate(Term.t(), n2) - for_loop {element, i} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, l), n1} do - i = op index.casts(i) :: i32() - i = result_at(i, 0) - Pointer.store(element, Pointer.element_ptr(Term.t(), left_temp, i)) + for_loop {element, i} <- {Pointer.element_ptr(arr, l), n1} do + Pointer.store(element, Pointer.element_ptr(left_temp, i)) end - for_loop {element, j} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, m + 1), n2} do - j = op index.casts(j) :: i32() - j = result_at(j, 0) - Pointer.store(element, Pointer.element_ptr(Term.t(), right_temp, j)) + for_loop {element, j} <- {Pointer.element_ptr(arr, m + 1), n2} do + Pointer.store(element, Pointer.element_ptr(right_temp, j)) end i_ptr = Pointer.allocate(i32()) @@ -57,20 +54,20 @@ defmodule SortUtil do j = Pointer.load(i32(), j_ptr) k = Pointer.load(i32(), k_ptr) - left_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)) - right_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)) + left_term = Pointer.load(Term.t(), Pointer.element_ptr(left_temp, i)) + right_term = Pointer.load(Term.t(), Pointer.element_ptr(right_temp, j)) if enif_compare(left_term, right_term) <= 0 do Pointer.store( - Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)), - Pointer.element_ptr(Term.t(), arr, k) + Pointer.load(Term.t(), Pointer.element_ptr(left_temp, i)), + Pointer.element_ptr(arr, k) ) Pointer.store(i + 1, i_ptr) else Pointer.store( - Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)), - Pointer.element_ptr(Term.t(), arr, k) + Pointer.load(Term.t(), Pointer.element_ptr(right_temp, j)), + Pointer.element_ptr(arr, k) ) Pointer.store(j + 1, j_ptr) @@ -84,8 +81,8 @@ defmodule SortUtil do k = Pointer.load(i32(), k_ptr) Pointer.store( - Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)), - Pointer.element_ptr(Term.t(), arr, k) + Pointer.load(Term.t(), Pointer.element_ptr(left_temp, i)), + Pointer.element_ptr(arr, k) ) Pointer.store(i + 1, i_ptr) @@ -97,8 +94,8 @@ defmodule SortUtil do k = Pointer.load(i32(), k_ptr) Pointer.store( - Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)), - Pointer.element_ptr(Term.t(), arr, k) + Pointer.load(Term.t(), Pointer.element_ptr(right_temp, j)), + Pointer.element_ptr(arr, k) ) Pointer.store(j + 1, j_ptr) diff --git a/bench/vec_add_int_list.ex b/bench/vec_add_int_list.ex index 5f66b9a..4443d0c 100644 --- a/bench/vec_add_int_list.ex +++ b/bench/vec_add_int_list.ex @@ -5,18 +5,16 @@ defmodule AddTwoIntVec do defm load_list(env, l :: Term.t()) :: SIMD.t(i32(), 8) do i_ptr = Pointer.allocate(i32()) - # TODO: remove the const here, when pointer's type can be inferred - Pointer.store(const(0 :: i32()), i_ptr) + zero = const 0 :: Pointer.element_type(i_ptr) + Pointer.store(zero, i_ptr) init = SIMD.new(SIMD.t(i32(), 8), [0, 0, 0, 0, 0, 0, 0, 0]) Enum.reduce(l, init, fn x, acc -> v_ptr = Pointer.allocate(i32()) enif_get_int(env, x, v_ptr) - i = Pointer.load(i32(), i_ptr) + i = Pointer.load(i_ptr) Pointer.store(i + 1, i_ptr) - - Pointer.load(i32(), v_ptr) - |> vector.insertelement(acc, i) + Pointer.load(v_ptr) |> vector.insertelement(acc, i) end) end diff --git a/lib/charms.ex b/lib/charms.ex index 41de02e..075a8ea 100644 --- a/lib/charms.ex +++ b/lib/charms.ex @@ -48,7 +48,9 @@ defmodule Charms do defmacro __before_compile__(env) do defm_decls = Module.get_attribute(env.module, :defm) || [] - {ir, referenced_modules} = defm_decls |> Enum.reverse() |> Charms.Defm.Definition.compile() + + {ir, referenced_modules, required_intrinsic_modules} = + defm_decls |> Enum.reverse() |> Charms.Defm.Definition.compile() # create uses in Elixir, to disallow loop reference r = @@ -58,10 +60,18 @@ defmodule Charms do end end + i = + for r <- required_intrinsic_modules, r != env.module do + quote do + unquote(r).__use_intrinsic__ + end + end + quote do @ir unquote(ir) @referenced_modules unquote(referenced_modules) unquote_splicing(r) + unquote_splicing(i) @ir_hash [ :erlang.phash2(@ir) diff --git a/lib/charms/constant.ex b/lib/charms/constant.ex index 24d72ed..c7497e1 100644 --- a/lib/charms/constant.ex +++ b/lib/charms/constant.ex @@ -22,7 +22,10 @@ defmodule Charms.Constant do true -> loc = Beaver.Deferred.create(loc, ctx) - raise CompileError, Charms.Diagnostic.meta_from_loc(loc) ++ [description: "Not a supported type for constant, #{to_string(t)}"] + + raise CompileError, + Charms.Diagnostic.meta_from_loc(loc) ++ + [description: "Not a supported type for constant, #{to_string(t)}"] end end end diff --git a/lib/charms/defm.ex b/lib/charms/defm.ex index 672fed4..613eba0 100644 --- a/lib/charms/defm.ex +++ b/lib/charms/defm.ex @@ -21,11 +21,6 @@ defmodule Charms.Defm do """ defmacro value(_expr), do: :implemented_in_expander - @doc """ - syntax sugar to create an MLIR value from an Elixir value - """ - defmacro const(_), do: :implemented_in_expander - @doc """ call a local function with return """ diff --git a/lib/charms/defm/definition.ex b/lib/charms/defm/definition.ex index 09afbc5..b4c320d 100644 --- a/lib/charms/defm/definition.ex +++ b/lib/charms/defm/definition.ex @@ -259,11 +259,17 @@ defmodule Charms.Defm.Definition do mlir_expander = %Charms.Defm.Expander{mlir_expander | return_types: return_types} - for %__MODULE__{env: env, call: call, ret_types: ret_types, body: body} <- definitions do - quote(do: unquote(call) :: unquote(ret_types)) - |> then("e(do: defm(unquote(&1), unquote(body)))) - |> Charms.Defm.Expander.expand_to_mlir(env, mlir_expander) - end + required_intrinsic_modules = + for %__MODULE__{env: env, call: call, ret_types: ret_types, body: body} <- definitions, + reduce: MapSet.new() do + required_intrinsic_modules -> + {_, state, _} = + quote(do: unquote(call) :: unquote(ret_types)) + |> then("e(do: defm(unquote(&1), unquote(body)))) + |> Charms.Defm.Expander.expand_to_mlir(env, mlir_expander) + + MapSet.union(state.mlir.required_intrinsic_modules, required_intrinsic_modules) + end end m @@ -284,7 +290,9 @@ defmodule Charms.Defm.Definition do raise_compile_error(__ENV__, msg) end end) - |> then(&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1)}) + |> then( + &{MLIR.to_string(&1, bytecode: true), referenced_modules(&1), required_intrinsic_modules} + ) end @doc """ @@ -307,22 +315,22 @@ defmodule Charms.Defm.Definition do {:ok, do_compile(ctx, definitions)} rescue err -> - {:error, err} + {:error, err, __STACKTRACE__} end end, fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end ) case {res, msg} do - {{:ok, {mlir, mods}}, nil} -> - MLIR.Context.destroy(ctx) - {mlir, mods} + {{:ok, {mlir, mods, i_mods}}, nil} -> + # MLIR.Context.destroy(ctx) + {mlir, mods, i_mods} - {_, {:ok, d_msg}} -> - raise CompileError, d_msg + {{:error, _, st}, {:ok, d_msg}} -> + reraise CompileError, d_msg, st - {{:error, err}, _} -> - raise err + {{:error, err, st}, _} -> + reraise err, st end end diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index cecb3c8..45093a5 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -19,7 +19,7 @@ defmodule Charms.Defm.Expander do """ use Beaver alias MLIR.Attribute - alias MLIR.Dialect.{Func, CF, SCF, MemRef, Index, Arith, UB, LLVM} + alias MLIR.Dialect.{Func, CF, SCF, MemRef, Index, Arith, UB} require Func import Charms.Diagnostic, only: :macros # Define the environment we will use for expansion. @@ -34,6 +34,7 @@ defmodule Charms.Defm.Expander do region: nil, enif_env: nil, dependence_modules: Map.new(), + required_intrinsic_modules: MapSet.new(), return_types: Map.new() @env %{ @@ -466,6 +467,7 @@ defmodule Charms.Defm.Expander do defp expand_intrinsics(loc, module, intrinsic_impl, args, state, env) do {args, state, env} = expand(args, state, env) + state = update_in(state.mlir.required_intrinsic_modules, &MapSet.put(&1, module)) v = apply(module, intrinsic_impl, [ @@ -496,8 +498,7 @@ defmodule Charms.Defm.Expander do defp expand_magic_macros(loc, {module, fun, arity} = mfa, args, state, env) do cond do - Code.ensure_loaded?(module) and function_exported?(module, :__intrinsics__, 2) and - module.__intrinsics__(fun, arity) -> + export_intrinsics?(module, fun, arity) -> intrinsic_impl = module.__intrinsics__(fun, arity) expand_intrinsics(loc, module, intrinsic_impl, args, state, env) @@ -685,7 +686,6 @@ defmodule Charms.Defm.Expander do @intrinsics Charms.Kernel.macro_intrinsics() ++ Charms.Kernel.intrinsics() defp expand({fun, _meta, args}, state, env) when fun in @intrinsics do - {args, state, env} = expand(args, state, env) loc = MLIR.Location.from_env(env) try do @@ -781,6 +781,13 @@ defmodule Charms.Defm.Expander do end end + ## Auxiliary containers + defp expand({:"::", meta, [param, type]}, state, env) do + {param, state, env} = expand(param, state, env) + {type, state, env} = expand(type, state, env) + {{:"::", meta, [param, type]}, state, env} + end + ## Imported or local call defp expand({fun, meta, args}, state, env) when is_atom(fun) and is_list(args) do @@ -965,11 +972,14 @@ defmodule Charms.Defm.Expander do {arg_types, state, env} = arg_types |> expand(state, env) if i = Enum.find_index(arg_types, &(!is_struct(&1, MLIR.Type))) do - raise_compile_error(env, "invalid argument type ##{i + 1}") + raise_compile_error( + env, + "invalid argument type ##{i + 1}, #{inspect(Enum.at(arg_types, i))}" + ) end if Enum.find(List.wrap(ret_types), &(!is_struct(&1, MLIR.Type))) do - raise_compile_error(env, "invalid return type") + raise_compile_error(env, "invalid return type, #{inspect(ret_types)}") end ft = Type.function(arg_types, ret_types, ctx: Beaver.Env.context()) @@ -1112,10 +1122,11 @@ defmodule Charms.Defm.Expander do end defp expand_macro(_meta, Charms.Defm, :for_loop, [expr, [do: body]], _callback, state, env) do - {:<-, _, [{element, index}, {:{}, _, [t, ptr, len]}]} = expr + {:<-, _, [{element, index}, {ptr, len}]} = expr {len, state, env} = expand(len, state, env) - {t, state, env} = expand(t, state, env) {ptr, state, env} = expand(ptr, state, env) + t = MLIR.Value.type(ptr) |> MLIR.CAPI.mlirShapedTypeGetElementType() + loc = MLIR.Location.from_env(env) v = mlir ctx: state.mlir.ctx, blk: state.mlir.blk do @@ -1124,18 +1135,10 @@ defmodule Charms.Defm.Expander do upper_bound = Index.casts(len) >>> Type.index() step = Index.constant(value: Attribute.index(1)) >>> Type.index() - SCF.for [lower_bound, upper_bound, step] do + SCF.for [lower_bound, upper_bound, step, loc: loc] do region do block _body(index_val >>> Type.index()) do - index_casted = Index.casts(index_val) >>> Type.i64() - - element_ptr = - LLVM.getelementptr(ptr, index_casted, - elem_type: t, - rawConstantIndices: ~a{array} - ) >>> ~t{!llvm.ptr} - - element_val = LLVM.load(element_ptr) >>> t + element_val = MemRef.load(ptr, index_val, loc: loc) >>> t state = put_mlir_var(state, element, element_val) state = put_mlir_var(state, index, index_val) expand(body, put_in(state.mlir.blk, Beaver.Env.block()), env) @@ -1193,29 +1196,6 @@ defmodule Charms.Defm.Expander do expand_call_of_types(call, [], state, env) end - defp expand_macro( - meta, - Charms.Defm, - :const, - [{:"::", type_meta, [value, type]}], - _callback, - state, - env - ) do - env = %{env | line: type_meta[:line] || meta[:line] || env.line} - - {value, state, env} = expand(value, state, env) - {type, state, env} = expand(type, state, env) - - value = - mlir ctx: state.mlir.ctx, blk: state.mlir.blk do - loc = MLIR.Location.from_env(env) - Charms.Constant.from_literal(value, type, state.mlir.ctx, state.mlir.blk, loc) - end - - {value, state, env} - end - defp expand_macro(meta, module, fun, args, callback, state, env) do expand_macro_callback(meta, module, fun, args, callback, state, env) end @@ -1254,19 +1234,24 @@ defmodule Charms.Defm.Expander do ## Helpers + defp export_intrinsics?(module, fun, arity) do + match?({:module, _}, Code.ensure_compiled(module)) and Code.ensure_loaded?(module) and + function_exported?(module, :__intrinsics__, 2) and module.__intrinsics__(fun, arity) + end + defp expand_remote(_meta, module, fun, args, state, env) do # A compiler may want to emit a :remote_function trace in here. state = update_in(state.remotes, &[{module, fun, length(args)} | &1]) - {args, state, env} = expand_list(args, state, env) loc = MLIR.Location.from_env(env) cond do - Code.ensure_loaded?(module) and function_exported?(module, :__intrinsics__, 2) and - module.__intrinsics__(fun, length(args)) -> + export_intrinsics?(module, fun, length(args)) -> intrinsic_impl = module.__intrinsics__(fun, length(args)) expand_intrinsics(loc, module, intrinsic_impl, args, state, env) module in [MLIR.Type] -> + {args, state, env} = expand_list(args, state, env) + if fun in [:unranked_tensor, :complex, :vector] do args else diff --git a/lib/charms/intrinsic.ex b/lib/charms/intrinsic.ex index 183dc6e..faf3659 100644 --- a/lib/charms/intrinsic.ex +++ b/lib/charms/intrinsic.ex @@ -7,7 +7,7 @@ defmodule Charms.Intrinsic do end @moduledoc """ - Behaviour to define intrinsic functions. + Define intrinsic functions. """ alias Beaver @type ir_return :: MLIR.Value.t() | MLIR.Operation.t() @@ -20,6 +20,7 @@ defmodule Charms.Intrinsic do @before_compile Charms.Intrinsic import Charms.Intrinsic, only: :macros Module.register_attribute(__MODULE__, :intrinsic, accumulate: true) + def __use_intrinsic__, do: nil end end diff --git a/lib/charms/jit.ex b/lib/charms/jit.ex index b46a19d..8ae2ea5 100644 --- a/lib/charms/jit.ex +++ b/lib/charms/jit.ex @@ -20,8 +20,8 @@ defmodule Charms.JIT do |> convert_arith_to_llvm() |> convert_index_to_llvm() |> convert_func_to_llvm() - |> Beaver.Composer.append("convert-vector-to-llvm{reassociate-fp-reductions}") |> Beaver.Composer.append("finalize-memref-to-llvm") + |> Beaver.Composer.append("convert-vector-to-llvm{reassociate-fp-reductions}") |> reconcile_unrealized_casts |> Charms.Debug.print_ir_pass() |> Beaver.Composer.run!(print: Charms.Debug.step_print?()) @@ -82,30 +82,7 @@ defmodule Charms.JIT do raise ArgumentError, "Unexpected module type: #{inspect(other)}" end) |> then(fn op -> - {res, msg} = - MLIR.Context.with_diagnostics( - ctx, - fn -> - try do - {:ok, op |> merge_modules() |> jit_of_mod()} - rescue - err -> - {:error, err, __STACKTRACE__} - end - end, - fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end - ) - - case {res, msg} do - {{:ok, jit}, nil} -> - jit - - {{:error, _, st}, {:ok, d_msg}} -> - reraise CompileError, d_msg, st - - {{:error, err, st}, _} -> - reraise err, st - end + op |> merge_modules() |> jit_of_mod() end) |> then( &%__MODULE__{ diff --git a/lib/charms/pointer.ex b/lib/charms/pointer.ex index 4d6847d..0de5a64 100644 --- a/lib/charms/pointer.ex +++ b/lib/charms/pointer.ex @@ -1,12 +1,14 @@ defmodule Charms.Pointer do @moduledoc """ Intrinsic module to work with pointers. + + Charms.Pointer should be the "smart pointer" not just comes with lifetime management, but also SIMD and Tensor support. """ - alias Charms.Pointer + use Beaver use Charms.Intrinsic alias Charms.Intrinsic.Opts alias Beaver.MLIR.{Type} - alias Beaver.MLIR.Dialect.{LLVM} + alias Beaver.MLIR.Dialect.{MemRef, Index, Arith} @doc """ Allocates a single element of the given `elem_type`, returning a pointer to it. @@ -21,27 +23,32 @@ defmodule Charms.Pointer do Allocates an array of `size` elements of the given `elem_type`, returning a pointer to it. """ defintrinsic allocate(elem_type, size) do - %Opts{ctx: ctx} = __IR__ + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + + mlir ctx: ctx, blk: blk do + zero = Index.constant(value: Attribute.index(0)) >>> Type.index() - cast = case size do i when is_integer(i) -> - quote bind_quoted: [size: i] do - const size :: i64() - end + MemRef.alloca( + loc: loc, + operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([0, 0]) + ) >>> Type.memref([i], elem_type) %MLIR.Value{} -> - if MLIR.equal?(MLIR.Value.type(size), Type.i64(ctx: ctx)) do - size - else - quote bind_quoted: [size: size] do - value arith.extsi(size) :: i64() + size = + if Type.index?(MLIR.Value.type(size)) do + size + else + Index.casts(size, loc: loc) >>> Type.index() end - end - end - quote bind_quoted: [elem_type: elem_type, size: cast] do - value llvm.alloca(size, elem_type: elem_type) :: Pointer.t() + MemRef.alloca(size, + loc: loc, + operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0]) + ) >>> Type.memref([:dynamic], elem_type) + end + |> offset_ptr(elem_type, zero, ctx, blk, loc) end end @@ -49,8 +56,38 @@ defmodule Charms.Pointer do Loads a value of `type` from the given pointer `ptr`. """ defintrinsic load(type, ptr) do - quote bind_quoted: [type: type, ptr: ptr] do - value llvm.load(ptr) :: type + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + + if MLIR.equal?(MLIR.Value.type(ptr), ~t{!llvm.ptr}) do + quote bind_quoted: [type: type, ptr: ptr] do + value llvm.load(ptr) :: type + end + else + mlir ctx: ctx, blk: blk do + zero = Index.constant(value: Attribute.index(0), loc: loc) >>> Type.index() + MemRef.load(ptr, zero, loc: loc) >>> type + end + end + end + + @doc false + def memref_ptr?(%MLIR.Type{} = t) do + MLIR.CAPI.mlirTypeIsAMemRef(t) |> Beaver.Native.to_term() + end + + def memref_ptr?(%MLIR.Value{} = ptr) do + MLIR.Value.type(ptr) |> memref_ptr?() + end + + defintrinsic load(%MLIR.Value{} = ptr) do + t = MLIR.Value.type(ptr) + + if memref_ptr?(t) do + quote do + Charms.Pointer.load(unquote(MLIR.CAPI.mlirShapedTypeGetElementType(t)), unquote(ptr)) + end + else + raise ArgumentError, "Pointer is not typed, use load/2 to specify the pointer type" end end @@ -58,22 +95,107 @@ defmodule Charms.Pointer do Stores a value `val` at the given pointer `ptr`. """ defintrinsic store(val, ptr) do - quote bind_quoted: [val: val, ptr: ptr] do - llvm.store(val, ptr) + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + + mlir ctx: ctx, blk: blk do + zero = Index.constant(value: Attribute.index(0)) >>> Type.index() + MemRef.store(val, ptr, zero, loc: loc) >>> [] + end + end + + defp ptr_type(elem_type, ctx) do + layout = + MLIR.CAPI.mlirStridedLayoutAttrGet( + ctx, + MLIR.CAPI.mlirShapedTypeGetDynamicStrideOrOffset(), + 1, + Beaver.Native.array([1], Beaver.Native.I64) + ) + + Type.memref([:dynamic], elem_type, layout: layout, ctx: ctx) + end + + # cast ptr to a pointer of the given element type with offset + defp offset_ptr(ptr, elem_type, offset, ctx, blk, loc) do + mlir ctx: ctx, blk: blk do + d = MLIR.CAPI.mlirShapedTypeGetDynamicStrideOrOffset() |> Beaver.Native.to_term() + static_offsets_or_sizes = Attribute.dense_array([d], Beaver.Native.I64, ctx: ctx) + static_strides = Attribute.dense_array([1], Beaver.Native.I64, ctx: ctx) + + if MLIR.null?(static_offsets_or_sizes) do + raise ArgumentError, "Failed to create dense array" + end + + [_, offset_extracted, size, _stride] = + MemRef.extract_strided_metadata(ptr, loc: loc) >>> :infer + + offset = Arith.addi(offset_extracted, offset, loc: loc) >>> Type.index() + + MemRef.reinterpret_cast(ptr, offset, size, + operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 1, 1, 0]), + static_offsets: static_offsets_or_sizes, + static_sizes: static_offsets_or_sizes, + static_strides: static_strides, + loc: loc + ) >>> ptr_type(elem_type, ctx) + end + end + + defintrinsic element_ptr(%MLIR.Type{} = elem_type, ptr, n) do + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + + t = MLIR.Value.type(ptr) + elem_t = MLIR.CAPI.mlirShapedTypeGetElementType(t) + + if not MLIR.equal?(elem_t, elem_type) do + raise ArgumentError, + "Expected a pointer of type #{MLIR.to_string(elem_type)}, got #{MLIR.to_string(t)}" + end + + mlir ctx: ctx, blk: blk do + n = + case n do + i when is_integer(i) -> + Index.constant(value: Attribute.index(i)) >>> Type.index() + + %MLIR.Value{} -> + if Type.index?(MLIR.Value.type(n)) do + n + else + Index.casts(n, loc: loc) >>> Type.index() + end + end + + offset_ptr(ptr, elem_type, n, ctx, blk, loc) end end @doc """ Gets the element pointer of `elem_type` for the given base pointer `ptr` and index `n`. """ - defintrinsic element_ptr(elem_type, ptr, n) do - %Opts{ctx: ctx, blk: blk} = __IR__ + defintrinsic element_ptr(%MLIR.Value{} = ptr, n) do + t = MLIR.Value.type(ptr) - mlir ctx: ctx, blk: blk do - LLVM.getelementptr(ptr, n, - elem_type: elem_type, - rawConstantIndices: ~a{array} - ) >>> ~t{!llvm.ptr} + if memref_ptr?(t) do + quote do + Charms.Pointer.element_ptr( + unquote(MLIR.CAPI.mlirShapedTypeGetElementType(t)), + unquote(ptr), + unquote(n) + ) + end + else + raise ArgumentError, "Pointer is not typed, use element_ptr/3 to specify the pointer type" + end + end + + defintrinsic element_type(%MLIR.Value{} = ptr) do + t = MLIR.Value.type(ptr) + + if memref_ptr?(t) do + MLIR.CAPI.mlirShapedTypeGetElementType(t) + else + raise ArgumentError, "Pointer is not typed, element_type/1 expects a typed pointer" end end @@ -84,4 +206,9 @@ defmodule Charms.Pointer do %Opts{ctx: ctx} = __IR__ Beaver.Deferred.create(~t{!llvm.ptr}, ctx) end + + defintrinsic t(elem_t) do + %Opts{ctx: ctx} = __IR__ + ptr_type(elem_t, ctx) + end end diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index 4867adc..45a1851 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -4,26 +4,17 @@ defmodule Charms.Prelude do """ use Charms.Intrinsic alias Charms.Intrinsic.Opts - alias Beaver.MLIR.Dialect.{Arith, Func} + alias Beaver.MLIR.Dialect.{Arith, Func, LLVM, MemRef, Index} @enif_functions Beaver.ENIF.functions() - defp wrap_arg({i, t}, %Opts{ctx: ctx, blk: blk}) when is_integer(i) do + defp literal_to_constant(v, t, %Opts{ctx: ctx, blk: blk, loc: loc}) + when is_integer(v) or is_float(v) do mlir ctx: ctx, blk: blk do - case i do - %MLIR.Value{} -> - i - - i when is_integer(i) -> - if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do - Arith.constant(value: Attribute.integer(t, i)) >>> t - else - raise ArgumentError, "Not an integer type, #{to_string(t)}" - end - end + Charms.Constant.from_literal(v, t, ctx, blk, loc) end end - defp wrap_arg({v, _}, _) do + defp literal_to_constant(v, _, _) do v end @@ -45,6 +36,15 @@ defmodule Charms.Prelude do MLIR.Value.type(value) end + defintrinsic const(ast) do + {:"::", _type_meta, [value, type]} = ast + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + + mlir ctx: ctx, blk: blk do + Charms.Constant.from_literal(value, type, ctx, blk, loc) + end + end + @doc """ Dump the MLIR entity at compile time with `IO.puts/1` """ @@ -52,27 +52,78 @@ defmodule Charms.Prelude do entity |> tap(&IO.puts(MLIR.to_string(&1))) end + def extract_raw_pointer(arg, arg_type, %Opts{ctx: ctx, blk: blk, loc: loc}) do + mlir ctx: ctx, blk: blk do + t = MLIR.Value.type(arg) + + if MLIR.equal?(~t{!llvm.ptr}.(ctx), arg_type) and Charms.Pointer.memref_ptr?(arg) do + elem_t = MLIR.CAPI.mlirShapedTypeGetElementType(t) + + width = + cond do + MLIR.Type.integer?(elem_t) -> + MLIR.CAPI.mlirIntegerTypeGetWidth(elem_t) |> Beaver.Native.to_term() + + MLIR.Type.float?(elem_t) -> + MLIR.CAPI.mlirFloatTypeGetWidth(elem_t) |> Beaver.Native.to_term() + + true -> + raise ArgumentError, "Expected a shaped type, got #{to_string(t)}" + end + + width = Index.constant(value: Attribute.index(width), loc: loc) >>> Type.index() + ptr_i = MemRef.extract_aligned_pointer_as_index(arg, loc: loc) >>> Type.index() + [_, offset, _, _] = MemRef.extract_strided_metadata(arg, loc: loc) >>> :infer + offset = Arith.muli(offset, width, loc: loc) >>> Type.index() + ptr_i = Arith.addi(ptr_i, offset, loc: loc) >>> Type.index() + ptr_i = Arith.index_cast(ptr_i, loc: loc) >>> Type.i64() + LLVM.inttoptr(ptr_i, loc: loc) >>> ~t{!llvm.ptr} + else + arg + end + end + end + + defp preprocess_args(args, arg_types, []) do + for {arg, arg_type} <- args |> Enum.zip(arg_types) do + if not MLIR.equal?(MLIR.Value.type(arg), arg_type) do + raise ArgumentError, + "Expected a value of type #{MLIR.to_string(arg_type)}, got #{MLIR.to_string(MLIR.Value.type(arg))}" + end + + arg + end + end + + defp preprocess_args(args, arg_types, [preprocessor | tail]) do + for {arg, arg_type} <- args |> Enum.zip(arg_types) do + preprocessor.(arg, arg_type) + end + |> preprocess_args(arg_types, tail) + end + + defp call_enif(name, args, %Opts{ctx: ctx, blk: blk, loc: loc} = opts) do + {arg_types, ret_types} = Beaver.ENIF.signature(ctx, name) + + args = + preprocess_args(args, arg_types, [ + &literal_to_constant(&1, &2, opts), + &extract_raw_pointer(&1, &2, opts) + ]) + + mlir ctx: ctx, blk: blk do + Func.call(args, callee: Attribute.flat_symbol_ref(name), loc: loc) >>> ret_types + end + end + signature_ctx = MLIR.Context.create() for name <- @enif_functions do - {arg_types, _} = Beaver.ENIF.signature(signature_ctx, name) - args = Macro.generate_arguments(length(arg_types), __MODULE__) + arity = Beaver.ENIF.signature(signature_ctx, name) |> elem(0) |> length() + args = Macro.generate_arguments(arity, __MODULE__) defintrinsic unquote(name)(unquote_splicing(args)) do - opts = %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ - {arg_types, ret_types} = Beaver.ENIF.signature(ctx, unquote(name)) - args = [unquote_splicing(args)] |> Enum.zip(arg_types) |> Enum.map(&wrap_arg(&1, opts)) - - mlir ctx: ctx, blk: blk do - Func.call(args, callee: Attribute.flat_symbol_ref("#{unquote(name)}"), loc: loc) >>> - case ret_types do - [ret] -> - ret - - [] -> - [] - end - end + call_enif(unquote(name), unquote(args), __IR__) end end diff --git a/test/add_two_int_test.exs b/test/add_two_int_test.exs new file mode 100644 index 0000000..d8e6a2f --- /dev/null +++ b/test/add_two_int_test.exs @@ -0,0 +1,59 @@ +defmodule AddTwoIntTest do + use ExUnit.Case, async: true + + test "add two integers" do + defmodule AddTwoInt do + use Charms, init: false + alias Charms.{Pointer, Term} + + defm add_or_error_with_cond_br(env, a, b, error) :: Term.t() do + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) + + arg_err = + block do + func.return(error) + end + + cond_br enif_get_int(env, a, ptr_a) != 0 do + cond_br 0 != enif_get_int(env, b, ptr_b) do + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + sum = value llvm.add(a, b) :: i32() + sum = sum / 1 + sum = sum + 1 - 1 + term = enif_make_int(env, sum) + func.return(term) + else + ^arg_err + end + else + ^arg_err + end + end + + defm add(env, a, b) :: Term.t() do + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) + + if !enif_get_int(env, a, ptr_a) || !enif_get_int(env, b, ptr_b) do + enif_make_badarg(env) + else + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + enif_make_int(env, a + b) + end + end + end + + assert {:ok, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) + assert {:cached, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) + engine = Charms.JIT.engine(:add_int) + assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR") + assert AddTwoInt.add(1, 2).(engine) == 3 + assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end + assert AddTwoInt.add_or_error_with_cond_br(1, 2, :arg_err).(engine) == 3 + assert AddTwoInt.add_or_error_with_cond_br(1, "", :arg_err).(engine) == :arg_err + assert :ok = Charms.JIT.destroy(:add_int) + end +end diff --git a/test/defm_test.exs b/test/defm_test.exs index f9a3c84..b37f1c7 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -1,47 +1,3 @@ -defmodule AddTwoInt do - use Charms, init: false - alias Charms.{Pointer, Term} - - defm add_or_error_with_cond_br(env, a, b, error) :: Term.t() do - ptr_a = Pointer.allocate(i32()) - ptr_b = Pointer.allocate(i32()) - - arg_err = - block do - func.return(error) - end - - cond_br enif_get_int(env, a, ptr_a) != 0 do - cond_br 0 != enif_get_int(env, b, ptr_b) do - a = Pointer.load(i32(), ptr_a) - b = Pointer.load(i32(), ptr_b) - sum = value llvm.add(a, b) :: i32() - sum = sum / 1 - sum = sum + 1 - 1 - term = enif_make_int(env, sum) - func.return(term) - else - ^arg_err - end - else - ^arg_err - end - end - - defm add(env, a, b) :: Term.t() do - ptr_a = Pointer.allocate(i32()) - ptr_b = Pointer.allocate(i32()) - - if !enif_get_int(env, a, ptr_a) || !enif_get_int(env, b, ptr_b) do - enif_make_badarg(env) - else - a = Pointer.load(i32(), ptr_a) - b = Pointer.load(i32(), ptr_b) - enif_make_int(env, a + b) - end - end -end - defmodule DefmTest do import ExUnit.CaptureIO use ExUnit.Case, async: true @@ -87,38 +43,26 @@ defmodule DefmTest do assert 1 = ReferrerMod.term_roundtrip(1) end - test "add two integers" do - assert {:ok, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) - assert {:cached, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) - engine = Charms.JIT.engine(:add_int) - assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR") - assert AddTwoInt.add(1, 2).(engine) == 3 - assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end - assert AddTwoInt.add_or_error_with_cond_br(1, 2, :arg_err).(engine) == 3 - assert AddTwoInt.add_or_error_with_cond_br(1, "", :arg_err).(engine) == :arg_err - assert :ok = Charms.JIT.destroy(:add_int) - end + describe "different sorts" do + for s <- [ENIFTimSort, ENIFMergeSort, ENIFQuickSort] do + test "#{s}" do + s = unquote(s) + arr = [5, 4, 3, 2, 1] + assert s.sort(arr) == Enum.sort(arr) + assert_raise ArgumentError, "list expected", fn -> s.sort(:what) end - test "quick sort" do - assert_raise ArgumentError, "list expected", fn -> ENIFQuickSort.sort(:what) end + assert {:cached, %Charms.JIT{}} = + Charms.JIT.init(s, name: s.__ir_digest__()) - arr = [5, 4, 3, 2, 1] - assert ENIFQuickSort.sort(arr) == Enum.sort(arr) - - assert {:cached, %Charms.JIT{}} = - Charms.JIT.init(ENIFQuickSort, name: ENIFQuickSort.__ir_digest__()) + for i <- 0..1000 do + arr = 0..i |> Enum.shuffle() + assert s.sort(arr) == Enum.sort(arr) + end - for i <- 0..1000 do - arr = 0..i |> Enum.shuffle() - assert ENIFTimSort.sort(arr) == Enum.sort(arr) - assert ENIFQuickSort.sort(arr) == Enum.sort(arr) - assert ENIFMergeSort.sort(arr) == Enum.sort(arr) + assert :ok = Charms.JIT.destroy(s.__ir_digest__()) + assert :noop = Charms.JIT.destroy(SortUtil.__ir_digest__()) + end end - - assert :ok = Charms.JIT.destroy(ENIFQuickSort.__ir_digest__()) - assert :ok = Charms.JIT.destroy(ENIFMergeSort.__ir_digest__()) - assert :ok = Charms.JIT.destroy(ENIFTimSort.__ir_digest__()) - assert :noop = Charms.JIT.destroy(SortUtil.__ir_digest__()) end describe "different calls" do @@ -220,4 +164,18 @@ defmodule DefmTest do end end) =~ ~r"block argument.+i64" end + + test "enif type mismatch" do + assert_raise ArgumentError, ~r/Expected a value of type i32, got f32/, fn -> + defmodule MismatchEnifType do + use Charms + alias Charms.Term + + defm foo(env) :: Term.t() do + zero = const 0.0 :: f32() + enif_make_int(env, zero) + end + end + end + end end