You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For a multidimensional tensor the pooling functions seem to reduce over the batch dimension (the first one) but don't allow to reduce over the last dimension. This might be on purpose but is totally not clear from the documentation. I don't need a workaround as I actually want to reduce over one of the middle dimensions but thought I should still report it.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux, 6.12.4-arch1-1
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
flax: 0.10.2
jax: 0.4.38
jaxlib: 0.4.38
Python version: 3.12.8
GPU/TPU model and memory: NVIDIA GeForce RTX 20606144MiB (There is another NVIDIA TITAN Xp GPU on that machine but flax/jax is using the first one).
CUDA version (if applicable): 12.6.r12.6
Problem you have encountered:
When using a pool functions such as avg_pool or max_pool from either flax.nnx or flax.linen and one specifies a window that has a size > 1 in the first dimension then the reduction is done on the batch dimension.
Furthermore the window and stride arguments don't allow one to use a tuple that spans over all dimensions so that one cannot pool over the last dimension.
What you expected to happen:
Either pooling should run separately for each batch entry or one should be able to specify a window of the same dimension as the whole tensor.
Logs, error messages, etc:
Steps to reproduce:
importjax.numpyasjnpfromflaximportnnx# create a tensor of shape (2, 3, 4) - i.e. with a batch size of 2x=jnp.float32(range(24)).reshape((2,3,4))
print(x[0,0,0]) # 0.0print(x[1,0,0]) # 12.0x_reduced=nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')
# This shows that the reduction happened over the batch dimensionprint(x_reduced.shape) # (1, 3, 4)# This show that the first reduced value is the average of x[0,0,0] and x[1,0,0]print(x_reduced[0,0,0]) # 6.0# Trying to to reduce over the last dimension gives an exceptionx_reduced=nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')
The message of the exception is
AssertionError Traceback (most recent call last)
Cell In[23], line 1
----> 1 x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')
File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:97, in avg_pool(inputs, window_shape, strides, padding, count_include_pad)
79 def avg_pool(
80 inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
81 ):
82 """Pools the input by taking the average over a window.
83
84 Args:
(...)
95 The average for each window slice.
96 """
---> 97 y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
98 if count_include_pad:
99 y = y / np.prod(window_shape)
File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:62, in pool(inputs, init, reduce_fn, window_shape, strides, padding)
59 dims = (1,) + dims
60 is_single_input = True
---> 62 assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
63 if not isinstance(padding, str):
64 padding = tuple(map(tuple, padding))
AssertionError: len((2, 3, 4)) != len((1, 1, 4, 1))
The text was updated successfully, but these errors were encountered:
Hi @simonschoelly, the issue is that pooling layers don't operate over the last dimension as most commonly you reduce over the time / space dimensions. Maybe you could add a dummy features dimension:
For a multidimensional tensor the pooling functions seem to reduce over the batch dimension (the first one) but don't allow to reduce over the last dimension. This might be on purpose but is totally not clear from the documentation. I don't need a workaround as I actually want to reduce over one of the middle dimensions but thought I should still report it.
System information
6.12.4-arch1-1
pip show flax jax jaxlib
:3.12.8
NVIDIA GeForce RTX 2060
6144MiB
(There is anotherNVIDIA TITAN Xp
GPU on that machine but flax/jax is using the first one).12.6.r12.6
Problem you have encountered:
When using a pool functions such as
avg_pool
ormax_pool
from eitherflax.nnx
orflax.linen
and one specifies a window that has a size > 1 in the first dimension then the reduction is done on the batch dimension.Furthermore the
window
andstride
arguments don't allow one to use a tuple that spans over all dimensions so that one cannot pool over the last dimension.What you expected to happen:
Either pooling should run separately for each batch entry or one should be able to specify a window of the same dimension as the whole tensor.
Logs, error messages, etc:
Steps to reproduce:
The message of the exception is
The text was updated successfully, but these errors were encountered: