Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,10 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
def make_keypoints(canvas_size=DEFAULT_SIZE, *, clamping_mode="soft", num_points=4, dtype=None, device="cpu"):
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)
return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size)
return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size, clamping_mode=clamping_mode)


def make_bounding_boxes(
Expand Down
79 changes: 66 additions & 13 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):

def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True):
canvas_size = new_canvas_size or keypoints.canvas_size
clamping_mode = keypoints.clamping_mode

def affine_keypoints(keypoints):
dtype = keypoints.dtype
Expand All @@ -652,7 +653,7 @@ def affine_keypoints(keypoints):
)

if clamp:
output = F.clamp_keypoints(output, canvas_size=canvas_size)
output = F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode)
else:
dtype = output.dtype

Expand All @@ -661,6 +662,7 @@ def affine_keypoints(keypoints):
return tv_tensors.KeyPoints(
torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape),
canvas_size=canvas_size,
clamping_mode=clamping_mode,
)


Expand Down Expand Up @@ -2084,7 +2086,6 @@ def test_functional(self, make_input):
(F.rotate_image, tv_tensors.Image),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
(F.rotate_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
Expand Down Expand Up @@ -3309,7 +3310,6 @@ def test_functional(self, make_input):
(F.elastic_image, tv_tensors.Image),
(F.elastic_mask, tv_tensors.Mask),
(F.elastic_video, tv_tensors.Video),
(F.elastic_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
Expand Down Expand Up @@ -4414,7 +4414,6 @@ def test_functional(self, make_input):
(F.resized_crop_image, tv_tensors.Image),
(F.resized_crop_mask, tv_tensors.Mask),
(F.resized_crop_video, tv_tensors.Video),
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
Expand Down Expand Up @@ -5325,6 +5324,7 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo

def _reference_perspective_keypoints(self, keypoints, *, startpoints, endpoints):
canvas_size = keypoints.canvas_size
clamping_mode = keypoints.clamping_mode
dtype = keypoints.dtype
device = keypoints.device

Expand Down Expand Up @@ -5361,16 +5361,16 @@ def perspective_keypoints(keypoints):
)

# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
return F.clamp_keypoints(
output,
canvas_size=canvas_size,
).to(dtype=dtype, device=device)
return F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode).to(
dtype=dtype, device=device
)

return tv_tensors.KeyPoints(
torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(
keypoints.shape
),
canvas_size=canvas_size,
clamping_mode=clamping_mode,
)

@pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
Expand Down Expand Up @@ -5733,32 +5733,85 @@ def test_error(self):


class TestClampKeyPoints:
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, dtype, device):
keypoints = make_keypoints(dtype=dtype, device=device)
def test_kernel(self, clamping_mode, dtype, device):
keypoints = make_keypoints(dtype=dtype, device=device, clamping_mode=clamping_mode)
check_kernel(
F.clamp_keypoints,
keypoints,
canvas_size=keypoints.canvas_size,
clamping_mode=clamping_mode,
)

def test_functional(self):
check_functional(F.clamp_keypoints, make_keypoints())
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))
def test_functional(self, clamping_mode):
check_functional(F.clamp_keypoints, make_keypoints(clamping_mode=clamping_mode))

def test_errors(self):
input_tv_tensor = make_keypoints()
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)

with pytest.raises(ValueError, match="`canvas_size` has to be passed"):
with pytest.raises(ValueError, match="`canvas_size` and `clamping_mode` have to be passed."):
F.clamp_keypoints(input_pure_tensor, canvas_size=None)

with pytest.raises(ValueError, match="`canvas_size` must not be passed"):
F.clamp_keypoints(input_tv_tensor, canvas_size=input_tv_tensor.canvas_size)
with pytest.raises(ValueError, match="clamping_mode must be soft,"):
F.clamp_keypoints(input_tv_tensor, clamping_mode="bad")
with pytest.raises(ValueError, match="clamping_mode must be soft,"):
transforms.ClampKeyPoints(clamping_mode="bad")(input_tv_tensor)

def test_transform(self):
check_transform(transforms.ClampKeyPoints(), make_keypoints())

@pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None))
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None, "auto"))
@pytest.mark.parametrize("pass_pure_tensor", (True, False))
@pytest.mark.parametrize("fn", [F.clamp_keypoints, transform_cls_to_functional(transforms.ClampKeyPoints)])
def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn):
# This test checks 2 things:
# - That passing clamping_mode=None to the clamp_keypointss
# functional (or to the class) relies on the box's `.clamping_mode`
# attribute
# - That clamping happens when it should, and only when it should, i.e.
# when the clamping mode is not None. It doesn't validate the
# numerical results, only that clamping happened. For that, we create
# a keypoints with large coordinates (100) inside of a small 10x10 image.

if pass_pure_tensor and fn is not F.clamp_keypoints:
# Only the functional supports pure tensors, not the class
return
if pass_pure_tensor and clamping_mode == "auto":
# cannot leave clamping_mode="auto" when passing pure tensor
return

keypoints = tv_tensors.KeyPoints(
[[0, 100], [0, 100]], canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
)
expected_clamped_output = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this line is redundant and can be removed.

torch.tensor([[0, 10], [0, 10]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
)
expected_clamped_output = (
torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
)

if pass_pure_tensor:
out = fn(
keypoints.as_subclass(torch.Tensor),
canvas_size=keypoints.canvas_size,
clamping_mode=clamping_mode,
)
else:
out = fn(keypoints, clamping_mode=clamping_mode)

clamping_mode_prevailing = constructor_clamping_mode if clamping_mode == "auto" else clamping_mode
if clamping_mode_prevailing is None:
assert_equal(keypoints, out) # should be a pass-through
else:
assert_equal(out, expected_clamped_output)


class TestInvert:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32])
Expand Down
21 changes: 16 additions & 5 deletions torchvision/transforms/v2/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
from torchvision.tv_tensors import CLAMPING_MODE_TYPE


class ConvertBoundingBoxFormat(Transform):
Expand Down Expand Up @@ -46,17 +46,27 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t
class ClampKeyPoints(Transform):
"""Clamp keypoints to their corresponding image dimensions.

The clamping is done according to the keypoints' ``canvas_size`` meta-data.
Args:
clamping_mode: Default is "auto" which relies on the input keypoint'
``clamping_mode`` attribute.
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
Read more in :ref:`clamping_mode_tuto`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clamping_mode_tuto in the docs currently only covers bounding boxes and would need to be updated as well.

for more details on how to use this transform.

"""

def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
super().__init__()
self.clamping_mode = clamping_mode

_transformed_types = (tv_tensors.KeyPoints,)

def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
return F.clamp_keypoints(inpt) # type: ignore[return-value]
return F.clamp_keypoints(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value]


class SetClampingMode(Transform):
"""Sets the ``clamping_mode`` attribute of the bounding boxes for future transforms.
"""Sets the ``clamping_mode`` attribute of the bounding boxes and keypoints for future transforms.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be useful to allow setting the clamping modes of bounding boxes and keypoints to different values by passing a dictionary.

For example:

  • SetClampingMode("soft") sets the clamping mode of both bounding boxes and keypoints to "soft".
  • SetClampingMode({tv_tensors.BoundingBoxes: "hard", tv_tensors.KeyPoints: "soft"}) sets the clamping mode of bounding boxes to "hard" and that of keypoints to "soft".
  • SetClampingMode({tv_tensors.BoundingBoxes: "hard"}) sets the clamping mode of bounding boxes to "hard" and leaves that of keypoints unchanged.




Expand All @@ -73,9 +83,10 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
if self.clamping_mode not in (None, "soft", "hard"):
raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}")

_transformed_types = (tv_tensors.BoundingBoxes,)
_transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints)

def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
# this method works for both `tv_tensors.BoundingBoxes`` and `tv_tensors.KeyPoints`.
out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment]
out.clamping_mode = self.clamping_mode
return out
Loading