Skip to content

Commit 102e62b

Browse files
committed
(Experimental) Integrate Metal PjRt plugin
1 parent 2f3c6ef commit 102e62b

File tree

13 files changed

+220
-30
lines changed

13 files changed

+220
-30
lines changed

exla/c_src/exla/exla.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
131131
build_options.set_use_spmd_partitioning(use_spmd);
132132

133133
bool compile_portable_executable = false;
134-
if (device_id >= 0) {
134+
135+
bool is_mps = (*client)->client()->platform_name() == "METAL";
136+
137+
if (device_id >= 0 && !is_mps) {
135138
compile_portable_executable = true;
136139
build_options.set_device_ordinal(device_id);
137140
}
@@ -728,6 +731,16 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
728731
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
729732
}
730733

734+
ERL_NIF_TERM get_mps_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
735+
if (argc != 0) {
736+
return exla::nif::error(env, "Bad argument count.");
737+
}
738+
739+
EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetMpsClient(), env);
740+
741+
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
742+
}
743+
731744
ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
732745
if (argc != 1) {
733746
return exla::nif::error(env, "Bad argument count.");
@@ -915,6 +928,7 @@ static ErlNifFunc exla_funcs[] = {
915928
{"get_host_client", 0, get_host_client},
916929
{"get_gpu_client", 2, get_gpu_client},
917930
{"get_tpu_client", 0, get_tpu_client},
931+
{"get_mps_client", 0, get_mps_client},
918932
{"get_c_api_client", 1, get_c_api_client},
919933
{"load_pjrt_plugin", 2, load_pjrt_plugin},
920934
{"get_device_count", 1, get_device_count},

exla/c_src/exla/exla_client.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,30 @@ xla::StatusOr<ExlaClient*> GetTpuClient() {
489489
return new ExlaClient(std::move(client));
490490
}
491491

492+
xla::StatusOr<ExlaClient*> GetMpsClient() {
493+
// The plugin may be compiled for a different version of PjRt C API
494+
// than present in our XLA compilation. By default pjrt::LoadPjrtPlugin
495+
// raises if the version does not match. By setting this environment
496+
// variable, we relax this check to allow different versions, as long
497+
// as they satisfy compatibility constraints.
498+
//
499+
// See https://github.com/openxla/xla/blob/4e8e23f16bc925b6f27817de098a8e1e81296bb5/xla/pjrt/pjrt_api.cc
500+
setenv("ENABLE_PJRT_COMPATIBILITY", "1", 1);
501+
502+
EXLA_EFFECT_OR_RETURN(pjrt::LoadPjrtPlugin("METAL", "pjrt_plugin_metal.dylib"));
503+
504+
xla::Status status = pjrt::InitializePjrtPlugin("METAL");
505+
506+
if (!status.ok()) {
507+
return status;
508+
}
509+
510+
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
511+
xla::GetCApiClient("METAL"));
512+
513+
return new ExlaClient(std::move(client));
514+
}
515+
492516
xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type) {
493517
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
494518
xla::GetCApiClient(device_type));

exla/c_src/exla/exla_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
110110

111111
xla::StatusOr<ExlaClient*> GetTpuClient();
112112

113+
xla::StatusOr<ExlaClient*> GetMpsClient();
114+
113115
xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type);
114116
} // namespace exla
115117

exla/lib/exla/client.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ defmodule EXLA.Client do
159159
:tpu ->
160160
EXLA.NIF.get_tpu_client()
161161

162+
:mps ->
163+
EXLA.NIF.get_mps_client()
164+
162165
_ ->
163166
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
164167
end

exla/lib/exla/defn.ex

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,9 @@ defmodule EXLA.Defn do
663663
result =
664664
Value.gather(
665665
tensor,
666-
indices,
666+
# TODO remove conversion (unsigned indices fail)
667+
# Reported in https://github.com/google/jax/issues/21547
668+
to_type(indices, {:s, 32}),
667669
index_vector_dim,
668670
slice_sizes,
669671
offset_dims,
@@ -871,6 +873,10 @@ defmodule EXLA.Defn do
871873
) do
872874
precision = state.precision
873875

876+
# Ensure both have the same type
877+
left = to_type(left, ans.type)
878+
right = to_type(right, ans.type)
879+
874880
Value.dot_general(
875881
left,
876882
right,
@@ -1291,6 +1297,9 @@ defmodule EXLA.Defn do
12911297
defp to_operator(:put_slice, [%Value{} = tensor, start_indices, slice], ans, _state) do
12921298
tensor = to_type(tensor, ans.type)
12931299
slice = to_type(slice, ans.type)
1300+
# TODO remove conversion (unsigned indices fail)
1301+
# Reported in https://github.com/google/jax/issues/21547
1302+
start_indices = Enum.map(start_indices, &to_type(&1, {:s, 32}))
12941303
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
12951304
end
12961305

@@ -1313,7 +1322,9 @@ defmodule EXLA.Defn do
13131322

13141323
Value.gather(
13151324
tensor,
1316-
indices,
1325+
# TODO remove conversion (unsigned indices fail)
1326+
# Reported in https://github.com/google/jax/issues/21547
1327+
to_type(indices, {:s, 32}),
13171328
index_vector_dim,
13181329
slice_sizes,
13191330
offset_dims,
@@ -1341,7 +1352,7 @@ defmodule EXLA.Defn do
13411352
defp to_operator(:sort, [%Value{} = tensor, opts], ans, state) do
13421353
dimension = opts[:axis]
13431354

1344-
op =
1355+
operator =
13451356
case opts[:direction] do
13461357
:asc -> :less
13471358
:desc -> :greater
@@ -1350,7 +1361,7 @@ defmodule EXLA.Defn do
13501361
arg_typespec = Typespec.tensor(ans.type, {})
13511362
arg_typespecs = [arg_typespec, arg_typespec]
13521363

1353-
comp = sort_computation(op, ans.type, arg_typespecs, state)
1364+
comp = sort_computation(operator, ans.type, arg_typespecs, state)
13541365

13551366
Value.sort([tensor], comp, dimension, opts[:stable] == true, [expr_to_typespec(ans)]) |> hd()
13561367
end
@@ -1530,30 +1541,45 @@ defmodule EXLA.Defn do
15301541

15311542
## Computation helpers
15321543

1533-
defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do
1544+
defp sort_computation(
1545+
operator,
1546+
type,
1547+
arg_typespecs,
1548+
%{builder: %EXLA.MLIR.Function{} = function}
1549+
) do
15341550
{region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs)
15351551

15361552
typespec = Typespec.tensor({:pred, 8}, {})
15371553

1538-
op =
1539-
cond do
1540-
Nx.Type.integer?(type) ->
1541-
apply(Value, op, [lhs, rhs, typespec])
1542-
1543-
op == :less ->
1544-
is_nan = Value.is_nan(rhs, typespec)
1545-
Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec)
1546-
1547-
op == :greater ->
1548-
is_nan = Value.is_nan(lhs, typespec)
1549-
Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec)
1554+
{lhs, rhs} =
1555+
if Nx.Type.float?(type) do
1556+
{canonicalize_float_for_sort(lhs), canonicalize_float_for_sort(rhs)}
1557+
else
1558+
{lhs, rhs}
15501559
end
15511560

1561+
op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]])
1562+
15521563
Value.return(function, [op])
15531564
Function.pop_region(function)
15541565
region
15551566
end
15561567

1568+
defp canonicalize_float_for_sort(%Value{function: func} = op) do
1569+
# Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0).
1570+
# See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253
1571+
1572+
op_typespec = Value.get_typespec(op)
1573+
1574+
zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {}))
1575+
zeros = Value.constant(func, [0], op_typespec)
1576+
nans = Value.constant(func, [:nan], op_typespec)
1577+
1578+
pred_typespec = Typespec.tensor({:pred, 8}, {})
1579+
op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec)
1580+
Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec)
1581+
end
1582+
15571583
defp op_computation(
15581584
op,
15591585
arg_typespecs,

exla/lib/exla/mlir/value.ex

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,18 @@ defmodule EXLA.MLIR.Value do
5454
}
5555

5656
for {op, direction} <- @bin_comparison_ops do
57-
def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do
58-
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction))
57+
def unquote(op)(
58+
%Value{function: func} = lhs,
59+
%Value{function: func} = rhs,
60+
typespec,
61+
opts \\ []
62+
) do
63+
opts = Keyword.validate!(opts, total_order: false)
64+
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order])
5965
end
6066
end
6167

62-
defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do
68+
defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
6369
%{type: lhs_type} = get_typespec(lhs)
6470
%{type: rhs_type} = get_typespec(rhs)
6571

@@ -69,7 +75,11 @@ defmodule EXLA.MLIR.Value do
6975
attr_comparison_type(:float)
7076

7177
Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
72-
attr_comparison_type(:float)
78+
if total_order? do
79+
attr_comparison_type(:totalorder)
80+
else
81+
attr_comparison_type(:float)
82+
end
7383

7484
true ->
7585
attr_comparison_type(:notype)
@@ -663,9 +673,15 @@ defmodule EXLA.MLIR.Value do
663673
typespecs
664674
) do
665675
result_types = typespecs_to_mlir_types(typespecs)
666-
regions = [on_true, on_false]
667-
pred = convert(pred, Typespec.tensor({:pred, 8}, {}))
668-
op(func, "stablehlo.if", [pred], result_types, regions: regions)
676+
677+
# TODO Jax does not support stablehlo.if, they use stablhelo.case instead.
678+
# It most likely makes sense for use to do the same. That said, note that
679+
# stablehlo.case is implemented for Metal, but does not lower reliably.
680+
# Reported in https://github.com/google/jax/issues/21601
681+
682+
regions = [on_false, on_true]
683+
pred = convert(pred, Typespec.tensor({:s, 32}, {}))
684+
op(func, "stablehlo.case", [pred], result_types, regions: regions)
669685
end
670686

671687
def infeed(%Value{function: func} = token, typespecs) do

exla/lib/exla/nif.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ defmodule EXLA.NIF do
5353

5454
def get_tpu_client(), do: :erlang.nif_error(:undef)
5555

56+
def get_mps_client(), do: :erlang.nif_error(:undef)
57+
5658
def get_supported_platforms, do: :erlang.nif_error(:undef)
5759

5860
def get_device_count(_client),

exla/mix.exs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ defmodule EXLA.MixProject do
5151
cuda: [platform: :cuda],
5252
rocm: [platform: :rocm],
5353
tpu: [platform: :tpu],
54+
mps: [platform: :mps],
5455
host: [platform: :host]
5556
],
56-
preferred_clients: [:cuda, :rocm, :tpu, :host]
57+
preferred_clients: [:cuda, :rocm, :tpu, :mps, :host]
5758
]
5859
]
5960
end
@@ -128,11 +129,31 @@ defmodule EXLA.MixProject do
128129
:ok -> File.write!(xla_snapshot_path, xla_archive_path)
129130
{:error, term} -> Mix.raise("failed to extract xla archive, reason: #{inspect(term)}")
130131
end
132+
133+
# TODO should be packed into the XLA archive
134+
download_metal_plugin!(xla_extension_path)
131135
end
132136

133137
{:ok, []}
134138
end
135139

140+
defp download_metal_plugin!(xla_extension_path) do
141+
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")
142+
143+
wheel_url =
144+
"https://files.pythonhosted.org/packages/d6/4f/f5d128a493b7387fbbe0e6906544214af2a6b86af30302dd6ffb9dc66a74/jax_metal-0.0.7-py3-none-macosx_13_0_arm64.whl"
145+
146+
wheel_path = Path.join(xla_extension_path, "jax_metal.whl")
147+
148+
{_, 0} = System.shell("wget --output-document=#{wheel_path} #{wheel_url}")
149+
{_, 0} = System.shell("unzip #{wheel_path} -d #{xla_extension_path}")
150+
151+
wheel_plugin_path =
152+
Path.join(xla_extension_path, "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib")
153+
154+
File.cp!(wheel_plugin_path, plugin_path)
155+
end
156+
136157
defp cached_make(_) do
137158
contents =
138159
for path <- Path.wildcard("c_src/**/*"),

exla/test/exla/backend_test.exs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,73 @@ defmodule EXLA.BackendTest do
2727
@skip_mac_arm []
2828
end
2929

30+
if EXLA.Client.default_name() == :mps do
31+
@skip_mps [
32+
# Missing support for "stablehlo.reduce_window".
33+
# Reported in https://github.com/google/jax/issues/21387
34+
window_max: 3,
35+
window_min: 3,
36+
window_sum: 3,
37+
window_product: 3,
38+
window_reduce: 5,
39+
window_scatter_min: 5,
40+
window_scatter_max: 5,
41+
window_mean: 3,
42+
# Argmax/armin fail when a custom :type is passed.
43+
# Reported in https://github.com/google/jax/issues/21577
44+
argmin: 2,
45+
argmax: 2,
46+
# Missing support for general "stablehlo.reduce". Some cases work
47+
# becuase they are special-cased.
48+
# Reported in https://github.com/google/jax/issues/21384
49+
reduce: 4,
50+
# Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros",
51+
# "stablehlo.cbrt".
52+
# Reported in https://github.com/google/jax/issues/21389
53+
count_leading_zeros: 1,
54+
population_count: 1,
55+
cbrt: 1,
56+
# Matrix multiplication for integers is not supported
57+
dot: 2,
58+
dot: 4,
59+
dot: 6,
60+
covariance: 3,
61+
# (edge case) Put slice with overflowing slice, different behaviour.
62+
# Reported in https://github.com/google/jax/issues/21392
63+
put_slice: 3,
64+
# (edge case) Slice with overflowing index, different behaviour.
65+
# Reported in https://github.com/google/jax/issues/21393
66+
slice: 4,
67+
# (edge case) Top-k wrong behaviour with NaNs.
68+
# Reported in https://github.com/google/jax/issues/21397
69+
top_k: 2,
70+
# Missing support for complex numbers.
71+
# Tracked in https://github.com/google/jax/issues/16416
72+
complex: 2,
73+
conjugate: 1,
74+
conv: 3,
75+
fft: 2,
76+
fft2: 2,
77+
ifft: 2,
78+
ifft2: 2,
79+
imag: 1,
80+
is_infinity: 1,
81+
is_nan: 1,
82+
phase: 1,
83+
real: 1,
84+
sigil_MAT: 2,
85+
# Missing support for float-64.
86+
# Tracked in https://github.com/google/jax/issues/20938
87+
iota: 2,
88+
as_type: 2,
89+
atan2: 2
90+
]
91+
else
92+
@skip_mps []
93+
end
94+
3095
doctest Nx,
31-
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm
96+
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_mps
3297

3398
test "Nx.to_binary/1" do
3499
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)

0 commit comments

Comments
 (0)