Skip to content

Commit

Permalink
Support tuple of axis in softmax_cross_entropy_with_integer_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
daskol committed Jan 2, 2025
1 parent 1e08bcc commit 9b7dde2
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
44 changes: 42 additions & 2 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np
from optax import projections

if np.__version__.startswith('1.'):
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
else:
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple


def sigmoid_binary_cross_entropy(
logits,
Expand Down Expand Up @@ -273,7 +279,7 @@ def softmax_cross_entropy(
def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
axis: Union[int, tuple[int, ...]] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes softmax cross entropy between the logits and integer labels.
Expand All @@ -297,7 +303,10 @@ def softmax_cross_entropy_with_integer_labels(
labels: Integers specifying the correct class for each input, with shape
``[batch_size]``. Class labels are assumed to be between 0 and
``num_classes - 1`` inclusive.
axis: Axis or axes along which to compute.
axis: Axis or axes along which to compute. If a tuple of axes is passed
then ``num_classes`` must match the total number of elements in ``axis``
dimensions and a label is interpreted as a flat index in a ``logits``
slice of shape ``logits[axis]``.
where: Elements to include in the computation.
Returns:
Expand All @@ -313,6 +322,21 @@ def softmax_cross_entropy_with_integer_labels(
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761297 2.951799 ]
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
>>> # elements indices in slice of shape (3, 4)
>>> ix = jnp.array([[1, 2]])
>>> jx = jnp.array([[1, 3]])
>>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
... logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.458669 0.45866907]]
References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Wikipedia
Expand All @@ -329,6 +353,22 @@ def softmax_cross_entropy_with_integer_labels(
"""
chex.assert_type([logits], float)
chex.assert_type([labels], int)
if isinstance(axis, int):
axis = normalize_axis_index(axis, logits.ndim)
elif isinstance(axis, tuple):
# Move all "feature" dimensions to the end preserving axis ordering and
# subsequent flattening "feature" dimensions to a single one.
logit_axis = normalize_axis_tuple(axis, logits.ndim, argname='logits')
batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis)
axis = len(batch_axis)
logits = logits.transpose(batch_axis + logit_axis)
logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1, ))
if where is not None:
where = where.transpose(batch_axis + logit_axis)
where = where.reshape(where.shape[:len(batch_axis)] + (-1, ))
else:
raise ValueError('Keyword argument \'axis\' must be of type \'int\' or '
f'\'tuple[int, ...]\' but actual type is {type(axis)}.')
# This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
# we avoid subtracting the normalizer from all values, just from the values
# for the correct labels.
Expand Down
34 changes: 34 additions & 0 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,40 @@ def test_axis(self, shape, axis):
)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
{'axis': (1, 3), 'shape': (2, 3, 4, 5)},
{'axis': (3, 2), 'shape': (2, 3, 4, 5)},
{'axis': (2, 3), 'shape': (2, 3, 4, 5)},
{'axis': (-3, -1), 'shape': (2, 3, 4, 5)},
{'axis': (-1, -2), 'shape': (2, 3, 4, 5)},
{'axis': (-2, -1), 'shape': (2, 3, 4, 5)},
)
def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]):
# Canonicalize axis and calculate shapes.
ndim = len(shape)
logits_axis = tuple((x + ndim) % ndim for x in axis)
labels_axis = tuple(x for x in range(ndim) if x not in logits_axis)
# Obtain shapes of batch and logits subspaces.
logits_shape = tuple(shape[x] for x in logits_axis)
labels_shape = tuple(shape[x] for x in labels_axis)
num_classes: float = np.prod(logits_shape).item()

key = jax.random.key(42)
keys = jax.random.split(key, 2)
logits = jax.random.uniform(keys[0], labels_shape + (num_classes, ))
labels = jax.random.randint(keys[1], labels_shape, 0, num_classes - 1)

fn = _classification.softmax_cross_entropy_with_integer_labels
desired = fn(logits, labels)

# Apply inverse axes permutation to obtain an array of `shape` shape.
logits = logits \
.reshape(labels_shape + logits_shape) \
.transpose(labels_axis + logits_axis)
assert logits.shape == shape
actual = fn(logits, labels, axis)
np.testing.assert_allclose(actual, desired)


class SigmoidCrossEntropyTest(parameterized.TestCase):

Expand Down

0 comments on commit 9b7dde2

Please sign in to comment.