diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index 8c51ca3d..2ffd347d 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -100,7 +100,7 @@ def transpose_from_front(ax, xs): def trans(x): if ax < 0: - pax = x.ndim - ax + pax = x.ndim + ax else: pax = ax assert pax < x.ndim