Skip to content

Commit 6d1bad3

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
Explicitly make out a jax.Array before converting it in lax_numpy.array.
This allows avoiding an unnecessary copy in `lax_internal._convert_element_type` when the original `dtype` is the same as the output `dtype` (the fast path does an instance check, which fails on `np.ndarray`s). PiperOrigin-RevId: 762322412
1 parent 199d9f7 commit 6d1bad3

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5536,6 +5536,9 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
55365536
# https://github.com/jax-ml/jax/pull/6047. More correct would be to call
55375537
# coerce_to_array on each leaf, but this may have performance implications.
55385538
out = np.asarray(object, dtype=dtype)
5539+
# Ensuring that the output is a `jax.Array` allows avoiding a copy in
5540+
# `lax_internal._convert_element_type` when `out` is a NumPy array.
5541+
out = asarray(out, dtype=dtype)
55395542
elif isinstance(object, Array):
55405543
assert object.aval is not None
55415544
out = _array_copy(object) if copy else object

0 commit comments

Comments
 (0)