Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pool functions reduce over batch dimension and not last dimension #4494

Open
simonschoelly opened this issue Jan 21, 2025 · 1 comment
Open

Comments

@simonschoelly
Copy link

simonschoelly commented Jan 21, 2025

For a multidimensional tensor the pooling functions seem to reduce over the batch dimension (the first one) but don't allow to reduce over the last dimension. This might be on purpose but is totally not clear from the documentation. I don't need a workaround as I actually want to reduce over one of the middle dimensions but thought I should still report it.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux, 6.12.4-arch1-1
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
flax: 0.10.2
jax: 0.4.38
jaxlib: 0.4.38
  • Python version: 3.12.8
  • GPU/TPU model and memory: NVIDIA GeForce RTX 2060 6144MiB (There is another NVIDIA TITAN Xp GPU on that machine but flax/jax is using the first one).
  • CUDA version (if applicable): 12.6.r12.6

Problem you have encountered:

When using a pool functions such as avg_pool or max_pool from either flax.nnx or flax.linen and one specifies a window that has a size > 1 in the first dimension then the reduction is done on the batch dimension.

Furthermore the window and stride arguments don't allow one to use a tuple that spans over all dimensions so that one cannot pool over the last dimension.

What you expected to happen:

Either pooling should run separately for each batch entry or one should be able to specify a window of the same dimension as the whole tensor.

Logs, error messages, etc:

Steps to reproduce:

import jax.numpy as jnp
from flax import nnx

# create a tensor of shape (2, 3, 4) - i.e. with a batch size of 2
x = jnp.float32(range(24)).reshape((2,3,4))

print(x[0,0,0]) # 0.0
print(x[1,0,0]) # 12.0

x_reduced = nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')

# This shows that the reduction happened over the batch dimension
print(x_reduced.shape) # (1, 3, 4)

# This show that the first reduced value is the average of x[0,0,0] and x[1,0,0]
print(x_reduced[0,0,0]) # 6.0

# Trying to to reduce over the last dimension gives an exception
x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')

The message of the exception is

AssertionError                            Traceback (most recent call last)
Cell In[23], line 1
----> 1 x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')

File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:97, in avg_pool(inputs, window_shape, strides, padding, count_include_pad)
     79 def avg_pool(
     80   inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
     81 ):
     82   """Pools the input by taking the average over a window.
     83 
     84   Args:
   (...)
     95     The average for each window slice.
     96   """
---> 97   y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
     98   if count_include_pad:
     99     y = y / np.prod(window_shape)

File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:62, in pool(inputs, init, reduce_fn, window_shape, strides, padding)
     59   dims = (1,) + dims
     60   is_single_input = True
---> 62 assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
     63 if not isinstance(padding, str):
     64   padding = tuple(map(tuple, padding))

AssertionError: len((2, 3, 4)) != len((1, 1, 4, 1))
@cgarciae
Copy link
Collaborator

Hi @simonschoelly, the issue is that pooling layers don't operate over the last dimension as most commonly you reduce over the time / space dimensions. Maybe you could add a dummy features dimension:

x = jnp.float32(range(24)).reshape((2, 3, 4, 1))
x_reduced = nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')
print(x_reduced.shape)  # (2, 2, 4, 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants