From eb106d7c2b3713457b7a6e02938badb4177f099f Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Fri, 27 Dec 2024 21:00:38 +0800 Subject: [PATCH] Use ownership based buffer deallocation (#59) --- bench/sort_benchmark.exs | 8 ++-- guides/programming-with-charms.livemd | 2 +- lib/charms/debug.ex | 2 +- lib/charms/defm/definition.ex | 26 ++++++++++++ lib/charms/defm/expander.ex | 11 +++-- lib/charms/jit.ex | 14 +++++-- lib/charms/pointer.ex | 58 ++++++++++++++++++++++++++- lib/charms/prelude.ex | 36 +++-------------- mix.exs | 4 +- mix.lock | 4 +- test/string_test.exs | 3 +- 11 files changed, 116 insertions(+), 52 deletions(-) diff --git a/bench/sort_benchmark.exs b/bench/sort_benchmark.exs index a2a28af..deaf2e3 100644 --- a/bench/sort_benchmark.exs +++ b/bench/sort_benchmark.exs @@ -6,12 +6,12 @@ ENIFTimSort.sort(arr) Benchee.run( %{ "Enum.sort" => &Enum.sort/1, - "enif_quick_sort" => &ENIFQuickSort.sort(&1), - "enif_merge_sort" => &ENIFMergeSort.sort(&1), - "enif_tim_sort" => &ENIFTimSort.sort(&1) + "enif_quick_sort" => &ENIFQuickSort.sort(&1) }, parallel: 2, - inputs: [10, 100, 1000, 10000] |> Enum.map(&{"array size #{&1}", &1}) |> Enum.into(%{}), + warmup: 1, + time: 3, + inputs: [10, 1000, 100_000, 1_000_000] |> Enum.map(&{"array size #{&1}", &1}) |> Enum.into(%{}), before_scenario: fn i -> Enum.to_list(1..i) |> Enum.shuffle() end diff --git a/guides/programming-with-charms.livemd b/guides/programming-with-charms.livemd index 6e7e4a7..6da2385 100644 --- a/guides/programming-with-charms.livemd +++ b/guides/programming-with-charms.livemd @@ -2,7 +2,7 @@ ```elixir Mix.install([ - {:charms, "~> 0.1.2"} + {:charms, "~> 0.1.3"} ]) ``` diff --git a/lib/charms/debug.ex b/lib/charms/debug.ex index 7348ff3..8366cfe 100644 --- a/lib/charms/debug.ex +++ b/lib/charms/debug.ex @@ -3,7 +3,7 @@ defmodule Charms.Debug do alias Beaver.MLIR def print_ir_pass(op) do - if System.get_env("DEFM_PRINT_IR") == "1" do + if System.get_env("DEFM_PRINT_IR") && !step_print?() do case op do %MLIR.Operation{} -> MLIR.dump!(op) diff --git a/lib/charms/defm/definition.ex b/lib/charms/defm/definition.ex index b4c320d..8e14909 100644 --- a/lib/charms/defm/definition.ex +++ b/lib/charms/defm/definition.ex @@ -167,6 +167,29 @@ defmodule Charms.Defm.Definition do :ok end + @default_visibility "private" + defp declare_enif(ctx, blk, name_str) do + mlir ctx: ctx, blk: blk do + {arg_types, ret_types} = Beaver.ENIF.signature(ctx, String.to_atom(name_str)) + + Func.func _( + sym_name: MLIR.Attribute.string(name_str), + sym_visibility: MLIR.Attribute.string(@default_visibility), + function_type: Type.function(arg_types, ret_types) + ) do + region do + end + end + end + end + + defp declared_required_enif(op) do + mlir ctx: MLIR.context(op), blk: MLIR.Module.body(MLIR.Module.from_operation(op)) do + declare_enif(Beaver.Env.context(), Beaver.Env.block(), "enif_alloc") + declare_enif(Beaver.Env.context(), Beaver.Env.block(), "enif_free") + end + end + # if it is single block with no terminator, add a return defp append_missing_return(func) do with [r] <- Beaver.Walker.regions(func) |> Enum.to_list(), @@ -279,6 +302,9 @@ defmodule Charms.Defm.Definition do {"append_missing_return", "func.func", &append_missing_return/1} ) |> Beaver.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc) + |> Beaver.Composer.append( + {"declared-required-enif", "builtin.module", &declared_required_enif/1} + ) |> Beaver.Composer.append({"check-poison", "builtin.module", &check_poison!/1}) |> MLIR.Transform.canonicalize() |> then(fn op -> diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index 7074c06..af29599 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -514,11 +514,8 @@ defmodule Charms.Defm.Expander do term_ptr = Pointer.allocate(Term.t()) size = String.length(attr) size = value index.casts(size) :: i64() - buffer_ptr = Pointer.allocate(i8(), size) - buffer = ptr_to_memref(buffer_ptr, size) - memref.copy(attr, buffer) zero = const 0 :: i32() - enif_binary_to_term(env_ptr, buffer_ptr, size, term_ptr, zero) + enif_binary_to_term(env_ptr, attr, size, term_ptr, zero) Pointer.load(Term.t(), term_ptr) end |> expand_with_bindings(state, env, attr: attr, env_ptr: env_ptr) @@ -920,6 +917,11 @@ defmodule Charms.Defm.Expander do end end + # convert op name ast to a string + defp normalize_dot_op_name(ast) do + ast |> Macro.to_string() |> String.replace([":", " "], "") + end + ## Macro handling # This is going to be the function where you will intercept expansions @@ -1165,6 +1167,7 @@ defmodule Charms.Defm.Expander do defp expand_macro(_meta, Charms.Defm, :op, [call], _callback, state, env) do {call, return_types} = decompose_call_signature(call) {{dialect, _, _}, op, args} = Macro.decompose_call(call) + dialect = normalize_dot_op_name(dialect) op = "#{dialect}.#{op}" {args, state, env} = expand(args, state, env) {return_types, state, env} = expand(return_types, state, env) diff --git a/lib/charms/jit.ex b/lib/charms/jit.ex index 8ae2ea5..5c735cb 100644 --- a/lib/charms/jit.ex +++ b/lib/charms/jit.ex @@ -6,7 +6,6 @@ defmodule Charms.JIT do import Beaver.MLIR.CAPI alias Beaver.MLIR alias __MODULE__.LockedCache - defstruct ctx: nil, engine: nil, owner: true defp jit_of_mod(m) do @@ -14,9 +13,16 @@ defmodule Charms.JIT do m |> MLIR.verify!() + |> MLIR.Transform.canonicalize() + |> Beaver.Composer.append("ownership-based-buffer-deallocation") + |> Beaver.Composer.append("buffer-deallocation-simplification") + |> Beaver.Composer.append("bufferization-lower-deallocations") + |> MLIR.Transform.canonicalize() + |> Charms.Debug.print_ir_pass() |> Beaver.Composer.nested("func.func", "llvm-request-c-wrappers") |> Beaver.Composer.nested("func.func", loop_invariant_code_motion()) |> convert_scf_to_cf + |> convert_cf_to_llvm() |> convert_arith_to_llvm() |> convert_index_to_llvm() |> convert_func_to_llvm() @@ -25,7 +31,7 @@ defmodule Charms.JIT do |> reconcile_unrealized_casts |> Charms.Debug.print_ir_pass() |> Beaver.Composer.run!(print: Charms.Debug.step_print?()) - |> MLIR.ExecutionEngine.create!(opt_level: 3, object_dump: true) + |> MLIR.ExecutionEngine.create!(opt_level: 3, object_dump: true, dirty: :cpu_bound) |> tap(&beaver_raw_jit_register_enif(&1.ref)) end @@ -147,8 +153,8 @@ defmodule Charms.JIT do if jit = LockedCache.get(key), do: jit.engine end - def invoke(%MLIR.ExecutionEngine{ref: ref}, {mod, func, args}) do - beaver_raw_jit_invoke_with_terms(ref, to_string(Charms.Defm.mangling(mod, func)), args) + def invoke(%MLIR.ExecutionEngine{} = engine, {mod, func, args}) do + Beaver.ENIF.invoke(engine, to_string(Charms.Defm.mangling(mod, func)), args) end def destroy(key) do diff --git a/lib/charms/pointer.ex b/lib/charms/pointer.ex index a51e70b..f78ea8f 100644 --- a/lib/charms/pointer.ex +++ b/lib/charms/pointer.ex @@ -29,10 +29,16 @@ defmodule Charms.Pointer do zero = Index.constant(value: Attribute.index(0)) >>> Type.index() case size do - i when is_integer(i) -> + 1 -> MemRef.alloca( loc: loc, operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([0, 0]) + ) >>> Type.memref([1], elem_type) + + i when is_integer(i) -> + MemRef.alloc( + loc: loc, + operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([0, 0]) ) >>> Type.memref([i], elem_type) %MLIR.Value{} -> @@ -43,7 +49,7 @@ defmodule Charms.Pointer do Index.casts(size, loc: loc) >>> Type.index() end - MemRef.alloca(size, + MemRef.alloc(size, loc: loc, operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0]) ) >>> Type.memref([:dynamic], elem_type) @@ -207,4 +213,52 @@ defmodule Charms.Pointer do %Opts{ctx: ctx} = __IR__ ptr_type(elem_t, ctx) end + + @doc false + def extract_raw_pointer(%MLIR.Value{} = ptr, %Opts{ctx: ctx, blk: blk, loc: loc}) do + t = MLIR.Value.type(ptr) + + cond do + MLIR.equal?(~t{!llvm.ptr}.(ctx), t) -> + ptr + + Charms.Pointer.memref_ptr?(t) -> + mlir ctx: ctx, blk: blk do + elem_t = MLIR.CAPI.mlirShapedTypeGetElementType(t) + + width = + cond do + MLIR.Type.integer?(elem_t) -> + MLIR.CAPI.mlirIntegerTypeGetWidth(elem_t) |> Beaver.Native.to_term() + + MLIR.Type.float?(elem_t) -> + MLIR.CAPI.mlirFloatTypeGetWidth(elem_t) |> Beaver.Native.to_term() + + true -> + raise ArgumentError, "Expected a shaped type, got #{to_string(t)}" + end + + width = Index.constant(value: Attribute.index(width), loc: loc) >>> Type.index() + ptr_i = MemRef.extract_aligned_pointer_as_index(ptr, loc: loc) >>> Type.index() + [_, offset, _, _] = MemRef.extract_strided_metadata(ptr, loc: loc) >>> :infer + offset = Arith.muli(offset, width, loc: loc) >>> Type.index() + ptr_i = Arith.addi(ptr_i, offset, loc: loc) >>> Type.index() + ptr_i = Arith.index_cast(ptr_i, loc: loc) >>> Type.i64() + LLVM.inttoptr(ptr_i, loc: loc) >>> ~t{!llvm.ptr} + end + + true -> + raise ArgumentError, "Expected a pointer, got #{MLIR.to_string(t)}" + end + end + + defintrinsic copy(source, destination, bytes_count) do + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + source = extract_raw_pointer(source, __IR__) + destination = extract_raw_pointer(destination, __IR__) + + mlir ctx: ctx, blk: blk do + LLVM.intr_memcpy(destination, source, bytes_count, isVolatile: ~a{false}, loc: loc) >>> [] + end + end end diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index 45a1851..264e72f 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -4,7 +4,7 @@ defmodule Charms.Prelude do """ use Charms.Intrinsic alias Charms.Intrinsic.Opts - alias Beaver.MLIR.Dialect.{Arith, Func, LLVM, MemRef, Index} + alias Beaver.MLIR.Dialect.Func @enif_functions Beaver.ENIF.functions() defp literal_to_constant(v, t, %Opts{ctx: ctx, blk: blk, loc: loc}) @@ -52,35 +52,11 @@ defmodule Charms.Prelude do entity |> tap(&IO.puts(MLIR.to_string(&1))) end - def extract_raw_pointer(arg, arg_type, %Opts{ctx: ctx, blk: blk, loc: loc}) do - mlir ctx: ctx, blk: blk do - t = MLIR.Value.type(arg) - - if MLIR.equal?(~t{!llvm.ptr}.(ctx), arg_type) and Charms.Pointer.memref_ptr?(arg) do - elem_t = MLIR.CAPI.mlirShapedTypeGetElementType(t) - - width = - cond do - MLIR.Type.integer?(elem_t) -> - MLIR.CAPI.mlirIntegerTypeGetWidth(elem_t) |> Beaver.Native.to_term() - - MLIR.Type.float?(elem_t) -> - MLIR.CAPI.mlirFloatTypeGetWidth(elem_t) |> Beaver.Native.to_term() - - true -> - raise ArgumentError, "Expected a shaped type, got #{to_string(t)}" - end - - width = Index.constant(value: Attribute.index(width), loc: loc) >>> Type.index() - ptr_i = MemRef.extract_aligned_pointer_as_index(arg, loc: loc) >>> Type.index() - [_, offset, _, _] = MemRef.extract_strided_metadata(arg, loc: loc) >>> :infer - offset = Arith.muli(offset, width, loc: loc) >>> Type.index() - ptr_i = Arith.addi(ptr_i, offset, loc: loc) >>> Type.index() - ptr_i = Arith.index_cast(ptr_i, loc: loc) >>> Type.i64() - LLVM.inttoptr(ptr_i, loc: loc) >>> ~t{!llvm.ptr} - else - arg - end + defp extract_raw_pointer(arg, arg_type, %Opts{ctx: ctx} = opts) do + if MLIR.equal?(~t{!llvm.ptr}.(ctx), arg_type) and Charms.Pointer.memref_ptr?(arg) do + Charms.Pointer.extract_raw_pointer(arg, opts) + else + arg end end diff --git a/mix.exs b/mix.exs index 2500f36..f13271f 100644 --- a/mix.exs +++ b/mix.exs @@ -4,7 +4,7 @@ defmodule Charms.MixProject do def project do [ app: :charms, - version: "0.1.3-dev", + version: "0.1.4-dev", elixir: "~> 1.17", start_permanent: Mix.env() == :prod, elixirc_paths: elixirc_paths(Mix.env()), @@ -75,7 +75,7 @@ defmodule Charms.MixProject do defp deps do [ {:ex_doc, ">= 0.0.0", only: :dev, runtime: false}, - {:beaver, "~> 0.4.0"}, + {:beaver, "~> 0.4.1"}, {:benchee, "~> 1.0", only: :dev}, {:credo, "~> 1.7", only: [:dev, :test], runtime: false} ] diff --git a/mix.lock b/mix.lock index 66dd140..3b46a06 100644 --- a/mix.lock +++ b/mix.lock @@ -1,12 +1,12 @@ %{ - "beaver": {:hex, :beaver, "0.4.0", "82014114c6c54efd6583ec4dbd8df6b7c9559ee190fcd6a2658dd688c2e8da63", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kinda, "~> 0.9.3", [hex: :kinda, repo: "hexpm", optional: false]}], "hexpm", "2e336f4ab2f7088562943bb153ee981baca5156c3e429066d563734193761a18"}, + "beaver": {:hex, :beaver, "0.4.1", "9c18b5308a68484597d8d80a3c4a537aa3fe2f7b10dad77698b986ec8118da4b", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kinda, "~> 0.9.3", [hex: :kinda, repo: "hexpm", optional: false]}], "hexpm", "64ec3286bbf95135932c8ae2637b22c6845eb1f6ae19155790a33de5f31ec4b8"}, "benchee": {:hex, :benchee, "1.3.1", "c786e6a76321121a44229dde3988fc772bca73ea75170a73fd5f4ddf1af95ccf", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: true]}], "hexpm", "76224c58ea1d0391c8309a8ecbfe27d71062878f59bd41a390266bf4ac1cc56d"}, "bunt": {:hex, :bunt, "1.0.0", "081c2c665f086849e6d57900292b3a161727ab40431219529f13c4ddcf3e7a44", [:mix], [], "hexpm", "dc5f86aa08a5f6fa6b8096f0735c4e76d54ae5c9fa2c143e5a1fc7c1cd9bb6b5"}, "credo": {:hex, :credo, "1.7.10", "6e64fe59be8da5e30a1b96273b247b5cf1cc9e336b5fd66302a64b25749ad44d", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "71fbc9a6b8be21d993deca85bf151df023a3097b01e09a2809d460348561d8cd"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.42", "f23d856f41919f17cd06a493923a722d87a2d684f143a1e663c04a2b93100682", [:mix], [], "hexpm", "6915b6ca369b5f7346636a2f41c6a6d78b5af419d61a611079189233358b8b8b"}, "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, - "ex_doc": {:hex, :ex_doc, "0.35.1", "de804c590d3df2d9d5b8aec77d758b00c814b356119b3d4455e4b8a8687aecaf", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "2121c6402c8d44b05622677b761371a759143b958c6c19f6558ff64d0aed40df"}, + "ex_doc": {:hex, :ex_doc, "0.36.1", "4197d034f93e0b89ec79fac56e226107824adcce8d2dd0a26f5ed3a95efc36b1", [:mix], [{:earmark_parser, "~> 1.4.42", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "d7d26a7cf965dacadcd48f9fa7b5953d7d0cfa3b44fa7a65514427da44eafd89"}, "file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"}, "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, "kinda": {:hex, :kinda, "0.9.4", "007e25491bcd3af8a95e0179d9044362dc336920bbe5dbd6515196a5e938b201", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "a5ec71839edf88e52f3e68a2c2210b9a9a00ee414ff369e9801b1184b3844447"}, diff --git a/test/string_test.exs b/test/string_test.exs index 4a45df0..b4f6e49 100644 --- a/test/string_test.exs +++ b/test/string_test.exs @@ -12,8 +12,7 @@ defmodule StringTest do term_ptr = Pointer.allocate(Term.t()) size = value index.casts(String.length(str)) :: i64() d_ptr = enif_make_new_binary(env, size, term_ptr) - m = ptr_to_memref(d_ptr, size) - memref.copy(str, m) + Pointer.copy(str, d_ptr, size) Pointer.load(Term.t(), term_ptr) end end