Skip to content

Commit

Permalink
add test_scan_negative_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 8, 2023
1 parent 870847e commit b198cff
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,35 @@ def __call__(self, c, b, xs):
np.testing.assert_allclose(c[0], c2[0], atol=1e-7)
np.testing.assert_allclose(c[1], c2[1], atol=1e-7)

def test_scan_negative_axes(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, _, x):
x = nn.Dense(4)(x)
return None, x

class Bar(nn.Module):
@nn.compact
def __call__(self, x):
_, x = nn.scan(
Foo,
variable_broadcast='params',
split_rngs=dict(params=False),
in_axes=1,
out_axes=-1,
)()(None, x)
return x

y, variables = Bar().init_with_output(
{'params': jax.random.PRNGKey(0)},
jax.random.normal(jax.random.PRNGKey(1), shape=[1, 2, 3]),
)
params = variables['params']

self.assertEqual(y.shape, (1, 4, 2))
self.assertEqual(params['ScanFoo_0']['Dense_0']['kernel'].shape, (3, 4))
self.assertEqual(params['ScanFoo_0']['Dense_0']['bias'].shape, (4,))

def test_multiscope_lifting_simple(self):
class Counter(nn.Module):
@nn.compact
Expand Down

0 comments on commit b198cff

Please sign in to comment.