Skip to content

Commit 0429d73

Browse files
AntoineSimoulinNicolasHugCallidior
authored
[release/0.24] cherrypicks for keypoints fix (#9238)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Björn Barz <[email protected]>
1 parent b919bd0 commit 0429d73

File tree

8 files changed

+523
-22
lines changed

8 files changed

+523
-22
lines changed

docs/source/transforms.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ Miscellaneous
413413
v2.RandomErasing
414414
v2.Lambda
415415
v2.SanitizeBoundingBoxes
416+
v2.SanitizeKeyPoints
416417
v2.ClampBoundingBoxes
417418
v2.ClampKeyPoints
418419
v2.UniformTemporalSubsample
@@ -427,6 +428,7 @@ Functionals
427428
v2.functional.normalize
428429
v2.functional.erase
429430
v2.functional.sanitize_bounding_boxes
431+
v2.functional.sanitize_keypoints
430432
v2.functional.clamp_bounding_boxes
431433
v2.functional.clamp_keypoints
432434
v2.functional.uniform_temporal_subsample
@@ -530,6 +532,7 @@ Developer tools
530532
v2.query_size
531533
v2.query_chw
532534
v2.get_bounding_boxes
535+
v2.get_keypoints
533536

534537

535538
V1 API Reference

test/test_transforms_v2.py

Lines changed: 325 additions & 12 deletions
Large diffs are not rendered by default.

torchvision/transforms/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@
5151
LinearTransformation,
5252
Normalize,
5353
SanitizeBoundingBoxes,
54+
SanitizeKeyPoints,
5455
ToDtype,
5556
)
5657
from ._temporal import UniformTemporalSubsample
5758
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58-
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
59+
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
5960

6061
from ._deprecated import ToTensor # usort: skip

torchvision/transforms/v2/_misc.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from torchvision import transforms as _transforms, tv_tensors
1111
from torchvision.transforms.v2 import functional as F, Transform
1212

13-
from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
13+
from ._utils import (
14+
_parse_labels_getter,
15+
_setup_number_or_seq,
16+
_setup_size,
17+
get_bounding_boxes,
18+
get_keypoints,
19+
has_any,
20+
is_pure_tensor,
21+
)
1422

1523

1624
# TODO: do we want/need to expose this?
@@ -459,3 +467,93 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
459467
return output
460468
else:
461469
return tv_tensors.wrap(output, like=inpt)
470+
471+
472+
class SanitizeKeyPoints(Transform):
473+
"""Remove keypoints outside of the image area and their corresponding labels (if any).
474+
475+
This transform removes keypoints or groups of keypoints and their associated labels that
476+
have coordinates outside of their corresponding image.
477+
If you would instead like to clamp such keypoints to the image edges, use
478+
:class:`~torchvision.transforms.v2.ClampKeyPoints`.
479+
480+
It is recommended to call it at the end of a pipeline, before passing the
481+
input to the models.
482+
483+
Keypoints can be passed as a set of individual keypoints or as a set of objects
484+
(e.g., polygons or polygonal chains) consisting of a fixed number of keypoints of shape ``[..., 2]``.
485+
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform
486+
will only remove entire groups, not individual keypoints within a group.
487+
488+
Args:
489+
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
490+
(or anything else that needs to be sanitized along with the keypoints).
491+
If set to the string ``"default"``, this will try to find a "labels" key in the input (case-insensitive), if
492+
the input is a dict or it is a tuple whose second element is a dict.
493+
494+
It can also be a callable that takes the same input as the transform, and returns either:
495+
496+
- A single tensor (the labels)
497+
- A tuple/list of tensors, each of which will be subject to the same sanitization as the keypoints.
498+
499+
If ``labels_getter`` is None (the default), then only keypoints are sanitized.
500+
"""
501+
502+
def __init__(
503+
self,
504+
labels_getter: Union[Callable[[Any], Any], str, None] = None,
505+
) -> None:
506+
super().__init__()
507+
self.labels_getter = labels_getter
508+
self._labels_getter = _parse_labels_getter(labels_getter)
509+
510+
def forward(self, *inputs: Any) -> Any:
511+
inputs = inputs if len(inputs) > 1 else inputs[0]
512+
513+
labels = self._labels_getter(inputs)
514+
if labels is not None:
515+
msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
516+
if isinstance(labels, torch.Tensor):
517+
labels = (labels,)
518+
elif isinstance(labels, (tuple, list)):
519+
for entry in labels:
520+
if not isinstance(entry, torch.Tensor):
521+
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
522+
raise ValueError(msg.format(type=type(entry)))
523+
else:
524+
raise ValueError(msg.format(type=type(labels)))
525+
526+
flat_inputs, spec = tree_flatten(inputs)
527+
points = get_keypoints(flat_inputs)
528+
529+
if labels is not None:
530+
for label in labels:
531+
if points.shape[0] != label.shape[0]:
532+
raise ValueError(
533+
f"Number of kepyoints (shape={points.shape}) must match the number of labels."
534+
f"Found labels with shape={label.shape})."
535+
)
536+
537+
valid = F._misc._get_sanitize_keypoints_mask(
538+
points,
539+
canvas_size=points.canvas_size,
540+
)
541+
542+
params = dict(valid=valid, labels=labels)
543+
flat_outputs = [self.transform(inpt, params) for inpt in flat_inputs]
544+
545+
return tree_unflatten(flat_outputs, spec)
546+
547+
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
548+
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
549+
is_keypoints = isinstance(inpt, tv_tensors.KeyPoints)
550+
551+
if not (is_label or is_keypoints):
552+
return inpt
553+
554+
output = inpt[params["valid"]]
555+
556+
if is_label:
557+
return output
558+
else:
559+
return tv_tensors.wrap(output, like=inpt)

torchvision/transforms/v2/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,18 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes:
165165
raise ValueError("No bounding boxes were found in the sample")
166166

167167

168+
def get_keypoints(flat_inputs: list[Any]) -> tv_tensors.KeyPoints:
169+
"""Return the keypoints in the input.
170+
171+
Assumes only one ``KeyPoints`` object is present.
172+
"""
173+
# This assumes there is only one keypoint per sample as per the general convention
174+
try:
175+
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints))
176+
except StopIteration:
177+
raise ValueError("No keypoints were found in the sample")
178+
179+
168180
def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
169181
"""Return Channel, Height, and Width."""
170182
chws = {

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@
156156
normalize_image,
157157
normalize_video,
158158
sanitize_bounding_boxes,
159+
sanitize_keypoints,
159160
to_dtype,
160161
to_dtype_image,
161162
to_dtype_video,

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from torchvision.utils import _log_api_usage_once
2626

27-
from ._meta import _get_size_image_pil, clamp_bounding_boxes, clamp_keypoints, convert_bounding_box_format
27+
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
2828

2929
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
3030

@@ -71,7 +71,7 @@ def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, i
7171
shape = keypoints.shape
7272
keypoints = keypoints.clone().reshape(-1, 2)
7373
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_()
74-
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
74+
return keypoints.reshape(shape)
7575

7676

7777
@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -159,7 +159,7 @@ def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int
159159
shape = keypoints.shape
160160
keypoints = keypoints.clone().reshape(-1, 2)
161161
keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_()
162-
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
162+
return keypoints.reshape(shape)
163163

164164

165165
def vertical_flip_bounding_boxes(
@@ -1026,7 +1026,7 @@ def _affine_keypoints_with_expand(
10261026
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
10271027
canvas_size = (new_height, new_width)
10281028

1029-
out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape)
1029+
out_keypoints = transformed_points.reshape(original_shape)
10301030
out_keypoints = out_keypoints.to(original_dtype)
10311031

10321032
return out_keypoints, canvas_size
@@ -1695,7 +1695,7 @@ def pad_keypoints(
16951695
left, right, top, bottom = _parse_pad_padding(padding)
16961696
pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
16971697
canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right)
1698-
return clamp_keypoints(keypoints + pad, canvas_size), canvas_size
1698+
return keypoints + pad, canvas_size
16991699

17001700

17011701
@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -1817,7 +1817,7 @@ def crop_keypoints(
18171817
keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
18181818
canvas_size = (height, width)
18191819

1820-
return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size
1820+
return keypoints, canvas_size
18211821

18221822

18231823
@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -2047,7 +2047,7 @@ def perspective_keypoints(
20472047
numer_points = torch.matmul(points, theta1.T)
20482048
denom_points = torch.matmul(points, theta2.T)
20492049
transformed_points = numer_points.div_(denom_points)
2050-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape)
2050+
return transformed_points.to(keypoints.dtype).reshape(original_shape)
20512051

20522052

20532053
@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -2376,7 +2376,7 @@ def elastic_keypoints(
23762376
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
23772377
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
23782378

2379-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape)
2379+
return transformed_points.to(keypoints.dtype).reshape(original_shape)
23802380

23812381

23822382
@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False)

torchvision/transforms/v2/functional/_misc.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,76 @@ def _get_sanitize_bounding_boxes_mask(
442442
valid &= (bounding_boxes[..., 4] <= image_w) & (bounding_boxes[..., 5] <= image_h)
443443
valid &= (bounding_boxes[..., 6] <= image_w) & (bounding_boxes[..., 7] <= image_h)
444444
return valid
445+
446+
447+
def sanitize_keypoints(
448+
key_points: torch.Tensor,
449+
canvas_size: Optional[tuple[int, int]] = None,
450+
) -> tuple[torch.Tensor, torch.Tensor]:
451+
"""Remove keypoints outside of the image area and their corresponding labels (if any).
452+
453+
This transform removes keypoints or groups of keypoints and their associated labels that
454+
have coordinates outside of their corresponding image.
455+
If you would instead like to clamp such keypoints to the image edges, use
456+
:class:`~torchvision.transforms.v2.ClampKeyPoints`.
457+
458+
It is recommended to call it at the end of a pipeline, before passing the
459+
input to the models.
460+
461+
Keypoints can be passed as a set of individual keypoints or as a set of objects
462+
(e.g., polygons or polygonal chains) consisting of a fixed number of keypoints of shape ``[..., 2]``.
463+
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor),
464+
this transform will only remove entire groups, not individual keypoints within a group.
465+
466+
Args:
467+
key_points (Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The keypoints to be sanitized.
468+
canvas_size (tuple of int, optional): The canvas_size of the keypoints
469+
(size of the corresponding image/video).
470+
Must be left to none if ``key_points`` is a :class:`~torchvision.tv_tensors.KeyPoints` object.
471+
472+
Returns:
473+
out (tuple of Tensors): The subset of valid keypoints, and the corresponding indexing mask.
474+
The mask can then be used to subset other tensors (e.g. labels) that are associated with the keypoints.
475+
"""
476+
if torch.jit.is_scripting() or is_pure_tensor(key_points):
477+
if canvas_size is None:
478+
raise ValueError(
479+
"canvas_size cannot be None if key_points is a pure tensor. "
480+
"Set it to an appropriate value or pass key_points as a tv_tensors.KeyPoints object."
481+
)
482+
valid = _get_sanitize_keypoints_mask(
483+
key_points,
484+
canvas_size=canvas_size,
485+
)
486+
key_points = key_points[valid]
487+
else:
488+
if not isinstance(key_points, tv_tensors.KeyPoints):
489+
raise ValueError("key_points must be a tv_tensors.KeyPoints instance or a pure tensor.")
490+
if canvas_size is not None:
491+
raise ValueError(
492+
"canvas_size must be None when key_points is a tv_tensors.KeyPoints instance. "
493+
f"Got canvas_size={canvas_size}. "
494+
"Leave it to None or pass key_points as a pure tensor."
495+
)
496+
valid = _get_sanitize_keypoints_mask(
497+
key_points,
498+
canvas_size=key_points.canvas_size,
499+
)
500+
key_points = tv_tensors.wrap(key_points[valid], like=key_points)
501+
502+
return key_points, valid
503+
504+
505+
def _get_sanitize_keypoints_mask(
506+
key_points: torch.Tensor,
507+
canvas_size: tuple[int, int],
508+
) -> torch.Tensor:
509+
510+
h, w = canvas_size
511+
512+
x, y = key_points[..., 0], key_points[..., 1]
513+
valid = (x >= 0) & (x < w) & (y >= 0) & (y < h)
514+
515+
valid = valid.flatten(start_dim=1).all(dim=1) if valid.ndim > 1 else valid
516+
517+
return valid

0 commit comments

Comments
 (0)