Skip to content

Commit

Permalink
remove SelfAttention test and warning filter
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Nov 7, 2023
1 parent cbf7bea commit 2ace777
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 35 deletions.
4 changes: 2 additions & 2 deletions examples/lm1b/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None):
nn.initializers.ones, ('embed',)
),
)(inputs)
x = nn.SelfAttention(
x = nn.MultiHeadDotProductAttention(
num_heads=config.num_heads,
dtype=config.dtype,
qkv_features=config.qkv_dim,
Expand All @@ -256,7 +256,7 @@ def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None):
dropout_rate=config.attention_dropout_rate,
deterministic=config.deterministic,
decode=config.decode,
)(x, decoder_mask)
)(x, mask=decoder_mask)
x = nn.Dropout(rate=config.dropout_rate)(
x, deterministic=config.deterministic
)
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp_seq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __call__(self, inputs, deterministic):
# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=config.dtype)(inputs)
x = nn.SelfAttention(
x = nn.MultiHeadDotProductAttention(
num_heads=config.num_heads,
dtype=config.dtype,
qkv_features=config.qkv_dim,
Expand Down
8 changes: 4 additions & 4 deletions examples/wmt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __call__(self, inputs, encoder_mask=None):
# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=config.dtype)(inputs)
x = nn.SelfAttention(
x = nn.MultiHeadDotProductAttention(
num_heads=config.num_heads,
dtype=config.dtype,
qkv_features=config.qkv_dim,
Expand All @@ -227,7 +227,7 @@ def __call__(self, inputs, encoder_mask=None):
broadcast_dropout=False,
dropout_rate=config.attention_dropout_rate,
deterministic=config.deterministic,
)(x, encoder_mask)
)(x, mask=encoder_mask)

x = nn.Dropout(rate=config.dropout_rate)(
x, deterministic=config.deterministic
Expand Down Expand Up @@ -270,7 +270,7 @@ def __call__(
# Decoder block.
assert targets.ndim == 3
x = nn.LayerNorm(dtype=config.dtype)(targets)
x = nn.SelfAttention(
x = nn.MultiHeadDotProductAttention(
num_heads=config.num_heads,
dtype=config.dtype,
qkv_features=config.qkv_dim,
Expand All @@ -281,7 +281,7 @@ def __call__(
dropout_rate=config.attention_dropout_rate,
deterministic=config.deterministic,
decode=config.decode,
)(x, decoder_mask)
)(x, mask=decoder_mask)
x = nn.Dropout(rate=config.dropout_rate)(
x, deterministic=config.deterministic
)
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ filterwarnings = [
"ignore:`flax.traverse_util.Traversal` will be deprecated.*:DeprecationWarning",
# Deprecated legacy checkpoint - just want to keep the tests running for a while
"ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning",
# DeprecationWarning: SelfAttention will be deprecated soon.
"ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning",
# DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.
"ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning",
# DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed
Expand Down
30 changes: 4 additions & 26 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AttentionTest(parameterized.TestCase):
def test_multihead_self_attention(self):
rng = random.key(0)
x = jnp.ones((4, 6, 5))
sa_module = nn.SelfAttention(
sa_module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
Expand All @@ -47,7 +47,7 @@ def test_multihead_self_attention(self):
def test_dtype_infer(self):
rng = random.key(0)
x = jnp.ones((4, 6, 5), jnp.complex64)
sa_module = nn.SelfAttention(
sa_module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_decoding(self, spatial_shape, attn_dims):
inputs = random.normal(
key1, (bs,) + spatial_shape + (num_heads * num_features,)
)
module = nn.SelfAttention(
module = nn.MultiHeadDotProductAttention(
num_heads=num_heads,
qkv_features=num_heads * num_features,
precision=lax.Precision.HIGHEST,
Expand All @@ -198,7 +198,7 @@ def test_decoding(self, spatial_shape, attn_dims):
initial_vars = decode_module.init(key2, inputs)
state, params = pop(initial_vars, 'params')
causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape))
y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))(
y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, mask=y))(
inputs, causal_mask
)

Expand Down Expand Up @@ -263,28 +263,6 @@ def get_receptive_field_1d(pos):
'autoregressive self-attention.'
)

def test_multihead_self_attention_equality(self):
rng = random.key(0)
q = jnp.ones((4, 2, 3, 5))
module_kwargs = {
'num_heads': 8,
'qkv_features': 16,
'kernel_init': initializers.ones,
'bias_init': initializers.zeros,
'deterministic': False,
}
sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs)
sa_module1 = nn.SelfAttention(**module_kwargs)
y0, v0 = sa_module0.init_with_output(rng, q)
with self.assertWarnsRegex(
DeprecationWarning, 'SelfAttention will be deprecated soon.'
):
y1, v1 = sa_module1.init_with_output(rng, q)
self.assertTrue((y0 == y1).all())
self.assertTrue(
jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1))
)

def test_multihead_kv_args(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
Expand Down

0 comments on commit 2ace777

Please sign in to comment.