Releases: jax-ml/jax
JAX v0.5.0
As of this release, JAX now uses effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.
-
Breaking changes
-
Enable
jax_threefry_partitionable
by default (see
the update note). -
This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussion, see #22936.Two key factors motivated this decision:
- The Mac x86 build (only) has a number of test failures and crashes. We
would prefer to ship no release than a broken release. - Mac x86 hardware is end-of-life and cannot be easily obtained for
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again. - The Mac x86 build (only) has a number of test failures and crashes. We
-
-
Changes:
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025. - The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025. jax.numpy.einsum
now defaults tooptimize='auto'
rather than
optimize='optimal'
. This avoids exponentially-scaling trace-time in
the case of many arguments (#25214
).jax.numpy.linalg.solve
no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
usesolve(a, b[..., None]).squeeze(-1)
.
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
-
New Features
jax.numpy.fft.fftn
,jax.numpy.fft.rfftn
,
jax.numpy.fft.ifftn
, andjax.numpy.fft.irfftn
now support
transforms in more than 3 dimensions, which was previously the limit. See
#25606
for more details.- Support added for user defined state in the FFI via the new
jax.ffi.register_ffi_type_id
function. - The AOT lowering
.as_text()
method now supports thedebug_info
option
to include debugging information, e.g., source location, in the output.
-
Deprecations
- From
jax.interpreters.xla
,abstractify
andpytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name
injax.core
. jax.scipy.special.lpmn
andjax.scipy.special.lpmn_values
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.- The
jax.extend.ffi
submodule was moved tojax.ffi
, and the
previous import path is deprecated.
- From
-
Deletions
jax_enable_memories
flag has been deleted and the behavior of that flag
is on by default.- From
jax.lib.xla_client
, the previously-deprecatedDevice
and
XlaRuntimeError
symbols have been removed; instead usejax.Device
andjax.errors.JaxRuntimeError
respectively. - The
jax.experimental.array_api
module has been removed after being
deprecated in JAX v0.4.32. Since that release,jax.numpy
supports
the array API directly.
JAX v0.4.38
-
Changes:
jax.tree.flatten_with_path
andjax.tree.map_with_path
are added
as shortcuts of the correspondingtree_util
functions.
-
Deprecations
- a number of APIs in the internal
jax.core
namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name injax.extend.core
; see the documentation for {mod}jax.extend
for information on the compatibility guarantees of these semi-public extensions. - Several previously-deprecated APIs have been removed, including:
- from
jax.core
:check_eqn
,check_type
,check_valid_jaxtype
, and
non_negative_dim
. - from
jax.lib.xla_bridge
:xla_client
anddefault_backend
. - from
jax.lib.xla_client
:_xla
andbfloat16
. - from
jax.numpy
:round_
.
- from
- a number of APIs in the internal
-
New Features
jax.export.export
can be used for device-polymorphic export with
shardings constructed with {func}jax.sharding.AbstractMesh
.
See the jax.export documentation.- Added
jax.lax.split
. This is a primitive version of
jax.numpy.split
, added because it yields a more compact
transpose during automatic differentiation.
JAX v0.4.37
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
- Bug fixes
- Fixed a bug where
jit
would error if an argument was namedf
(#25329). - Fix a bug that will throw
index out of range
error in
jax.lax.while_loop
if the user registers pytree node class with
different aux data for the flatten and flatten_with_path. - Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
- Fixed a bug where
JAX v0.4.36
-
Breaking Changes
-
This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,post_process_call
,
new_base_main
,custom_bind
, and so on. The change should only affect
users that use JAX internals.If you do use JAX internals then you may need to
update your code (see
c36e1f7
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallback
flag as a workaround, and if
you need help updating your code then please file a bug. -
jax.experimental.jax2tf.convert
withnative_serialization=False
or withenable_xla=False
have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases.jax2tf
with native serialization will still be supported. -
In
jax.interpreters.xla
, thexb
,xc
, andxe
symbols have been removed
after being deprecated in JAX v0.4.31. Instead usexb = jax.lib.xla_bridge
,
xc = jax.lib.xla_client
, andxe = jax.lib.xla_extension
. -
The deprecated module
jax.experimental.export
has been removed. It was replaced
byjax.export
in JAX v0.4.30. See the migration guide
for information on migrating to the new API. -
The
initial
argument tojax.nn.softmax
andjax.nn.log_softmax
has been removed, after being deprecated in v0.4.27. -
Calling
np.asarray
on typed PRNG keys (i.e. keys produced byjax.random.key
)
now raises an error. Previously, this returned a scalar object array. -
The following deprecated methods and functions in
jax.export
have
been removed:jax.export.DisabledSafetyCheck.shape_assertions
: it had no effect
already.jax.export.Exported.lowering_platforms
: useplatforms
.jax.export.Exported.mlir_module_serialization_version
:
usecalling_convention_version
.jax.export.Exported.uses_shape_polymorphism
:
useuses_global_constants
.- the
lowering_platforms
kwarg forjax.export.export
: use
platforms
instead.
-
The kwargs
symbolic_scope
andsymbolic_constraints
from
jax.export.symbolic_args_specs
have been removed. They were
deprecated in June 2024. Usescope
andconstraints
instead. -
Hashing of tracers, which has been deprecated since version 0.4.30, now
results in aTypeError
. -
Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Runpython build/build.py --help
for
more details. Brief overview of the new subcommand options:build
: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.
-
jax.scipy.linalg.toeplitz
now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can calljax.numpy.ravel
on the function inputs. -
jax.scipy.special.gamma
andjax.scipy.special.gammasgn
now
return NaN for negative integer inputs, to match the behavior of SciPy from
scipy/scipy#21827. -
jax.clear_backends
was removed after being deprecated in v0.4.26. -
We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use thedisabled_checks
parameter. See more details in the documentation.
-
-
New Features
jax.jit
got a newcompiler_options: dict[str, Any]
argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.jax.tree_util.register_dataclass
now allows metadata fields to be
declared inline viadataclasses.field
. See the function documentation
for examples.- Added
jax.numpy.put_along_axis
. jax.lax.linalg.eig
and the relatedjax.numpy
functions
(jax.numpy.linalg.eig
andjax.numpy.linalg.eigvals
) are now
supported on GPU. See #24663 for more details.- Added two new configuration flags,
jax_exec_time_optimization_effort
andjax_memory_fitting_effort
, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
-
Bug fixes
- Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
#24843 for more details.
- Fixed a bug where the GPU implementations of LU and QR decomposition would
-
Deprecations
jax.lib.xla_extension.ArrayImpl
andjax.lib.xla_client.ArrayImpl
are deprecated;
usejax.Array
instead.jax.lib.xla_extension.XlaRuntimeError
is deprecated; usejax.errors.JaxRuntimeError
instead.
JAX v0.4.35
-
Breaking Changes
jax.numpy.isscalar
now returns True for any array-like object with
zero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.jax.experimental.host_callback
has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See#20385
for a discussion of alternatives.
-
Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFT
operations. The semi-public APIjax.lib.xla_client.FftType
has been
deprecated.- TPU: JAX now installs TPU support from the
libtpu
package rather than
libtpu-nightly
. For the next few releases JAX will pin an empty version of
libtpu-nightly
as well aslibtpu
to ease the transition; that dependency
will be removed in Q1 2025.
-
Deprecations:
- The semi-public API
jax.lib.xla_client.PaddingType
has been deprecated.
No JAX APIs consume this type, so there is no replacement. - The default behavior of
jax.pure_callback
and
jax.extend.ffi.ffi_call
undervmap
has been deprecated and so has
thevectorized
parameter to those functions. Thevmap_method
parameter
should be used instead for better defined behavior. See the discussion in
#23881
for more details. - The semi-public API
jax.lib.xla_client.register_custom_call_target
has
been deprecated. Use the JAX FFI instead. - The semi-public APIs
jax.lib.xla_client.dtype_to_etype
,
jax.lib.xla_client.ops
,
jax.lib.xla_client.shape_from_pyval
,jax.lib.xla_client.PrimitiveType
,
jax.lib.xla_client.Shape
,jax.lib.xla_client.XlaBuilder
, and
jax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLO
instead.
- The semi-public API
JAX v0.4.34
-
New Functionality
- This release includes wheels for Python 3.13. Free-threading mode is not yet
supported. jax.errors.JaxRuntimeError
has been added as a public alias for the
formerly privateXlaRuntimeError
type.
- This release includes wheels for Python 3.13. Free-threading mode is not yet
-
Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.array[0]
on a pmap result now introduces a reshape (usearray[0:1]
instead).- The per-shard shape (accessable via
jax_array.addressable_shards
or
jax_array.addressable_data(0))
now has a leading(1, ...)
. Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
--jax_host_callback_legacy
configuration value toTrue
, which means that
if your code usesjax.experimental.host_callback
APIs, those API calls
will be implemented in terms of the newjax.experimental.io_callback
API.
If this breaks your code, for a very limited time, you can set the
--jax_host_callback_legacy
toTrue
. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See #20385 for a discussion.
-
Deprecations
- In
jax.numpy.trim_zeros
, non-arraylike arguments or arraylike
arguments withndim != 1
are now deprecated, and in the future will result
in an error. - Internal pretty-printing tools
jax.core.pp_*
have been removed, after
being deprecated in JAX v0.4.30. jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Use
jax.errors.JaxRuntimeError
instead.
- In
-
Deletion:
jax.xla_computation
is deleted. It has been 3 months since its deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality asjax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced with
jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.- You can also use
.out_info
property ofjax.stages.Lowered
to get the
output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
with
jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument.
The argument was only used byxmap
which was removed in 0.4.31.jax.tree.map(f, None, non-None)
, which previously emitted a
DeprecationWarning
, now raises an error.None
is only a tree-prefix of itself. To preserve the current behavior, you can
askjax.tree.map
to treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please use
jax.sharding.Sharding
.
-
Bug fixes
- Fixed a bug where
jax.numpy.cumsum
would produce incorrect outputs
if a non-boolean input was provided anddtype=bool
was specified. - Edit implementation of
jax.numpy.ldexp
to get correct gradient.
- Fixed a bug where
JAX release v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu-nightly
.
This release also fixes an inaccurate result for F64 tanh on CPU (#23590).
Jaxlib release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
JAX release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
Jaxlib release v0.4.31
jaxlib-v0.4.31 jaxlib version 0.4.31