Skip to content

Commit 8a851f1

Browse files
committed
Bump jax-metal to 0.1.0
1 parent 102e62b commit 8a851f1

File tree

3 files changed

+5
-12
lines changed

3 files changed

+5
-12
lines changed

exla/lib/exla/defn.ex

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,7 @@ defmodule EXLA.Defn do
663663
result =
664664
Value.gather(
665665
tensor,
666-
# TODO remove conversion (unsigned indices fail)
667-
# Reported in https://github.com/google/jax/issues/21547
668-
to_type(indices, {:s, 32}),
666+
indices,
669667
index_vector_dim,
670668
slice_sizes,
671669
offset_dims,
@@ -1297,9 +1295,6 @@ defmodule EXLA.Defn do
12971295
defp to_operator(:put_slice, [%Value{} = tensor, start_indices, slice], ans, _state) do
12981296
tensor = to_type(tensor, ans.type)
12991297
slice = to_type(slice, ans.type)
1300-
# TODO remove conversion (unsigned indices fail)
1301-
# Reported in https://github.com/google/jax/issues/21547
1302-
start_indices = Enum.map(start_indices, &to_type(&1, {:s, 32}))
13031298
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
13041299
end
13051300

@@ -1322,9 +1317,7 @@ defmodule EXLA.Defn do
13221317

13231318
Value.gather(
13241319
tensor,
1325-
# TODO remove conversion (unsigned indices fail)
1326-
# Reported in https://github.com/google/jax/issues/21547
1327-
to_type(indices, {:s, 32}),
1320+
indices,
13281321
index_vector_dim,
13291322
slice_sizes,
13301323
offset_dims,

exla/mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ defmodule EXLA.MixProject do
141141
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")
142142

143143
wheel_url =
144-
"https://files.pythonhosted.org/packages/d6/4f/f5d128a493b7387fbbe0e6906544214af2a6b86af30302dd6ffb9dc66a74/jax_metal-0.0.7-py3-none-macosx_13_0_arm64.whl"
144+
"https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl"
145145

146146
wheel_path = Path.join(xla_extension_path, "jax_metal.whl")
147147

exla/test/exla/backend_test.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ defmodule EXLA.BackendTest do
3939
window_scatter_min: 5,
4040
window_scatter_max: 5,
4141
window_mean: 3,
42-
# Argmax/armin fail when a custom :type is passed.
43-
# Reported in https://github.com/google/jax/issues/21577
42+
# (edge case) Argmax/argmin return wrong value in case of NaN.
43+
# Reported in https://github.com/google/jax/issues/21821
4444
argmin: 2,
4545
argmax: 2,
4646
# Missing support for general "stablehlo.reduce". Some cases work

0 commit comments

Comments
 (0)