Skip to content

Commit 6d06085

Browse files
Address bug with convolution using Tensorflow, Numpy, Jax backends (#21796)
* add exception * Update keras/src/backend/tensorflow/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/layers/convolutional/conv_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * address numpy + make test more generic * fix jax * fix pydocs * fix error msg in tensorflow * handle only static case --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3973b15 commit 6d06085

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

keras/src/backend/jax/nn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def conv(
355355
feature_group_count = channels // kernel_in_channels
356356
kernel = convert_to_tensor(kernel)
357357
inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
358-
return jax.lax.conv_general_dilated(
358+
result = jax.lax.conv_general_dilated(
359359
inputs,
360360
kernel,
361361
strides,
@@ -364,6 +364,14 @@ def conv(
364364
dimension_numbers=dimension_numbers,
365365
feature_group_count=feature_group_count,
366366
)
367+
if result.size == 0:
368+
raise ValueError(
369+
"The convolution operation resulted in an empty output. "
370+
"This can happen if the input is too small for the given "
371+
"kernel size, strides, dilation rate, and padding mode. "
372+
"Please check the input shape and convolution parameters."
373+
)
374+
return result
367375

368376

369377
def depthwise_conv(

keras/src/backend/numpy/nn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def conv(
404404
f"kernel in_channels {kernel_in_channels}. "
405405
)
406406
feature_group_count = channels // kernel_in_channels
407-
return np.array(
407+
result = np.array(
408408
jax.lax.conv_general_dilated(
409409
inputs,
410410
kernel if is_tensor(kernel) else kernel.numpy(),
@@ -415,6 +415,14 @@ def conv(
415415
feature_group_count=feature_group_count,
416416
)
417417
)
418+
if result.size == 0:
419+
raise ValueError(
420+
"The convolution operation resulted in an empty output. "
421+
"This can happen if the input is too small for the given "
422+
"kernel size, strides, dilation rate, and padding mode. "
423+
"Please check the input shape and convolution parameters."
424+
)
425+
return result
418426

419427

420428
def depthwise_conv(

keras/src/backend/tensorflow/nn.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,28 @@ def conv(
310310
):
311311
def _conv():
312312
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
313-
return tf.nn.convolution(
313+
result = tf.nn.convolution(
314314
inputs,
315315
kernel,
316316
strides,
317317
padding.upper(),
318318
data_format=tf_data_format,
319319
dilations=dilation_rate,
320320
)
321+
result_shape = result.shape
322+
if (
323+
result_shape.is_fully_defined()
324+
and math.prod(result_shape.as_list()) == 0
325+
):
326+
raise ValueError(
327+
"The convolution operation resulted in an empty output. "
328+
"Output shape:"
329+
f" {result_shape}. This can happen if the input is too small "
330+
"for the given kernel size, strides, dilation rate, and "
331+
"padding mode. Please check the input shape and convolution "
332+
"parameters."
333+
)
334+
return result
321335

322336
# Certain ops are are broken in Tensorflow on CPU only.
323337
# We can work around by compiling the op with XLA.

keras/src/layers/convolutional/conv_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,3 +1095,11 @@ def test_conv_constraints(self):
10951095
)
10961096
layer.build((None, 5, 5, 3))
10971097
self.assertIsInstance(layer.bias.constraint, constraints.NonNeg)
1098+
1099+
def test_conv_raises_exception_on_zero_dims(self):
1100+
x = np.random.rand(3, 4, 4, 4)
1101+
l = layers.Conv2D(6, [5, 5], 1, "valid")
1102+
# The exception type can vary across backends (e.g., ValueError,
1103+
# tf.errors.InvalidArgumentError, RuntimeError).
1104+
with self.assertRaises(Exception):
1105+
l(x)

0 commit comments

Comments
 (0)