JAX's type promotion rules (i.e., the result of :func:`jax.numpy.promote_types` for each pair of types) are given by the following table, where, for example
- "b1" means
np.bool_
, - "s2" means
np.int16
, - "u4" means
np.uint32
, - "bf" means
np.bfloat16
, - "f2" means
np.float16
, and - "c8" means
np.complex128
.
b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
u8 | u8 | u8 | u8 | u8 | u8 | f8 | f8 | f8 | f8 | bf | f2 | f4 | f8 | c4 | c8 |
i1 | i1 | i2 | i4 | i8 | f8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
i2 | i2 | i2 | i4 | i8 | f8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
i4 | i4 | i4 | i4 | i8 | f8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
i8 | i8 | i8 | i8 | i8 | f8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c4 | c8 |
bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c4 | c8 |
f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c4 | c8 |
f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c4 | c8 |
f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c8 | c8 |
c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c8 | c4 | c8 |
c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 |
Jax's type promotion rules differ from those of NumPy, as given by :func:`numpy.promote_types`, in those cells highlighted with a green background in the table above. There are two key differences:
when promoting an integer or boolean type against a floating-point or complex type, JAX always prefers the type of the floating-point or complex type.
Accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64-bit floating point types (GPUs) or do not support 64-bit floating point types at all (TPUs). Classic NumPy's promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators.
JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floating-point types are similar to those used by PyTorch.
JAX supports the bfloat16 non-standard 16-bit floating point type (
jax.numpy.bfloat16
), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE-754float16
, with whichbfloat16
promotes to afloat32
.