@@ -144,6 +144,66 @@ def compute_output_shape(self, input_shape):
144144 self .assertEqual (out ["2" ]["11" ].shape , (2 , 3 ))
145145 self .assertEqual (out ["2" ]["22" ].shape , (2 , 3 ))
146146
147+ def test_custom_layer_with_mask_parameter (self ):
148+ # Test that custom layer parameters ending in _mask work correctly
149+ # with compute_output_shape. Regression test for issue #21154.
150+
151+ class CustomLayerWithMask (layers .Layer ):
152+ def call (self , x , attention_mask ):
153+ return x * attention_mask
154+
155+ def compute_output_shape (self , x_shape , attention_mask_shape ):
156+ # Use the mask shape in computation to ensure it's needed
157+ return x_shape
158+
159+ layer = CustomLayerWithMask ()
160+ x = backend .KerasTensor ((2 , 3 ))
161+ attention_mask = backend .KerasTensor ((2 , 3 ))
162+
163+ # This should work without errors
164+ output = layer (x , attention_mask = attention_mask )
165+ self .assertEqual (output .shape , (2 , 3 ))
166+
167+ # Test compute_output_spec as well
168+ output_spec = layer .compute_output_spec (x , attention_mask = attention_mask )
169+ self .assertEqual (output_spec .shape , (2 , 3 ))
170+
171+ def test_mask_parameter_exclusions (self ):
172+ # Test that only 'mask' parameter is excluded from shapes_dict,
173+ # not all parameters ending with '_mask'. Regression test for issue #21154.
174+
175+ class LayerWithMultipleMasks (layers .Layer ):
176+ def call (self , x , mask = None , attention_mask = None , padding_mask = None ):
177+ result = x
178+ if mask is not None :
179+ result = result * mask
180+ if attention_mask is not None :
181+ result = result * attention_mask
182+ if padding_mask is not None :
183+ result = result * padding_mask
184+ return result
185+
186+ def compute_output_shape (self , x_shape , attention_mask_shape = None , padding_mask_shape = None ):
187+ # Note: 'mask' should not appear here as it's excluded
188+ # but attention_mask_shape and padding_mask_shape should be available
189+ return x_shape
190+
191+ layer = LayerWithMultipleMasks ()
192+ x = backend .KerasTensor ((2 , 3 ))
193+ mask = backend .KerasTensor ((2 , 3 ))
194+ attention_mask = backend .KerasTensor ((2 , 3 ))
195+ padding_mask = backend .KerasTensor ((2 , 3 ))
196+
197+ # This should work without errors
198+ output = layer (x , mask = mask , attention_mask = attention_mask , padding_mask = padding_mask )
199+ self .assertEqual (output .shape , (2 , 3 ))
200+
201+ # Test compute_output_spec as well
202+ output_spec = layer .compute_output_spec (
203+ x , mask = mask , attention_mask = attention_mask , padding_mask = padding_mask
204+ )
205+ self .assertEqual (output_spec .shape , (2 , 3 ))
206+
147207 def test_positional_arg_error (self ):
148208 class SomeLayer (layers .Layer ):
149209 def call (self , x , bool_arg ):
0 commit comments