Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 73 additions & 17 deletions keras/src/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,42 +616,98 @@ def extract_patches(
padding="valid",
data_format=None,
):
"""Extracts patches from the image(s).
"""Extracts patches from the image(s) or volume(s).

This function supports both 2D and 3D patch extraction based on the
`size` argument length, similar to how `keras.ops.conv` handles
different dimensions.

Args:
images: Input image or batch of images. Must be 3D or 4D.
size: Patch size int or tuple (patch_height, patch_width)
strides: strides along height and width. If not specified, or
if `None`, it defaults to the same value as `size`.
dilation_rate: This is the input stride, specifying how far two
consecutive patch samples are in the input. For value other than 1,
strides must be 1. NOTE: `strides > 1` is not supported in
conjunction with `dilation_rate > 1`
images: Input image/volume or batch of images/volumes.
For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
size: Patch size as int or tuple.
Length 2 tuple `(patch_height, patch_width)` for 2D patches.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small tweak: mention that if the size is an int it will do 2D patches:

    size: Patch size as int or tuple.
        Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.

Length 3 tuple `(patch_depth, patch_height, patch_width)` for
3D patches.
strides: Strides for patch extraction. If not specified, defaults
to `size` (non-overlapping patches).
dilation_rate: Dilation rate for patch extraction. Note that
`dilation_rate > 1` is not supported with `strides > 1`.
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
data_format: A string specifying the data format of the input tensor.
It can be either `"channels_last"` or `"channels_first"`.
`"channels_last"` corresponds to inputs with shape
`(batch, height, width, channels)`, while `"channels_first"`
corresponds to inputs with shape `(batch, channels, height, width)`.
If not specified, the value will default to
`keras.config.image_data_format`.
If not specified, defaults to `keras.config.image_data_format`.
Comment on lines +626 to +640
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstrings for images, size, and data_format could be more explicit to improve clarity, especially with the new 3D support. The current images docstring assumes channels_last without stating it, the size docstring is ambiguous about integer values, and the data_format docstring is a bit sparse. I suggest clarifying these points for a better user experience.

        images: Input image/volume or batch of images/volumes. Assumes
            `channels_last` data format.
            For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
            For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
        size: The size of the patches to extract.
            - If `size` is an integer, 2D patches of size `(size, size)` are
              extracted.
            - If `size` is a tuple of 2 integers, 2D patches of size
              `(patch_height, patch_width)` are extracted.
            - If `size` is a tuple of 3 integers, 3D patches of size
              `(patch_depth, patch_height, patch_width)` are extracted.
        strides: Strides for patch extraction. If not specified, defaults
            to `size` (non-overlapping patches).
        dilation_rate: Dilation rate for patch extraction. Note that
            `dilation_rate > 1` is not supported with `strides > 1`.
        padding: The type of padding algorithm to use: `"same"` or `"valid"`.
        data_format: A string specifying the data format of the input tensor.
            It can be either `"channels_last"` or `"channels_first"`.
            `"channels_last"` corresponds to inputs with shape
            `(batch, ..., channels)`, while `"channels_first"` corresponds
            to inputs with shape `(batch, channels, ...)`. If not specified,
            defaults to `keras.config.image_data_format`.


Returns:
Extracted patches 3D (if not batched) or 4D (if batched)
Extracted patches with shape depending on input and `size`:
- 2D patches: 3D (unbatched) or 4D (batched)
- 3D patches: 4D (unbatched) or 5D (batched)

Examples:

>>> # 2D patches from batch of images
>>> image = np.random.random(
... (2, 20, 20, 3)
... ).astype("float32") # batch of 2 RGB images
... ).astype("float32")
>>> patches = keras.ops.image.extract_patches(image, (5, 5))
>>> patches.shape
(2, 4, 4, 75)
>>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image

>>> # 2D patches from single image
>>> image = np.random.random((20, 20, 3)).astype("float32")
>>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
>>> patches.shape
(18, 18, 27)

>>> # 3D patches from batch of volumes
>>> volumes = np.random.random(
... (2, 10, 10, 10, 3)
... ).astype("float32")
>>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
>>> patches.shape
(2, 3, 3, 3, 81)

>>> # 3D patches from single volume
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
>>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
>>> patches.shape
(3, 3, 3, 81)
"""
# Validate size argument
if not isinstance(size, int):
if not isinstance(size, (tuple, list)):
raise TypeError(
"Invalid `size` argument. Expected an int or a tuple. "
f"Received: size={size} of type {type(size).__name__}"
)
if len(size) not in (2, 3):
raise ValueError(
"Invalid `size` argument. Expected a tuple of length 2 or 3. "
f"Received: size={size} with length {len(size)}"
)

# Determine 2D vs 3D based on size argument
if not isinstance(size, int) and len(size) == 3:
# 3D patch extraction
if any_symbolic_tensors((images,)):
return ExtractPatches3D(
size=size,
strides=strides,
dilation_rate=dilation_rate,
padding=padding,
data_format=data_format,
).symbolic_call(images)
return _extract_patches_3d(
images,
size,
strides,
dilation_rate,
padding,
data_format=data_format,
)

# 2D patch extraction (default)
if any_symbolic_tensors((images,)):
return ExtractPatches(
size=size,
Expand Down
22 changes: 16 additions & 6 deletions keras/src/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,19 +2357,29 @@ def test_affine_transform_invalid_transform_rank(self):
kimage.affine_transform(images, invalid_transform)

def test_extract_patches_invalid_size(self):
size = (3, 3, 3) # Invalid size, too many dimensions
size = "5" # Invalid size type
image = np.random.uniform(size=(2, 20, 20, 3))
with self.assertRaisesRegex(
TypeError, "Expected an int or a tuple of length 2"
):
with self.assertRaisesRegex(TypeError, "Expected an int or a tuple"):
kimage.extract_patches(image, size)

size = "5" # Invalid size type
size = (3, 3, 3, 3) # Invalid size, too many dimensions
with self.assertRaisesRegex(
TypeError, "Expected an int or a tuple of length 2"
ValueError, "Expected a tuple of length 2 or 3"
):
kimage.extract_patches(image, size)

def test_extract_patches_unified_3d(self):
# Test that extract_patches handles 3D volumes when size has 3 elements
# channels_last
volume = np.random.uniform(size=(2, 20, 20, 20, 3)).astype("float32")
patches = kimage.extract_patches(volume, (5, 5, 5))
self.assertEqual(patches.shape, (2, 4, 4, 4, 375))

# unbatched
volume = np.random.uniform(size=(20, 20, 20, 3)).astype("float32")
patches = kimage.extract_patches(volume, (5, 5, 5))
self.assertEqual(patches.shape, (4, 4, 4, 375))

def test_map_coordinates_invalid_coordinates_rank(self):
# Test mismatched dim of coordinates
image = np.random.uniform(size=(10, 10, 3))
Expand Down