-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Unify extract_patches to support both 2D and 3D patches #21980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -616,42 +616,98 @@ def extract_patches( | |
| padding="valid", | ||
| data_format=None, | ||
| ): | ||
| """Extracts patches from the image(s). | ||
| """Extracts patches from the image(s) or volume(s). | ||
|
|
||
| This function supports both 2D and 3D patch extraction based on the | ||
| `size` argument length, similar to how `keras.ops.conv` handles | ||
| different dimensions. | ||
|
|
||
| Args: | ||
| images: Input image or batch of images. Must be 3D or 4D. | ||
| size: Patch size int or tuple (patch_height, patch_width) | ||
| strides: strides along height and width. If not specified, or | ||
| if `None`, it defaults to the same value as `size`. | ||
| dilation_rate: This is the input stride, specifying how far two | ||
| consecutive patch samples are in the input. For value other than 1, | ||
| strides must be 1. NOTE: `strides > 1` is not supported in | ||
| conjunction with `dilation_rate > 1` | ||
| images: Input image/volume or batch of images/volumes. | ||
| For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`. | ||
| For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`. | ||
| size: Patch size as int or tuple. | ||
| Length 2 tuple `(patch_height, patch_width)` for 2D patches. | ||
| Length 3 tuple `(patch_depth, patch_height, patch_width)` for | ||
| 3D patches. | ||
| strides: Strides for patch extraction. If not specified, defaults | ||
| to `size` (non-overlapping patches). | ||
| dilation_rate: Dilation rate for patch extraction. Note that | ||
| `dilation_rate > 1` is not supported with `strides > 1`. | ||
| padding: The type of padding algorithm to use: `"same"` or `"valid"`. | ||
| data_format: A string specifying the data format of the input tensor. | ||
| It can be either `"channels_last"` or `"channels_first"`. | ||
| `"channels_last"` corresponds to inputs with shape | ||
| `(batch, height, width, channels)`, while `"channels_first"` | ||
| corresponds to inputs with shape `(batch, channels, height, width)`. | ||
| If not specified, the value will default to | ||
| `keras.config.image_data_format`. | ||
| If not specified, defaults to `keras.config.image_data_format`. | ||
|
Comment on lines
+626
to
+640
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstrings for images: Input image/volume or batch of images/volumes. Assumes
`channels_last` data format.
For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
size: The size of the patches to extract.
- If `size` is an integer, 2D patches of size `(size, size)` are
extracted.
- If `size` is a tuple of 2 integers, 2D patches of size
`(patch_height, patch_width)` are extracted.
- If `size` is a tuple of 3 integers, 3D patches of size
`(patch_depth, patch_height, patch_width)` are extracted.
strides: Strides for patch extraction. If not specified, defaults
to `size` (non-overlapping patches).
dilation_rate: Dilation rate for patch extraction. Note that
`dilation_rate > 1` is not supported with `strides > 1`.
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
data_format: A string specifying the data format of the input tensor.
It can be either `"channels_last"` or `"channels_first"`.
`"channels_last"` corresponds to inputs with shape
`(batch, ..., channels)`, while `"channels_first"` corresponds
to inputs with shape `(batch, channels, ...)`. If not specified,
defaults to `keras.config.image_data_format`. |
||
|
|
||
| Returns: | ||
| Extracted patches 3D (if not batched) or 4D (if batched) | ||
| Extracted patches with shape depending on input and `size`: | ||
| - 2D patches: 3D (unbatched) or 4D (batched) | ||
| - 3D patches: 4D (unbatched) or 5D (batched) | ||
|
|
||
| Examples: | ||
|
|
||
| >>> # 2D patches from batch of images | ||
| >>> image = np.random.random( | ||
| ... (2, 20, 20, 3) | ||
| ... ).astype("float32") # batch of 2 RGB images | ||
| ... ).astype("float32") | ||
| >>> patches = keras.ops.image.extract_patches(image, (5, 5)) | ||
| >>> patches.shape | ||
| (2, 4, 4, 75) | ||
| >>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image | ||
|
|
||
| >>> # 2D patches from single image | ||
| >>> image = np.random.random((20, 20, 3)).astype("float32") | ||
| >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1)) | ||
| >>> patches.shape | ||
| (18, 18, 27) | ||
|
|
||
| >>> # 3D patches from batch of volumes | ||
| >>> volumes = np.random.random( | ||
| ... (2, 10, 10, 10, 3) | ||
| ... ).astype("float32") | ||
| >>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3)) | ||
| >>> patches.shape | ||
| (2, 3, 3, 3, 81) | ||
|
|
||
| >>> # 3D patches from single volume | ||
| >>> volume = np.random.random((10, 10, 10, 3)).astype("float32") | ||
| >>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3)) | ||
| >>> patches.shape | ||
| (3, 3, 3, 81) | ||
| """ | ||
| # Validate size argument | ||
| if not isinstance(size, int): | ||
| if not isinstance(size, (tuple, list)): | ||
| raise TypeError( | ||
| "Invalid `size` argument. Expected an int or a tuple. " | ||
| f"Received: size={size} of type {type(size).__name__}" | ||
| ) | ||
| if len(size) not in (2, 3): | ||
| raise ValueError( | ||
| "Invalid `size` argument. Expected a tuple of length 2 or 3. " | ||
| f"Received: size={size} with length {len(size)}" | ||
| ) | ||
|
|
||
| # Determine 2D vs 3D based on size argument | ||
| if not isinstance(size, int) and len(size) == 3: | ||
hertschuh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # 3D patch extraction | ||
| if any_symbolic_tensors((images,)): | ||
| return ExtractPatches3D( | ||
| size=size, | ||
| strides=strides, | ||
| dilation_rate=dilation_rate, | ||
| padding=padding, | ||
| data_format=data_format, | ||
| ).symbolic_call(images) | ||
| return _extract_patches_3d( | ||
| images, | ||
| size, | ||
| strides, | ||
| dilation_rate, | ||
| padding, | ||
| data_format=data_format, | ||
| ) | ||
|
|
||
| # 2D patch extraction (default) | ||
| if any_symbolic_tensors((images,)): | ||
| return ExtractPatches( | ||
| size=size, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small tweak: mention that if the size is an int it will do 2D patches: