Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
9b7dde2 by Daniel Bershatsky <[email protected]>:

Support tuple of axis in `softmax_cross_entropy_with_integer_labels`

--
68ebabd by Daniel Bershatsky <[email protected]>:

Adjust according to review comments

COPYBARA_INTEGRATE_REVIEW=#1165 from daskol:fix/softmax_cross_entropy_with_integer_labels 68ebabd
PiperOrigin-RevId: 714273561
  • Loading branch information
daskol authored and OptaxDev committed Jan 11, 2025
1 parent 5a3b829 commit 6ed9095
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
60 changes: 55 additions & 5 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Classification losses."""

import functools
import operator
from typing import Optional, Union

import chex
Expand All @@ -23,6 +24,24 @@
from optax import projections


def canonicalize_axis(axis, ndim):
"""Vendored version of :func:`numpy.lib.array_utils.normalize_axis_index`.
"""
if 0 <= (axis := operator.index(axis)) < ndim:
return axis
elif -ndim <= axis < 0:
return axis + ndim
else:
raise ValueError(f'axis {axis} is out of bounds for array of '
f'dimension {ndim}')


def canonicalize_axes(axes, ndim) -> tuple[int, ...]:
"""Vendored version of :func:`numpy.lib.array_utils.normalize_axis_tuple`.
"""
return tuple(canonicalize_axis(x, ndim) for x in axes)


def sigmoid_binary_cross_entropy(
logits,
labels,
Expand Down Expand Up @@ -273,7 +292,7 @@ def softmax_cross_entropy(
def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
axis: Union[int, None] = -1,
axis: Union[int, tuple[int, ...]] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes softmax cross entropy between the logits and integer labels.
Expand All @@ -297,7 +316,10 @@ def softmax_cross_entropy_with_integer_labels(
labels: Integers specifying the correct class for each input, with shape
``[batch_size]``. Class labels are assumed to be between 0 and
``num_classes - 1`` inclusive.
axis: Axis along which to compute.
axis: Axis or axes along which to compute. If a tuple of axes is passed
then ``num_classes`` must match the total number of elements in ``axis``
dimensions and a label is interpreted as a flat index in a ``logits``
slice of shape ``logits[axis]``.
where: Elements to include in the computation.
Returns:
Expand All @@ -313,6 +335,21 @@ def softmax_cross_entropy_with_integer_labels(
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761297 2.951799 ]
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
>>> # elements indices in slice of shape (3, 4)
>>> ix = jnp.array([[1, 2]])
>>> jx = jnp.array([[1, 3]])
>>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
... logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.458669 0.45866907]]
References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Wikipedia
Expand All @@ -329,9 +366,22 @@ def softmax_cross_entropy_with_integer_labels(
"""
chex.assert_type([logits], float)
chex.assert_type([labels], int)
if axis is not None and not isinstance(axis, int):
raise ValueError(f'axis = {axis} is unsupported. Provide an int or None.')

if isinstance(axis, int):
axis = canonicalize_axis(axis, logits.ndim)
elif isinstance(axis, tuple):
# Move all "feature" dimensions to the end preserving axis ordering and
# subsequent flattening "feature" dimensions to a single one.
logit_axis = canonicalize_axes(axis, logits.ndim)
batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis)
axis = len(batch_axis)
logits = logits.transpose(batch_axis + logit_axis)
logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1,))
if where is not None:
where = where.transpose(batch_axis + logit_axis)
where = where.reshape(where.shape[:len(batch_axis)] + (-1,))
else:
raise ValueError('Keyword argument \'axis\' must be of type \'int\' or '
f'\'tuple[int, ...]\' but actual type is {type(axis)}.')
# This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
# we avoid subtracting the normalizer from all values, just from the values
# for the correct labels.
Expand Down
36 changes: 36 additions & 0 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,42 @@ def test_axis(self, shape, axis):
)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
{'axis': (1, 3), 'shape': (2, 3, 4, 5)},
{'axis': (3, 2), 'shape': (2, 3, 4, 5)},
{'axis': (2, 3), 'shape': (2, 3, 4, 5)},
{'axis': (-3, -1), 'shape': (2, 3, 4, 5)},
{'axis': (-1, -2), 'shape': (2, 3, 4, 5)},
{'axis': (-2, -1), 'shape': (2, 3, 4, 5)},
{'axis': (0, 1, 3), 'shape': (2, 3, 4, 5)},
{'axis': (-4, -3, -1), 'shape': (2, 3, 4, 5)},
)
def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]):
# Canonicalize axis and calculate shapes.
ndim = len(shape)
logits_axis = tuple((x + ndim) % ndim for x in axis)
labels_axis = tuple(x for x in range(ndim) if x not in logits_axis)
# Obtain shapes of batch and logits subspaces.
logits_shape = tuple(shape[x] for x in logits_axis)
labels_shape = tuple(shape[x] for x in labels_axis)
num_classes: float = np.prod(logits_shape).item()

key = jax.random.key(42)
keys = jax.random.split(key, 2)
logits = jax.random.uniform(keys[0], labels_shape + (num_classes,))
labels = jax.random.randint(keys[1], labels_shape, 0, num_classes - 1)

fn = _classification.softmax_cross_entropy_with_integer_labels
desired = fn(logits, labels)

# Apply inverse axes permutation to obtain an array of `shape` shape.
perm = labels_axis + logits_axis
perm_inv = tuple(i for i, _ in sorted(enumerate(perm), key=lambda x: x[1]))
logits = logits.reshape(labels_shape + logits_shape).transpose(perm_inv)
assert logits.shape == shape
actual = fn(logits, labels, axis)
np.testing.assert_allclose(actual, desired)


class SigmoidCrossEntropyTest(parameterized.TestCase):

Expand Down

0 comments on commit 6ed9095

Please sign in to comment.