From 3c5e04e0ae1b87c95031dbac9a28099d68dd9320 Mon Sep 17 00:00:00 2001
From: Daniel Bershatsky <daniel.bershatsky@gmail.com>
Date: Fri, 3 Jan 2025 00:11:03 +0300
Subject: [PATCH] Support tuple of axis in
 `softmax_cross_entropy_with_integer_labels`

---
 optax/losses/_classification.py      | 44 ++++++++++++++++++++++++++--
 optax/losses/_classification_test.py | 34 +++++++++++++++++++++
 2 files changed, 76 insertions(+), 2 deletions(-)

diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py
index 087e1ae4f..e13c876c9 100644
--- a/optax/losses/_classification.py
+++ b/optax/losses/_classification.py
@@ -20,8 +20,14 @@
 import chex
 import jax
 import jax.numpy as jnp
+import numpy as np
 from optax import projections
 
+if np.__version__.startswith('1.'):
+  from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
+else:
+  from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
+
 
 def sigmoid_binary_cross_entropy(
     logits,
@@ -273,7 +279,7 @@ def softmax_cross_entropy(
 def softmax_cross_entropy_with_integer_labels(
     logits: chex.Array,
     labels: chex.Array,
-    axis: Union[int, tuple[int, ...], None] = -1,
+    axis: Union[int, tuple[int, ...]] = -1,
     where: Union[chex.Array, None] = None,
 ) -> chex.Array:
   r"""Computes softmax cross entropy between the logits and integer labels.
@@ -297,7 +303,10 @@ def softmax_cross_entropy_with_integer_labels(
     labels: Integers specifying the correct class for each input, with shape
       ``[batch_size]``. Class labels are assumed to be between 0 and
       ``num_classes - 1`` inclusive.
-    axis: Axis or axes along which to compute.
+    axis: Axis or axes along which to compute. If a tuple of axes is passed
+      then ``num_classes`` must match the total number of elements in ``axis``
+      dimensions and a label is interpreted as a flat index in a ``logits``
+      slice of shape ``logits[axis]``.
     where: Elements to include in the computation.
 
   Returns:
@@ -313,6 +322,21 @@ def softmax_cross_entropy_with_integer_labels(
     >>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
     [0.2761297 2.951799 ]
 
+    >>> import jax.numpy as jnp
+    >>> import numpy as np
+    >>> import optax
+    >>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
+    >>> shape = (1, 2, 3, 4)
+    >>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
+    >>> # elements indices in slice of shape (3, 4)
+    >>> ix = jnp.array([[1, 2]])
+    >>> jx = jnp.array([[1, 3]])
+    >>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
+    >>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
+    >>>   logits, labels, axis=(2, 3))
+    >>> print(cross_entropy)
+    [[6.458669   0.45866907]]
+
   References:
     `Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
     Wikipedia
@@ -329,6 +353,22 @@ def softmax_cross_entropy_with_integer_labels(
   """
   chex.assert_type([logits], float)
   chex.assert_type([labels], int)
+  if isinstance(axis, int):
+    axis = normalize_axis_index(axis, logits.ndim)
+  elif isinstance(axis, tuple):
+    # Move all "feature" dimensions to the end preserving axis ordering and
+    # subsequent flattening "feature" dimensions to a single one.
+    logit_axis = normalize_axis_tuple(axis, logits.ndim, argname='logits')
+    batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis)
+    axis = len(batch_axis)
+    logits = logits.transpose(batch_axis + logit_axis)
+    logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1, ))
+    if where is not None:
+      where = where.transpose(batch_axis + logit_axis)
+      where = where.reshape(where.shape[:len(batch_axis)] + (-1, ))
+  else:
+    raise ValueError('Keyword argument \'axis\' must be of type \'int\' or '
+                     f'\'tuple[int, ...]\' but actual type is {type(axis)}.')
   # This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
   # we avoid subtracting the normalizer from all values, just from the values
   # for the correct labels.
diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py
index 7b6321618..2fc132712 100644
--- a/optax/losses/_classification_test.py
+++ b/optax/losses/_classification_test.py
@@ -247,6 +247,40 @@ def test_axis(self, shape, axis):
     )
     np.testing.assert_allclose(x, y, atol=1e-4)
 
+  @parameterized.parameters(
+      {'axis': (1, 3), 'shape': (2, 3, 4, 5)},
+      {'axis': (3, 2), 'shape': (2, 3, 4, 5)},
+      {'axis': (2, 3), 'shape': (2, 3, 4, 5)},
+      {'axis': (-3, -1), 'shape': (2, 3, 4, 5)},
+      {'axis': (-1, -2), 'shape': (2, 3, 4, 5)},
+      {'axis': (-2, -1), 'shape': (2, 3, 4, 5)},
+  )
+  def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]):
+    # Canonicalize axis and calculate shapes.
+    ndim = len(shape)
+    logits_axis = tuple((x + ndim) % ndim for x in axis)
+    labels_axis = tuple(x for x in range(ndim) if x not in logits_axis)
+    # Obtain shapes of batch and logits subspaces.
+    logits_shape = tuple(shape[x] for x in logits_axis)
+    labels_shape = tuple(shape[x] for x in labels_axis)
+    num_classes: float = np.prod(logits_shape).item()
+
+    key = jax.random.key(42)
+    keys = jax.random.split(key, 2)
+    logits = jax.random.uniform(keys[0], labels_shape + (num_classes, ))
+    labels = jax.random.randint(keys[1], labels_shape, 0, num_classes - 1)
+
+    fn = _classification.softmax_cross_entropy_with_integer_labels
+    desired = fn(logits, labels)
+
+    # Apply inverse axes permutation to obtain an array of `shape` shape.
+    logits = logits \
+        .reshape(labels_shape + logits_shape) \
+        .transpose(labels_axis + logits_axis)
+    assert logits.shape == shape
+    actual = fn(logits, labels, axis)
+    np.testing.assert_allclose(actual, desired)
+
 
 class SigmoidCrossEntropyTest(parameterized.TestCase):