These are the release notes for JAX.
- New features:
- Adds support for fast traceback collection.
- Adds preliminary support for on-device heap profiling.
- Implements
np.nextafter
forbfloat16
types. - Complex128 support for FFTs on CPU and GPU.
- Bugfixes:
- Improved float64
tanh
accuracy on GPU. - float64 scatters on GPU are much faster.
- Complex matrix multiplication on CPU should be much faster.
- Stable sorts on CPU should actually be stable now.
- Concurrency bug fix in CPU backend.
- Improved float64
- GitHub commits.
- New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive #3318.
- GitHub commits.
- New features:
- lax.cond supports a single-operand form, taken as the argument to both branches #2993.
- Notable changes:
- The format of the transforms keyword for the lax.experimental.host_callback.id_tap primitive has changed #3132.
- GitHub commits.
- New features:
- Support for reduction over subsets of a pmapped axis using
axis_index_groups
#2382. - Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap (#3006).
- Support for reduction over subsets of a pmapped axis using
- Notable changes:
- The visibility of names exported from :py:module:`jax.numpy` has been tightened. This may break code that was making use of names that were previously exported accidentally.
- Fixes crash for outfeed.
- GitHub commits.
- New features:
- Support for
in_axes=None
on :func:`pmap` #2896.
- Support for
- Fixes crash for linear algebra functions on Mac OS X (#432).
- Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
- GitHub commits.
- New features:
- Differentiation of determinants of singular matrices #2809.
- Bug fixes:
- Fix :func:`odeint` differentiation with respect to time of ODEs with time-dependent dynamics #2817, also add ODE CI testing.
- Fix :func:`lax_linalg.qr` differentiation #2867.
- Fixes segfault: jax-ml#2755
- Plumb is_stable option on Sort HLO through to Python.
- GitHub commits.
- New features:
- Add syntactic sugar for functional indexed updates #2684.
- Add :func:`jax.numpy.linalg.multi_dot` #2726.
- Add :func:`jax.numpy.unique` #2760.
- Add :func:`jax.numpy.rint` #2724.
- Add :func:`jax.numpy.rint` #2724.
- Add more primitive rules for :func:`jax.experimental.jet`.
- Bug fixes:
- Fix :func:`logaddexp` and :func:`logaddexp2` differentiation at zero #2107.
- Improve memory usage in reverse-mode autodiff without :func:`jit` #2719.
- Better errors:
- Improves error message for reverse-mode differentiation of :func:`lax.while_loop` #2129.
- Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
- Bugfix for
batch_group_count
convolutions. - Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
- GitHub commits.
- Added
jax.custom_jvp
andjax.custom_vjp
from #2026, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works). - Add
scipy.sparse.linalg.cg
#2566. - Changed how Tracers are printed to show more useful information for debugging #2591.
- Made
jax.numpy.isclose
handlenan
andinf
correctly #2501. - Added several new rules for
jax.experimental.jet
#2537. - Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn't provided. - Fix some missing cases of broadcasting in
jax.numpy.einsum
#2512. - Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan #2596 and makereduce_prod
differentiable to arbitray order #2597. - Add
batch_group_count
toconv_general_dilated
#2635. - Add docstring for
test_util.check_grads
#2656. - Add
callback_transform
#2665. - Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
- Fixed a performance regression for Resnet-50 on GPU.
- GitHub commits.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details. - Added an
all_gather
parallel convenience function. - More type annotations in core code.
- jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- GitHub commits.
- Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
- GitHub commits.
- New features:
- :py:func:`jax.pmap` has
static_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
in :py:func:`jax.jit`. - Improved error messages for when tracers are mistakenly saved in global state.
- Added :py:func:`jax.nn.one_hot` utility function.
- Added :py:module:`jax.experimental.jet` for exponentially faster higher-order automatic differentiation.
- Added more sanity checking to arguments of :py:func:`jax.lax.broadcast_in_dim`.
- :py:func:`jax.pmap` has
- The minimum jaxlib version is now 0.1.41.
- Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
- Includes prototype support for multihost GPU computations that communicate via NCCL.
- Improves performance of NCCL collectives on GPU.
- Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
- Supports device assignments known at XLA compilation time.
- GitHub commits.
- Breaking changes
- The minimum jaxlib version is now 0.1.38.
- Simplified :py:class:`Jaxpr` by removing the
Jaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
- New features:
- Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes (jax-ml#2091) - JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
- JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba. - JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
- Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
- Reverse-mode automatic differentiation (e.g.
- Updates XLA.
- CUDA 9.0 is no longer supported.
- CUDA 10.2 wheels are now built by default.
Breaking changes
- JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
- Forward-mode automatic differentiation (jvp) of while loop (jax-ml#1980)
- New NumPy and SciPy functions:
- :py:func:`jax.numpy.fft.fft2`
- :py:func:`jax.numpy.fft.ifft2`
- :py:func:`jax.numpy.fft.rfft`
- :py:func:`jax.numpy.fft.irfft`
- :py:func:`jax.numpy.fft.rfft2`
- :py:func:`jax.numpy.fft.irfft2`
- :py:func:`jax.numpy.fft.rfftn`
- :py:func:`jax.numpy.fft.irfftn`
- :py:func:`jax.numpy.fft.fftfreq`
- :py:func:`jax.numpy.fft.rfftfreq`
- :py:func:`jax.numpy.linalg.matrix_rank`
- :py:func:`jax.numpy.linalg.matrix_power`
- :py:func:`jax.scipy.special.betainc`
- Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
- With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.