Skip to content

Commit 807ace0

Browse files
committed
Fix custom layer parameters ending in _mask (#21154)
- Changed get_shapes_dict to only exclude 'mask' parameter, not all *_mask - Allows custom layers to use parameters like attention_mask, padding_mask - Added comprehensive tests for _mask parameter handling - Maintains backward compatibility with Keras masking Fixes #21154
1 parent 5c925c0 commit 807ace0

2 files changed

Lines changed: 62 additions & 2 deletions

File tree

keras/src/layers/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,8 +1895,8 @@ def get_shapes_dict(call_spec):
18951895
"""
18961896
shapes_dict = {}
18971897
for k, v in call_spec.tensor_arguments_dict.items():
1898-
if k == "mask" or k.endswith("_mask"):
1899-
# Do not include mask tensors in shapes dict
1898+
if k == "mask":
1899+
# Do not include the 'mask' tensor in shapes dict (for Keras masking)
19001900
continue
19011901
if k == "kwargs" or k == "args":
19021902
# Do not include catch-alls in shapes dict

keras/src/layers/layer_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,66 @@ def compute_output_shape(self, input_shape):
144144
self.assertEqual(out["2"]["11"].shape, (2, 3))
145145
self.assertEqual(out["2"]["22"].shape, (2, 3))
146146

147+
def test_custom_layer_with_mask_parameter(self):
148+
# Test that custom layer parameters ending in _mask work correctly
149+
# with compute_output_shape. Regression test for issue #21154.
150+
151+
class CustomLayerWithMask(layers.Layer):
152+
def call(self, x, attention_mask):
153+
return x * attention_mask
154+
155+
def compute_output_shape(self, x_shape, attention_mask_shape):
156+
# Use the mask shape in computation to ensure it's needed
157+
return x_shape
158+
159+
layer = CustomLayerWithMask()
160+
x = backend.KerasTensor((2, 3))
161+
attention_mask = backend.KerasTensor((2, 3))
162+
163+
# This should work without errors
164+
output = layer(x, attention_mask=attention_mask)
165+
self.assertEqual(output.shape, (2, 3))
166+
167+
# Test compute_output_spec as well
168+
output_spec = layer.compute_output_spec(x, attention_mask=attention_mask)
169+
self.assertEqual(output_spec.shape, (2, 3))
170+
171+
def test_mask_parameter_exclusions(self):
172+
# Test that only 'mask' parameter is excluded from shapes_dict,
173+
# not all parameters ending with '_mask'. Regression test for issue #21154.
174+
175+
class LayerWithMultipleMasks(layers.Layer):
176+
def call(self, x, mask=None, attention_mask=None, padding_mask=None):
177+
result = x
178+
if mask is not None:
179+
result = result * mask
180+
if attention_mask is not None:
181+
result = result * attention_mask
182+
if padding_mask is not None:
183+
result = result * padding_mask
184+
return result
185+
186+
def compute_output_shape(self, x_shape, attention_mask_shape=None, padding_mask_shape=None):
187+
# Note: 'mask' should not appear here as it's excluded
188+
# but attention_mask_shape and padding_mask_shape should be available
189+
return x_shape
190+
191+
layer = LayerWithMultipleMasks()
192+
x = backend.KerasTensor((2, 3))
193+
mask = backend.KerasTensor((2, 3))
194+
attention_mask = backend.KerasTensor((2, 3))
195+
padding_mask = backend.KerasTensor((2, 3))
196+
197+
# This should work without errors
198+
output = layer(x, mask=mask, attention_mask=attention_mask, padding_mask=padding_mask)
199+
self.assertEqual(output.shape, (2, 3))
200+
201+
# Test compute_output_spec as well
202+
output_spec = layer.compute_output_spec(
203+
x, mask=mask, attention_mask=attention_mask, padding_mask=padding_mask
204+
)
205+
self.assertEqual(output_spec.shape, (2, 3))
206+
147207
def test_positional_arg_error(self):
148208
class SomeLayer(layers.Layer):
149209
def call(self, x, bool_arg):

0 commit comments

Comments
 (0)