Skip to content

Commit

Permalink
Use ownership based buffer deallocation
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper committed Dec 27, 2024
1 parent 73c4646 commit 5c79644
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 48 deletions.
8 changes: 4 additions & 4 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ ENIFTimSort.sort(arr)
Benchee.run(
%{
"Enum.sort" => &Enum.sort/1,
"enif_quick_sort" => &ENIFQuickSort.sort(&1),
"enif_merge_sort" => &ENIFMergeSort.sort(&1),
"enif_tim_sort" => &ENIFTimSort.sort(&1)
"enif_quick_sort" => &ENIFQuickSort.sort(&1)
},
parallel: 2,
inputs: [10, 100, 1000, 10000] |> Enum.map(&{"array size #{&1}", &1}) |> Enum.into(%{}),
warmup: 1,
time: 3,
inputs: [10, 1000, 100_000, 1_000_000] |> Enum.map(&{"array size #{&1}", &1}) |> Enum.into(%{}),
before_scenario: fn i ->
Enum.to_list(1..i) |> Enum.shuffle()
end
Expand Down
2 changes: 1 addition & 1 deletion lib/charms/debug.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Charms.Debug do
alias Beaver.MLIR

def print_ir_pass(op) do
if System.get_env("DEFM_PRINT_IR") == "1" do
if System.get_env("DEFM_PRINT_IR") && !step_print?() do
case op do
%MLIR.Operation{} ->
MLIR.dump!(op)
Expand Down
26 changes: 26 additions & 0 deletions lib/charms/defm/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,29 @@ defmodule Charms.Defm.Definition do
:ok
end

@default_visibility "private"
defp declare_enif(ctx, blk, name_str) do
mlir ctx: ctx, blk: blk do
{arg_types, ret_types} = Beaver.ENIF.signature(ctx, String.to_atom(name_str))

Func.func _(
sym_name: MLIR.Attribute.string(name_str),
sym_visibility: MLIR.Attribute.string(@default_visibility),
function_type: Type.function(arg_types, ret_types)
) do
region do
end
end
end
end

defp declared_required_enif(op) do
mlir ctx: MLIR.context(op), blk: MLIR.Module.body(MLIR.Module.from_operation(op)) do
declare_enif(Beaver.Env.context(), Beaver.Env.block(), "enif_alloc")
declare_enif(Beaver.Env.context(), Beaver.Env.block(), "enif_free")
end
end

# if it is single block with no terminator, add a return
defp append_missing_return(func) do
with [r] <- Beaver.Walker.regions(func) |> Enum.to_list(),
Expand Down Expand Up @@ -279,6 +302,9 @@ defmodule Charms.Defm.Definition do
{"append_missing_return", "func.func", &append_missing_return/1}
)
|> Beaver.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc)
|> Beaver.Composer.append(
{"declared-required-enif", "builtin.module", &declared_required_enif/1}
)
|> Beaver.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> MLIR.Transform.canonicalize()
|> then(fn op ->
Expand Down
11 changes: 7 additions & 4 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,8 @@ defmodule Charms.Defm.Expander do
term_ptr = Pointer.allocate(Term.t())
size = String.length(attr)
size = value index.casts(size) :: i64()
buffer_ptr = Pointer.allocate(i8(), size)
buffer = ptr_to_memref(buffer_ptr, size)
memref.copy(attr, buffer)
zero = const 0 :: i32()
enif_binary_to_term(env_ptr, buffer_ptr, size, term_ptr, zero)
enif_binary_to_term(env_ptr, attr, size, term_ptr, zero)
Pointer.load(Term.t(), term_ptr)
end
|> expand_with_bindings(state, env, attr: attr, env_ptr: env_ptr)
Expand Down Expand Up @@ -920,6 +917,11 @@ defmodule Charms.Defm.Expander do
end
end

# convert op name ast to a string
defp normalize_dot_op_name(ast) do
ast |> Macro.to_string() |> String.replace([":", " "], "")
end

## Macro handling

# This is going to be the function where you will intercept expansions
Expand Down Expand Up @@ -1165,6 +1167,7 @@ defmodule Charms.Defm.Expander do
defp expand_macro(_meta, Charms.Defm, :op, [call], _callback, state, env) do
{call, return_types} = decompose_call_signature(call)
{{dialect, _, _}, op, args} = Macro.decompose_call(call)
dialect = normalize_dot_op_name(dialect)
op = "#{dialect}.#{op}"
{args, state, env} = expand(args, state, env)
{return_types, state, env} = expand(return_types, state, env)
Expand Down
45 changes: 45 additions & 0 deletions lib/charms/defm/pass/use_enif_malloc.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
defmodule Charms.Defm.Pass.UseEnifMalloc do
@moduledoc false
use Beaver
use MLIR.Pass, on: "builtin.module"
alias MLIR.Dialect.LLVM
import Beaver.Pattern

defpat replace_alloc(benefit: 10) do
size = value()
ptr_t = type()
{op, _} = LLVM.call(size, callee: Attribute.flat_symbol_ref("malloc")) >>> {:op, [ptr_t]}

rewrite op do
r =
LLVM.call(size,
callee: Attribute.flat_symbol_ref("enif_alloc"),
operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0]),
op_bundle_sizes: ~a{array<i32>}
) >>> ptr_t

replace(op, with: r)
end
end

defpat replace_free(benefit: 10) do
ptr = value()
{op, _} = LLVM.call(ptr, callee: Attribute.flat_symbol_ref("free")) >>> {:op, []}

rewrite op do
{enif_free, _} =
LLVM.call(ptr,
callee: Attribute.flat_symbol_ref("enif_free"),
operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0]),
op_bundle_sizes: ~a{array<i32>}
) >>> {:op, []}

replace(op, with: enif_free)
end
end

def run(op) do
module = MLIR.Module.from_operation(op)
MLIR.apply!(module, [replace_alloc(), replace_free()])
end
end
11 changes: 9 additions & 2 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,33 @@ defmodule Charms.JIT do
import Beaver.MLIR.CAPI
alias Beaver.MLIR
alias __MODULE__.LockedCache

defstruct ctx: nil, engine: nil, owner: true

defp jit_of_mod(m) do
import Beaver.MLIR.{Conversion, Transform}

m
|> MLIR.verify!()
|> MLIR.Transform.canonicalize()
|> Beaver.Composer.append("ownership-based-buffer-deallocation")
|> Beaver.Composer.append("buffer-deallocation-simplification")
|> Beaver.Composer.append("bufferization-lower-deallocations")
|> MLIR.Transform.canonicalize()
|> Charms.Debug.print_ir_pass()
|> Beaver.Composer.nested("func.func", "llvm-request-c-wrappers")
|> Beaver.Composer.nested("func.func", loop_invariant_code_motion())
|> convert_scf_to_cf
|> convert_cf_to_llvm()
|> convert_arith_to_llvm()
|> convert_index_to_llvm()
|> convert_func_to_llvm()
|> Beaver.Composer.append("finalize-memref-to-llvm")
|> Beaver.Composer.append("convert-vector-to-llvm{reassociate-fp-reductions}")
|> reconcile_unrealized_casts
|> Beaver.Composer.append(Charms.Defm.Pass.UseEnifMalloc)
|> Charms.Debug.print_ir_pass()
|> Beaver.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.ExecutionEngine.create!(opt_level: 3, object_dump: true)
|> MLIR.ExecutionEngine.create!(opt_level: 3, object_dump: true, dirty: :cpu_bound)
|> tap(&beaver_raw_jit_register_enif(&1.ref))
end

Expand Down
58 changes: 56 additions & 2 deletions lib/charms/pointer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ defmodule Charms.Pointer do
zero = Index.constant(value: Attribute.index(0)) >>> Type.index()

case size do
i when is_integer(i) ->
1 ->
MemRef.alloca(
loc: loc,
operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([0, 0])
) >>> Type.memref([1], elem_type)

i when is_integer(i) ->
MemRef.alloc(
loc: loc,
operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([0, 0])
) >>> Type.memref([i], elem_type)

%MLIR.Value{} ->
Expand All @@ -43,7 +49,7 @@ defmodule Charms.Pointer do
Index.casts(size, loc: loc) >>> Type.index()
end

MemRef.alloca(size,
MemRef.alloc(size,
loc: loc,
operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0])
) >>> Type.memref([:dynamic], elem_type)
Expand Down Expand Up @@ -207,4 +213,52 @@ defmodule Charms.Pointer do
%Opts{ctx: ctx} = __IR__
ptr_type(elem_t, ctx)
end

@doc false
def extract_raw_pointer(%MLIR.Value{} = ptr, %Opts{ctx: ctx, blk: blk, loc: loc}) do
t = MLIR.Value.type(ptr)

cond do
MLIR.equal?(~t{!llvm.ptr}.(ctx), t) ->
ptr

Charms.Pointer.memref_ptr?(t) ->
mlir ctx: ctx, blk: blk do
elem_t = MLIR.CAPI.mlirShapedTypeGetElementType(t)

width =
cond do
MLIR.Type.integer?(elem_t) ->
MLIR.CAPI.mlirIntegerTypeGetWidth(elem_t) |> Beaver.Native.to_term()

MLIR.Type.float?(elem_t) ->
MLIR.CAPI.mlirFloatTypeGetWidth(elem_t) |> Beaver.Native.to_term()

true ->
raise ArgumentError, "Expected a shaped type, got #{to_string(t)}"
end

width = Index.constant(value: Attribute.index(width), loc: loc) >>> Type.index()
ptr_i = MemRef.extract_aligned_pointer_as_index(ptr, loc: loc) >>> Type.index()
[_, offset, _, _] = MemRef.extract_strided_metadata(ptr, loc: loc) >>> :infer
offset = Arith.muli(offset, width, loc: loc) >>> Type.index()
ptr_i = Arith.addi(ptr_i, offset, loc: loc) >>> Type.index()
ptr_i = Arith.index_cast(ptr_i, loc: loc) >>> Type.i64()
LLVM.inttoptr(ptr_i, loc: loc) >>> ~t{!llvm.ptr}
end

true ->
raise ArgumentError, "Expected a pointer, got #{MLIR.to_string(t)}"
end
end

defintrinsic copy(source, destination, bytes_count) do
%Opts{ctx: ctx, blk: blk, loc: loc} = __IR__
source = extract_raw_pointer(source, __IR__)
destination = extract_raw_pointer(destination, __IR__)

mlir ctx: ctx, blk: blk do
LLVM.intr_memcpy(destination, source, bytes_count, isVolatile: ~a{false}, loc: loc) >>> []
end
end
end
36 changes: 6 additions & 30 deletions lib/charms/prelude.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Charms.Prelude do
"""
use Charms.Intrinsic
alias Charms.Intrinsic.Opts
alias Beaver.MLIR.Dialect.{Arith, Func, LLVM, MemRef, Index}
alias Beaver.MLIR.Dialect.Func
@enif_functions Beaver.ENIF.functions()

defp literal_to_constant(v, t, %Opts{ctx: ctx, blk: blk, loc: loc})
Expand Down Expand Up @@ -52,35 +52,11 @@ defmodule Charms.Prelude do
entity |> tap(&IO.puts(MLIR.to_string(&1)))
end

def extract_raw_pointer(arg, arg_type, %Opts{ctx: ctx, blk: blk, loc: loc}) do
mlir ctx: ctx, blk: blk do
t = MLIR.Value.type(arg)

if MLIR.equal?(~t{!llvm.ptr}.(ctx), arg_type) and Charms.Pointer.memref_ptr?(arg) do
elem_t = MLIR.CAPI.mlirShapedTypeGetElementType(t)

width =
cond do
MLIR.Type.integer?(elem_t) ->
MLIR.CAPI.mlirIntegerTypeGetWidth(elem_t) |> Beaver.Native.to_term()

MLIR.Type.float?(elem_t) ->
MLIR.CAPI.mlirFloatTypeGetWidth(elem_t) |> Beaver.Native.to_term()

true ->
raise ArgumentError, "Expected a shaped type, got #{to_string(t)}"
end

width = Index.constant(value: Attribute.index(width), loc: loc) >>> Type.index()
ptr_i = MemRef.extract_aligned_pointer_as_index(arg, loc: loc) >>> Type.index()
[_, offset, _, _] = MemRef.extract_strided_metadata(arg, loc: loc) >>> :infer
offset = Arith.muli(offset, width, loc: loc) >>> Type.index()
ptr_i = Arith.addi(ptr_i, offset, loc: loc) >>> Type.index()
ptr_i = Arith.index_cast(ptr_i, loc: loc) >>> Type.i64()
LLVM.inttoptr(ptr_i, loc: loc) >>> ~t{!llvm.ptr}
else
arg
end
defp extract_raw_pointer(arg, arg_type, %Opts{ctx: ctx} = opts) do
if MLIR.equal?(~t{!llvm.ptr}.(ctx), arg_type) and Charms.Pointer.memref_ptr?(arg) do
Charms.Pointer.extract_raw_pointer(arg, opts)
else
arg
end
end

Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ defmodule Charms.MixProject do
defp deps do
[
{:ex_doc, ">= 0.0.0", only: :dev, runtime: false},
{:beaver, "~> 0.4.0"},
{:beaver, "~> 0.4.1"},
{:benchee, "~> 1.0", only: :dev},
{:credo, "~> 1.7", only: [:dev, :test], runtime: false}
]
Expand Down
4 changes: 2 additions & 2 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
%{
"beaver": {:hex, :beaver, "0.4.0", "82014114c6c54efd6583ec4dbd8df6b7c9559ee190fcd6a2658dd688c2e8da63", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kinda, "~> 0.9.3", [hex: :kinda, repo: "hexpm", optional: false]}], "hexpm", "2e336f4ab2f7088562943bb153ee981baca5156c3e429066d563734193761a18"},
"beaver": {:hex, :beaver, "0.4.1", "9c18b5308a68484597d8d80a3c4a537aa3fe2f7b10dad77698b986ec8118da4b", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kinda, "~> 0.9.3", [hex: :kinda, repo: "hexpm", optional: false]}], "hexpm", "64ec3286bbf95135932c8ae2637b22c6845eb1f6ae19155790a33de5f31ec4b8"},
"benchee": {:hex, :benchee, "1.3.1", "c786e6a76321121a44229dde3988fc772bca73ea75170a73fd5f4ddf1af95ccf", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: true]}], "hexpm", "76224c58ea1d0391c8309a8ecbfe27d71062878f59bd41a390266bf4ac1cc56d"},
"bunt": {:hex, :bunt, "1.0.0", "081c2c665f086849e6d57900292b3a161727ab40431219529f13c4ddcf3e7a44", [:mix], [], "hexpm", "dc5f86aa08a5f6fa6b8096f0735c4e76d54ae5c9fa2c143e5a1fc7c1cd9bb6b5"},
"credo": {:hex, :credo, "1.7.10", "6e64fe59be8da5e30a1b96273b247b5cf1cc9e336b5fd66302a64b25749ad44d", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "71fbc9a6b8be21d993deca85bf151df023a3097b01e09a2809d460348561d8cd"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"earmark_parser": {:hex, :earmark_parser, "1.4.42", "f23d856f41919f17cd06a493923a722d87a2d684f143a1e663c04a2b93100682", [:mix], [], "hexpm", "6915b6ca369b5f7346636a2f41c6a6d78b5af419d61a611079189233358b8b8b"},
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
"ex_doc": {:hex, :ex_doc, "0.35.1", "de804c590d3df2d9d5b8aec77d758b00c814b356119b3d4455e4b8a8687aecaf", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "2121c6402c8d44b05622677b761371a759143b958c6c19f6558ff64d0aed40df"},
"ex_doc": {:hex, :ex_doc, "0.36.1", "4197d034f93e0b89ec79fac56e226107824adcce8d2dd0a26f5ed3a95efc36b1", [:mix], [{:earmark_parser, "~> 1.4.42", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "d7d26a7cf965dacadcd48f9fa7b5953d7d0cfa3b44fa7a65514427da44eafd89"},
"file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"},
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
"kinda": {:hex, :kinda, "0.9.4", "007e25491bcd3af8a95e0179d9044362dc336920bbe5dbd6515196a5e938b201", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "a5ec71839edf88e52f3e68a2c2210b9a9a00ee414ff369e9801b1184b3844447"},
Expand Down
3 changes: 1 addition & 2 deletions test/string_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ defmodule StringTest do
term_ptr = Pointer.allocate(Term.t())
size = value index.casts(String.length(str)) :: i64()
d_ptr = enif_make_new_binary(env, size, term_ptr)
m = ptr_to_memref(d_ptr, size)
memref.copy(str, m)
Pointer.copy(str, d_ptr, size)
Pointer.load(Term.t(), term_ptr)
end
end
Expand Down

0 comments on commit 5c79644

Please sign in to comment.