From cf93d9e0c8b86117600fc6e8f0e54ca8dc080bcf Mon Sep 17 00:00:00 2001 From: alperenunlu <97191996+alperenunlu@users.noreply.github.com> Date: Wed, 12 Mar 2025 00:22:33 +0300 Subject: [PATCH 01/19] Add box_area_center and box_iou_center for cxcywh format --- docs/source/ops.rst | 2 + test/test_ops.py | 102 ++++++++++++++++++++++++++++++++++++ torchvision/ops/__init__.py | 4 ++ torchvision/ops/boxes.py | 55 +++++++++++++++++++ 4 files changed, 163 insertions(+) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 7124c85bb79..541b5c30c15 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -50,8 +50,10 @@ These utility functions perform various operations on bounding boxes. :template: function.rst box_area + box_area_center box_convert box_iou + box_iou_center clip_boxes_to_image complete_box_iou distance_box_iou diff --git a/test/test_ops.py b/test/test_ops.py index 88124f7ba17..4b94f5018dc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1451,6 +1451,41 @@ def test_box_area_jit(self): torch.testing.assert_close(scripted_area, expected) +class TestBoxAreaCenter: + def area_check(self, box, expected, atol=1e-4): + out = ops.box_area_center(box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) + + @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) + def test_int_boxes(self, dtype): + box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), + in_fmt="xyxy", out_fmt="cxcywh") + expected = torch.tensor([10000, 0], dtype=torch.int32) + self.area_check(box_tensor, expected) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) + def test_float_boxes(self, dtype): + box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh") + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) + self.area_check(box_tensor, expected) + + def test_float16_box(self): + box_tensor = ops.box_convert(torch.tensor( + [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16 + ), in_fmt="xyxy", out_fmt="cxcywh") + + expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) + self.area_check(box_tensor, expected, atol=0.01) + + def test_box_area_jit(self): + box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), + in_fmt="xyxy", out_fmt="cxcywh") + expected = ops.box_area_center(box_tensor) + scripted_fn = torch.jit.script(ops.box_area_center) + scripted_area = scripted_fn(box_tensor) + torch.testing.assert_close(scripted_area, expected) + + INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]] INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] FLOAT_BOXES = [ @@ -1459,6 +1494,14 @@ def test_box_area_jit(self): [279.2440, 197.9812, 1189.4746, 849.2019], ] +INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]] +INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]] +FLOAT_BOXES_CXCYWH = [ + [739.4324, 518.5154, 908.1572, 665.8793], + [738.8228, 519.9021, 907.3512, 662.3295], + [734.3593, 523.5916, 910.2306, 651.2207] +] + def gen_box(size, dtype=torch.float): xy1 = torch.rand((size, 2), dtype=dtype) @@ -1525,6 +1568,65 @@ def test_iou_cartesian(self): self._run_cartesian_test(ops.box_iou) +class TestIouCenterBase: + @staticmethod + def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): + for dtype in dtypes: + actual_box1 = torch.tensor(actual_box1, dtype=dtype) + actual_box2 = torch.tensor(actual_box2, dtype=dtype) + expected_box = torch.tensor(expected) + out = target_fn(actual_box1, actual_box2) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) + + @staticmethod + def _run_jit_test(target_fn: Callable, actual_box: List): + box_tensor = torch.tensor(actual_box, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected) + + @staticmethod + def _cartesian_product(boxes1, boxes2, target_fn: Callable): + N = boxes1.size(0) + M = boxes2.size(0) + result = torch.zeros((N, M)) + for i in range(N): + for j in range(M): + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + return result + + @staticmethod + def _run_cartesian_test(target_fn: Callable): + boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh") + boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh") + a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2) + torch.testing.assert_close(a, b) + + +class TestBoxIouCenter(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "actual_box1, actual_box2, dtypes, atol, expected", + [ + pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): + self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH) + + def test_iou_cartesian(self): + self._run_cartesian_test(ops.box_iou_center) + + class TestGeneralizedBoxIou(TestIouBase): int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..456bde2d036 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -2,8 +2,10 @@ from .boxes import ( batched_nms, box_area, + box_area_center, box_convert, box_iou, + box_iou_center, clip_boxes_to_image, complete_box_iou, distance_box_iou, @@ -40,7 +42,9 @@ "clip_boxes_to_image", "box_convert", "box_area", + "box_area_center", "box_iou", + "box_iou_center", "generalized_box_iou", "distance_box_iou", "complete_box_iou", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 9674d5bfa1d..961396ce9a7 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -290,6 +290,25 @@ def box_area(boxes: Tensor) -> Tensor: return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) +def box_area_center(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by their + (cx, cy, w, h) coordinates. + + Args: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (cx, cy, w, h) format with + ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``. + + Returns: + Tensor[N]: the area for each box + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(box_area) + boxes = _upcast(boxes) + return boxes[:, 2] * boxes[:, 3] + + # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py # with slight modifications def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: @@ -328,6 +347,42 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: return iou +def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: + area1 = box_area_center(boxes1) + area2 = box_area_center(boxes2) + + lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] + rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + return inter, union + + +def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """ + Return intersection-over-union (Jaccard index) between two sets of boxes. + + Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with + ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(box_iou_center) + inter, union = _box_inter_union_center(boxes1, boxes2) + iou = inter / union + return iou + + # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ From 29c5147122f219700707e9faaf1a2614aa186b54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alperen=20=C3=9CNL=C3=9C?= <97191996+alperenunlu@users.noreply.github.com> Date: Mon, 21 Apr 2025 23:58:10 +0300 Subject: [PATCH 02/19] Update boxes.py --- torchvision/ops/boxes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index f0e00e5355f..a16085acaae 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -305,7 +305,7 @@ def box_area_center(boxes: Tensor) -> Tensor: Tensor[N]: the area for each box """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_area) + _log_api_usage_once(box_area_center) boxes = _upcast(boxes) return boxes[:, 2] * boxes[:, 3] From 14a8fbee0eacaea95056a83c99f0e62b1c5e3103 Mon Sep 17 00:00:00 2001 From: alperenunlu <97191996+alperenunlu@users.noreply.github.com> Date: Sat, 12 Jul 2025 04:50:56 +0300 Subject: [PATCH 03/19] Dispatch style box_area and box_iou --- docs/source/ops.rst | 2 - test/test_ops.py | 90 ++++++++++++++++++++++++++----------- torchvision/ops/__init__.py | 4 -- torchvision/ops/boxes.py | 80 ++++++++++++++++++++++++++------- 4 files changed, 126 insertions(+), 50 deletions(-) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 541b5c30c15..7124c85bb79 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -50,10 +50,8 @@ These utility functions perform various operations on bounding boxes. :template: function.rst box_area - box_area_center box_convert box_iou - box_iou_center clip_boxes_to_image complete_box_iou distance_box_iou diff --git a/test/test_ops.py b/test/test_ops.py index 4b94f5018dc..287eef7ed44 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1418,9 +1418,9 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh) -class TestBoxArea: +class TestBoxAreaXYXY: def area_check(self, box, expected, atol=1e-4): - out = ops.box_area(box) + out = ops.box_area(box, fmt="xyxy") torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) @@ -1445,15 +1445,15 @@ def test_float16_box(self): def test_box_area_jit(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) - expected = ops.box_area(box_tensor) + expected = ops.box_area(box_tensor, fmt="xyxy") scripted_fn = torch.jit.script(ops.box_area) scripted_area = scripted_fn(box_tensor) torch.testing.assert_close(scripted_area, expected) -class TestBoxAreaCenter: +class TestBoxAreaCXCYWH: def area_check(self, box, expected, atol=1e-4): - out = ops.box_area_center(box) + out = ops.box_area(box, fmt="cxcywh") torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) @@ -1480,9 +1480,9 @@ def test_float16_box(self): def test_box_area_jit(self): box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt="cxcywh") - expected = ops.box_area_center(box_tensor) - scripted_fn = torch.jit.script(ops.box_area_center) - scripted_area = scripted_fn(box_tensor) + expected = ops.box_area(box_tensor, fmt="cxcywh") + scripted_fn = torch.jit.script(ops.box_area) + scripted_area = scripted_fn(box_tensor, fmt="cxcywh") torch.testing.assert_close(scripted_area, expected) @@ -1509,22 +1509,22 @@ def gen_box(size, dtype=torch.float): return torch.cat([xy1, xy2], axis=-1) -class TestIouBase: +class TestIouXYXYBase: @staticmethod def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): for dtype in dtypes: actual_box1 = torch.tensor(actual_box1, dtype=dtype) actual_box2 = torch.tensor(actual_box2, dtype=dtype) expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2) + out = target_fn(actual_box1, actual_box2, fmt="xyxy") torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod def _run_jit_test(target_fn: Callable, actual_box: List): box_tensor = torch.tensor(actual_box, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor) + expected = target_fn(box_tensor, box_tensor, fmt="xyxy") scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor) + scripted_out = scripted_fn(box_tensor, box_tensor, fmt="xyxy") torch.testing.assert_close(scripted_out, expected) @staticmethod @@ -1534,19 +1534,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable): result = torch.zeros((N, M)) for i in range(N): for j in range(M): - result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="xyxy") return result @staticmethod def _run_cartesian_test(target_fn: Callable): boxes1 = gen_box(5) boxes2 = gen_box(7) - a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) - b = target_fn(boxes1, boxes2) + a = TestIouXYXYBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2, fmt="xyxy") torch.testing.assert_close(a, b) -class TestBoxIou(TestIouBase): +class TestBoxIouXYXY(TestIouXYXYBase): int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @@ -1568,22 +1568,22 @@ def test_iou_cartesian(self): self._run_cartesian_test(ops.box_iou) -class TestIouCenterBase: +class TestIouCXCYWHBase: @staticmethod def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): for dtype in dtypes: actual_box1 = torch.tensor(actual_box1, dtype=dtype) actual_box2 = torch.tensor(actual_box2, dtype=dtype) expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2) + out = target_fn(actual_box1, actual_box2, fmt="cxcywh") torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod def _run_jit_test(target_fn: Callable, actual_box: List): box_tensor = torch.tensor(actual_box, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor) + expected = target_fn(box_tensor, box_tensor, fmt="cxcywh") scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor) + scripted_out = scripted_fn(box_tensor, box_tensor, fmt="cxcywh") torch.testing.assert_close(scripted_out, expected) @staticmethod @@ -1593,19 +1593,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable): result = torch.zeros((N, M)) for i in range(N): for j in range(M): - result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="cxcywh") return result @staticmethod def _run_cartesian_test(target_fn: Callable): boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh") boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh") - a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn) - b = target_fn(boxes1, boxes2) + a = TestIouCXCYWHBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2, fmt="cxcywh") torch.testing.assert_close(a, b) -class TestBoxIouCenter(TestIouBase): +class TestBoxIouCXCYWH(TestIouCXCYWHBase): int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @@ -1618,13 +1618,49 @@ class TestBoxIouCenter(TestIouBase): ], ) def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): - self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected) + self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected) def test_iou_jit(self): - self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH) + self._run_jit_test(ops.box_iou, INT_BOXES_CXCYWH) def test_iou_cartesian(self): - self._run_cartesian_test(ops.box_iou_center) + self._run_cartesian_test(ops.box_iou) + +class TestIouBase: + @staticmethod + def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): + for dtype in dtypes: + actual_box1 = torch.tensor(actual_box1, dtype=dtype) + actual_box2 = torch.tensor(actual_box2, dtype=dtype) + expected_box = torch.tensor(expected) + out = target_fn(actual_box1, actual_box2) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) + + @staticmethod + def _run_jit_test(target_fn: Callable, actual_box: List): + box_tensor = torch.tensor(actual_box, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected) + + @staticmethod + def _cartesian_product(boxes1, boxes2, target_fn: Callable): + N = boxes1.size(0) + M = boxes2.size(0) + result = torch.zeros((N, M)) + for i in range(N): + for j in range(M): + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + return result + + @staticmethod + def _run_cartesian_test(target_fn: Callable): + boxes1 = gen_box(5) + boxes2 = gen_box(7) + a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2) + torch.testing.assert_close(a, b) class TestGeneralizedBoxIou(TestIouBase): diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 456bde2d036..827505b842d 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -2,10 +2,8 @@ from .boxes import ( batched_nms, box_area, - box_area_center, box_convert, box_iou, - box_iou_center, clip_boxes_to_image, complete_box_iou, distance_box_iou, @@ -42,9 +40,7 @@ "clip_boxes_to_image", "box_convert", "box_area", - "box_area_center", "box_iou", - "box_iou_center", "generalized_box_iou", "distance_box_iou", "complete_box_iou", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a16085acaae..6b66bc4d118 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -272,7 +272,30 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: return boxes -def box_area(boxes: Tensor) -> Tensor: +def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: + """ + Computes the area of a set of bounding boxes from a given format. + + Args: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (x1, y1, x2, y2) format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy" + + Returns: + Tensor[N]: the area for each box + """ + if fmt == "xyxy": + boxes = box_area_xyxy(boxes=boxes) + elif fmt == "cxcywh": + boxes = box_area_cxcywh(boxes=boxes) + else: + raise ValueError(f"Unsupported Box Area Calculation for given fmt {fmt}") + + return boxes + + +def box_area_xyxy(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by their (x1, y1, x2, y2) coordinates. @@ -286,12 +309,12 @@ def box_area(boxes: Tensor) -> Tensor: Tensor[N]: the area for each box """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_area) + _log_api_usage_once(box_area_xyxy) boxes = _upcast(boxes) return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) -def box_area_center(boxes: Tensor) -> Tensor: +def box_area_cxcywh(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by their (cx, cy, w, h) coordinates. @@ -305,16 +328,39 @@ def box_area_center(boxes: Tensor) -> Tensor: Tensor[N]: the area for each box """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_area_center) + _log_api_usage_once(box_area_cxcywh) boxes = _upcast(boxes) return boxes[:, 2] * boxes[:, 3] +def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: + """ + Return intersection-over-union (Jaccard index) between two sets of boxes from a given format. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy" + + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + if fmt == "xyxy": + iou = box_iou_xyxy(boxes1=boxes1, boxes2=boxes2) + elif fmt == "cxcywh": + iou = box_iou_cxcywh(boxes1=boxes1, boxes2=boxes2) + else: + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}") + + return iou + + # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py # with slight modifications -def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: - area1 = box_area(boxes1) - area2 = box_area(boxes2) +def _box_inter_union_xyxy(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: + area1 = box_area(boxes1, fmt="xyxy") + area2 = box_area(boxes2, fmt="xyxy") lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] @@ -327,7 +373,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: return inter, union -def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: +def box_iou_xyxy(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ Return intersection-over-union (Jaccard index) between two sets of boxes. @@ -342,15 +388,15 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_iou) - inter, union = _box_inter_union(boxes1, boxes2) + _log_api_usage_once(box_iou_xyxy) + inter, union = _box_inter_union_xyxy(boxes1, boxes2) iou = inter / union return iou -def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: - area1 = box_area_center(boxes1) - area2 = box_area_center(boxes2) +def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: + area1 = box_area(boxes1, fmt="cxcywh") + area2 = box_area(boxes2, fmt="cxcywh") lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] @@ -363,7 +409,7 @@ def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Ten return inter, union -def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor: +def box_iou_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ Return intersection-over-union (Jaccard index) between two sets of boxes. @@ -378,8 +424,8 @@ def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor: Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_iou_center) - inter, union = _box_inter_union_center(boxes1, boxes2) + _log_api_usage_once(box_iou_cxcywh) + inter, union = _box_inter_union_cxcywh(boxes1, boxes2) iou = inter / union return iou @@ -403,7 +449,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou) - inter, union = _box_inter_union(boxes1, boxes2) + inter, union = _box_inter_union_xyxy(boxes1, boxes2) iou = inter / union lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) From 966108a1782ae42077eb08235c57d4f8e7632965 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alperen=20=C3=9CNL=C3=9C?= <97191996+alperenunlu@users.noreply.github.com> Date: Sat, 12 Jul 2025 05:08:52 +0300 Subject: [PATCH 04/19] Update boxes.py --- torchvision/ops/boxes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index ef0c3a46c3b..eb329311a8f 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -360,7 +360,6 @@ def _box_inter_union_xyxy(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tenso area1 = box_area(boxes1, fmt="xyxy") area2 = box_area(boxes2, fmt="xyxy") - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] From 09ae6a0eef4162efab6abe4b24c149578f03dc50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alperen=20=C3=9CNL=C3=9C?= <97191996+alperenunlu@users.noreply.github.com> Date: Sat, 12 Jul 2025 15:37:26 +0300 Subject: [PATCH 05/19] Update boxes.py --- torchvision/ops/boxes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index eb329311a8f..7a3b43b61cb 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -392,7 +392,7 @@ def box_iou_xyxy(boxes1: Tensor, boxes2: Tensor) -> Tensor: return iou -def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: +def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: area1 = box_area(boxes1, fmt="cxcywh") area2 = box_area(boxes2, fmt="cxcywh") From 3651f9e931e786f69432c9668107332336fa3282 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:22:19 -0700 Subject: [PATCH 06/19] Remove dispatcher --- torchvision/ops/boxes.py | 178 +++++++++++++-------------------------- 1 file changed, 57 insertions(+), 121 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 7a3b43b61cb..ab8430c26ac 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -275,60 +275,32 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: Computes the area of a set of bounding boxes from a given format. Args: - boxes (Tensor[N, 4]): boxes for which the area will be computed. They - are expected to be in (x1, y1, x2, y2) format with - ``0 <= x1 < x2`` and ``0 <= y1 < y2``. - fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy" + boxes (Tensor[N, 4]): Tensor containing N boxes. + format (str): Format of the input boxes. + Default is "xyxy" to preserve backward compatibility. + Supported formats are "xyxy", "xywh", and "cxcywh". Returns: - Tensor[N]: the area for each box - """ - if fmt == "xyxy": - boxes = box_area_xyxy(boxes=boxes) - elif fmt == "cxcywh": - boxes = box_area_cxcywh(boxes=boxes) - else: - raise ValueError(f"Unsupported Box Area Calculation for given fmt {fmt}") - - return boxes - - -def box_area_xyxy(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by their - (x1, y1, x2, y2) coordinates. - - Args: - boxes (Tensor[N, 4]): boxes for which the area will be computed. They - are expected to be in (x1, y1, x2, y2) format with - ``0 <= x1 < x2`` and ``0 <= y1 < y2``. - - Returns: - Tensor[N]: the area for each box + Tensor[N]: Tensor containing the area for each box. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_area_xyxy) + _log_api_usage_once(box_area) + allowed_fmts = ( + "xyxy", + "xywh", + "cxcywh", + ) + if fmt not in allowed_fmts: + raise ValueError(f"Unsupported Bounding Box area for given fmt {fmt}") boxes = _upcast(boxes) - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - -def box_area_cxcywh(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by their - (cx, cy, w, h) coordinates. - - Args: - boxes (Tensor[N, 4]): boxes for which the area will be computed. They - are expected to be in (cx, cy, w, h) format with - ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``. + if fmt == "xyxy": + area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + else: + # For formats with width and height, area = width * height + # Supported: cxcywh, xywh + area = boxes[:, 2] * boxes[:, 3] - Returns: - Tensor[N]: the area for each box - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_area_cxcywh) - boxes = _upcast(boxes) - return boxes[:, 2] * boxes[:, 3] + return area def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: @@ -336,69 +308,54 @@ def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: Return intersection-over-union (Jaccard index) between two sets of boxes from a given format. Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes - fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy" - - - Returns: - Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 - """ - if fmt == "xyxy": - iou = box_iou_xyxy(boxes1=boxes1, boxes2=boxes2) - elif fmt == "cxcywh": - iou = box_iou_cxcywh(boxes1=boxes1, boxes2=boxes2) - else: - raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}") - - return iou - - -# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py -# with slight modifications -def _box_inter_union_xyxy(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: - area1 = box_area(boxes1, fmt="xyxy") - area2 = box_area(boxes2, fmt="xyxy") - - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + boxes1 (Tensor[N, 4]): first set of boxes. + boxes2 (Tensor[M, 4]): second set of boxes. + format (str): Format of the input boxes. + Default is "xyxy" to preserve backward compatibility. + Supported formats are "xyxy", "xywh", and "cxcywh". - wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] - inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - return inter, union - - -def box_iou_xyxy(boxes1: Tensor, boxes2: Tensor) -> Tensor: - """ - Return intersection-over-union (Jaccard index) between two sets of boxes. - - Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with - ``0 <= x1 < x2`` and ``0 <= y1 < y2``. - - Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes Returns: Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_iou_xyxy) - inter, union = _box_inter_union_xyxy(boxes1, boxes2) + _log_api_usage_once(box_iou) + allowed_fmts = ( + "xyxy", + "xywh", + "cxcywh", + ) + if fmt not in allowed_fmts: + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") + inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt) iou = inter / union return iou -def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: - area1 = box_area(boxes1, fmt="cxcywh") - area2 = box_area(boxes2, fmt="cxcywh") +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]: + area1 = box_area(boxes1, fmt=fmt) + area2 = box_area(boxes2, fmt=fmt) - lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] - rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + allowed_fmts = ( + "xyxy", + "xywh", + "cxcywh", + ) + if fmt not in allowed_fmts: + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") + if fmt == "xyxy": + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + elif fmt == "xywh": + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + else: # fmt == "cxcywh": + lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] + rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] @@ -407,27 +364,6 @@ def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Ten return inter, union -def box_iou_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tensor: - """ - Return intersection-over-union (Jaccard index) between two sets of boxes. - - Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with - ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``. - - Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes - - Returns: - Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_iou_cxcywh) - inter, union = _box_inter_union_cxcywh(boxes1, boxes2) - iou = inter / union - return iou - - # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ @@ -447,7 +383,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou) - inter, union = _box_inter_union_xyxy(boxes1, boxes2) + inter, union = _box_inter_union(boxes1, boxes2) iou = inter / union lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) From e6a0c0995309ccc5537b20199bb62796fafbc28c Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:37:45 -0700 Subject: [PATCH 07/19] fix hinting --- test/test_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1c729a0d86f..c60ba338b56 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -242,7 +242,7 @@ def _helper_boxes_shape(self, func): boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype) func(a, boxes, output_size=(2, 2)) - # test boxes as List[Tensor[N, 4]] + # test boxes as list[Tensor[N, 4]] with pytest.raises(AssertionError): a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype) @@ -1547,7 +1547,7 @@ def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expec torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod - def _run_jit_test(target_fn: Callable, actual_box: List): + def _run_jit_test(target_fn: Callable, actual_box: list): box_tensor = torch.tensor(actual_box, dtype=torch.float) expected = target_fn(box_tensor, box_tensor, fmt="xyxy") scripted_fn = torch.jit.script(target_fn) @@ -1617,7 +1617,7 @@ def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expec torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod - def _run_jit_test(target_fn: Callable, actual_box: List): + def _run_jit_test(target_fn: Callable, actual_box: list): box_tensor = torch.tensor(actual_box, dtype=torch.float) expected = target_fn(box_tensor, box_tensor, fmt="cxcywh") scripted_fn = torch.jit.script(target_fn) From 55193f2994a9f5dce6a641db61bbf99afd1f6486 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:38:56 -0700 Subject: [PATCH 08/19] lint --- test/test_ops.py | 28 +++++++++++++++++++--------- torchvision/ops/boxes.py | 2 +- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c60ba338b56..a5fe4095d5c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1485,8 +1485,9 @@ def area_check(self, box, expected, atol=1e-4): @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) def test_int_boxes(self, dtype): - box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), - in_fmt="xyxy", out_fmt="cxcywh") + box_tensor = ops.box_convert( + torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh" + ) expected = torch.tensor([10000, 0], dtype=torch.int32) self.area_check(box_tensor, expected) @@ -1497,16 +1498,22 @@ def test_float_boxes(self, dtype): self.area_check(box_tensor, expected) def test_float16_box(self): - box_tensor = ops.box_convert(torch.tensor( - [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16 - ), in_fmt="xyxy", out_fmt="cxcywh") + box_tensor = ops.box_convert( + torch.tensor( + [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], + dtype=torch.float16, + ), + in_fmt="xyxy", + out_fmt="cxcywh", + ) expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) self.area_check(box_tensor, expected, atol=0.01) def test_box_area_jit(self): - box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), - in_fmt="xyxy", out_fmt="cxcywh") + box_tensor = ops.box_convert( + torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt="cxcywh" + ) expected = ops.box_area(box_tensor, fmt="cxcywh") scripted_fn = torch.jit.script(ops.box_area) scripted_area = scripted_fn(box_tensor, fmt="cxcywh") @@ -1526,7 +1533,7 @@ def test_box_area_jit(self): FLOAT_BOXES_CXCYWH = [ [739.4324, 518.5154, 908.1572, 665.8793], [738.8228, 519.9021, 907.3512, 662.3295], - [734.3593, 523.5916, 910.2306, 651.2207] + [734.3593, 523.5916, 910.2306, 651.2207], ] @@ -1650,7 +1657,9 @@ class TestBoxIouCXCYWH(TestIouCXCYWHBase): @pytest.mark.parametrize( "actual_box1, actual_box2, dtypes, atol, expected", [ - pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param( + INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected + ), pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected), pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected), ], @@ -1664,6 +1673,7 @@ def test_iou_jit(self): def test_iou_cartesian(self): self._run_cartesian_test(ops.box_iou) + class TestIouBase: @staticmethod def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index d94c97d3fcb..ae6883c9beb 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -355,7 +355,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple else: # fmt == "cxcywh": lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] - + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] From 7b6a3ea0c4aa472d1a1b9aa47ed1fc1d250406b9 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:44:49 -0700 Subject: [PATCH 09/19] fix `_box_inter_union` for "xywh" format --- torchvision/ops/boxes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index ae6883c9beb..c6638865740 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -351,7 +351,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] elif fmt == "xywh": lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:], boxes2[:, :2] + boxes2[:, 2:]) # [N,M,2] else: # fmt == "cxcywh": lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] From adcc87847e0940e38e20ae994d69105e1c6d1c23 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:45:34 -0700 Subject: [PATCH 10/19] Re-order with original file structure --- torchvision/ops/boxes.py | 58 ++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index c6638865740..ec2bfd3c640 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -303,35 +303,6 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: return area -def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: - """ - Return intersection-over-union (Jaccard index) between two sets of boxes from a given format. - - Args: - boxes1 (Tensor[..., N, 4]): first set of boxes - boxes2 (Tensor[..., M, 4]): second set of boxes - format (str): Format of the input boxes. - Default is "xyxy" to preserve backward compatibility. - Supported formats are "xyxy", "xywh", and "cxcywh". - - Returns: - Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element - in boxes1 and boxes2 - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_iou) - allowed_fmts = ( - "xyxy", - "xywh", - "cxcywh", - ) - if fmt not in allowed_fmts: - raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") - inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt) - iou = inter / union - return iou - - # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py # with slight modifications def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]: @@ -364,6 +335,35 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple return inter, union +def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: + """ + Return intersection-over-union (Jaccard index) between two sets of boxes from a given format. + + Args: + boxes1 (Tensor[..., N, 4]): first set of boxes + boxes2 (Tensor[..., M, 4]): second set of boxes + format (str): Format of the input boxes. + Default is "xyxy" to preserve backward compatibility. + Supported formats are "xyxy", "xywh", and "cxcywh". + + Returns: + Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element + in boxes1 and boxes2 + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(box_iou) + allowed_fmts = ( + "xyxy", + "xywh", + "cxcywh", + ) + if fmt not in allowed_fmts: + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") + inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt) + iou = inter / union + return iou + + # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ From 8bc72b615ddd1800976e021d3bee18d7c81533cf Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:50:53 -0700 Subject: [PATCH 11/19] Keep batch dimension --- torchvision/ops/boxes.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index ec2bfd3c640..960424f1a21 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -318,19 +318,25 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") if fmt == "xyxy": - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] + rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2] elif fmt == "xywh": - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:], boxes2[:, :2] + boxes2[:, 2:]) # [N,M,2] + lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] + rb = torch.min( + boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :2] + boxes2[..., None, 2:] + ) # [N,M,2] else: # fmt == "cxcywh": - lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] - rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + lt = torch.max( + boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :2] - boxes2[..., None, 2:] / 2 + ) # [N,M,2] + rb = torch.min( + boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :2] + boxes2[..., None, 2:] / 2 + ) # [N,M,2] - wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] - inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] + inter = wh[..., 0] * wh[..., 1] # [N,M] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] - union = area1[:, None] + area2 - inter + union = area1[..., None] + area2[..., None, :] - inter return inter, union From c509b11299d1d8d20d467dffcde744c0381b6892 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 09:54:08 -0700 Subject: [PATCH 12/19] Keep batch dimension (2/2) --- torchvision/ops/boxes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 960424f1a21..52fb56a48ef 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -294,11 +294,11 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: raise ValueError(f"Unsupported Bounding Box area for given fmt {fmt}") boxes = _upcast(boxes) if fmt == "xyxy": - area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1]) else: # For formats with width and height, area = width * height # Supported: cxcywh, xywh - area = boxes[:, 2] * boxes[:, 3] + area = boxes[..., 2] * boxes[..., 3] return area @@ -323,14 +323,14 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple elif fmt == "xywh": lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] rb = torch.min( - boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :2] + boxes2[..., None, 2:] + boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] ) # [N,M,2] else: # fmt == "cxcywh": lt = torch.max( - boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :2] - boxes2[..., None, 2:] / 2 + boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2 ) # [N,M,2] rb = torch.min( - boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :2] + boxes2[..., None, 2:] / 2 + boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2 ) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] From 107d2cd0c491e69025925b0324cfd423700417a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alperen=20=C3=9CNL=C3=9C?= <97191996+alperenunlu@users.noreply.github.com> Date: Thu, 14 Aug 2025 11:15:16 +0300 Subject: [PATCH 13/19] Fix f-string boxes.py --- torchvision/ops/boxes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 52fb56a48ef..506e2c7855f 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -315,7 +315,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple "cxcywh", ) if fmt not in allowed_fmts: - raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.") if fmt == "xyxy": lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] @@ -364,7 +364,7 @@ def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: "cxcywh", ) if fmt not in allowed_fmts: - raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.") + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.") inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt) iou = inter / union return iou From b6e7a4c2f7448bb21b8612d558de811888459d17 Mon Sep 17 00:00:00 2001 From: alperenunlu <97191996+alperenunlu@users.noreply.github.com> Date: Thu, 14 Aug 2025 11:25:46 +0300 Subject: [PATCH 14/19] Fix docstring boxes.py --- torchvision/ops/boxes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 506e2c7855f..001c52d9289 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -276,7 +276,7 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: Args: boxes (Tensor[..., 4]): boxes for which the area will be computed. - format (str): Format of the input boxes. + fmt (str): Format of the input boxes. Default is "xyxy" to preserve backward compatibility. Supported formats are "xyxy", "xywh", and "cxcywh". @@ -291,7 +291,7 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: "cxcywh", ) if fmt not in allowed_fmts: - raise ValueError(f"Unsupported Bounding Box area for given fmt {fmt}") + raise ValueError(f"Unsupported Bounding Box area for given format {fmt}") boxes = _upcast(boxes) if fmt == "xyxy": area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1]) @@ -348,7 +348,7 @@ def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: Args: boxes1 (Tensor[..., N, 4]): first set of boxes boxes2 (Tensor[..., M, 4]): second set of boxes - format (str): Format of the input boxes. + fmt (str): Format of the input boxes. Default is "xyxy" to preserve backward compatibility. Supported formats are "xyxy", "xywh", and "cxcywh". @@ -364,7 +364,7 @@ def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: "cxcywh", ) if fmt not in allowed_fmts: - raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.") + raise ValueError(f"Unsupported Box IoU Calculation for given format {fmt}.") inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt) iou = inter / union return iou From d7467946f053c735262355a5b5d981e06e1bfccb Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Fri, 22 Aug 2025 08:26:50 -0700 Subject: [PATCH 15/19] simplify test case --- test/test_ops.py | 241 +++++++++++------------------------------------ 1 file changed, 55 insertions(+), 186 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a5fe4095d5c..6f31b647fe7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,7 +1,7 @@ import math import os from abc import ABC, abstractmethod -from functools import lru_cache +from functools import lru_cache, partial from itertools import product from typing import Callable @@ -1445,79 +1445,50 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh) -class TestBoxAreaXYXY: - def area_check(self, box, expected, atol=1e-4): - out = ops.box_area(box, fmt="xyxy") +class TestBoxArea: + def area_check(self, box, expected, fmt="xyxy", atol=1e-4): + out = ops.box_area(box, fmt=fmt) torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) - def test_int_boxes(self, dtype): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) - expected = torch.tensor([10000, 0], dtype=torch.int32) - self.area_check(box_tensor, expected) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) - def test_float_boxes(self, dtype): - box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype) - expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) - self.area_check(box_tensor, expected) - - def test_float16_box(self): - box_tensor = torch.tensor( - [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16 - ) - - expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) - self.area_check(box_tensor, expected, atol=0.01) - - def test_box_area_jit(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) - expected = ops.box_area(box_tensor, fmt="xyxy") - scripted_fn = torch.jit.script(ops.box_area) - scripted_area = scripted_fn(box_tensor) - torch.testing.assert_close(scripted_area, expected) - - -class TestBoxAreaCXCYWH: - def area_check(self, box, expected, atol=1e-4): - out = ops.box_area(box, fmt="cxcywh") - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) - - @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) - def test_int_boxes(self, dtype): + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_int_boxes(self, dtype, fmt): box_tensor = ops.box_convert( - torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh" + torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), in_fmt="xyxy", out_fmt=fmt ) expected = torch.tensor([10000, 0], dtype=torch.int32) - self.area_check(box_tensor, expected) + self.area_check(box_tensor, expected, fmt) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) - def test_float_boxes(self, dtype): - box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh") + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_float_boxes(self, dtype, fmt): + box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) - self.area_check(box_tensor, expected) + self.area_check(box_tensor, expected, fmt) - def test_float16_box(self): + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_float16_box(self, fmt): box_tensor = ops.box_convert( torch.tensor( [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16, ), in_fmt="xyxy", - out_fmt="cxcywh", + out_fmt=fmt, ) expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) - self.area_check(box_tensor, expected, atol=0.01) + self.area_check(box_tensor, expected, fmt, atol=0.01) - def test_box_area_jit(self): + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_box_area_jit(self, fmt): box_tensor = ops.box_convert( - torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt="cxcywh" + torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt=fmt ) - expected = ops.box_area(box_tensor, fmt="cxcywh") + expected = ops.box_area(box_tensor, fmt) scripted_fn = torch.jit.script(ops.box_area) - scripted_area = scripted_fn(box_tensor, fmt="cxcywh") - torch.testing.assert_close(scripted_area, expected) + scripted_area = scripted_fn(box_tensor) + torch.testing.assert_close(scripted_area, expected, fmt) INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]] @@ -1528,37 +1499,29 @@ def test_box_area_jit(self): [279.2440, 197.9812, 1189.4746, 849.2019], ] -INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]] -INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]] -FLOAT_BOXES_CXCYWH = [ - [739.4324, 518.5154, 908.1572, 665.8793], - [738.8228, 519.9021, 907.3512, 662.3295], - [734.3593, 523.5916, 910.2306, 651.2207], -] - -def gen_box(size, dtype=torch.float) -> Tensor: +def gen_box(size, dtype=torch.float, fmt="xyxy") -> Tensor: xy1 = torch.rand((size, 2), dtype=dtype) xy2 = xy1 + torch.rand((size, 2), dtype=dtype) - return torch.cat([xy1, xy2], axis=-1) + return ops.box_convert(torch.cat([xy1, xy2], axis=-1), in_fmt="xyxy", out_fmt=fmt) -class TestIouXYXYBase: +class TestIouBase: @staticmethod - def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): + def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected, fmt="xyxy"): for dtype in dtypes: - actual_box1 = torch.tensor(actual_box1, dtype=dtype) - actual_box2 = torch.tensor(actual_box2, dtype=dtype) + actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) + actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2, fmt="xyxy") + out = target_fn(actual_box1, actual_box2) torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod - def _run_jit_test(target_fn: Callable, actual_box: list): - box_tensor = torch.tensor(actual_box, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor, fmt="xyxy") + def _run_jit_test(target_fn: Callable, actual_box: list, fmt="xyxy"): + box_tensor = ops.box_convert(torch.tensor(actual_box, dtype=torch.float), in_fmt="xyxy", out_fmt=fmt) + expected = target_fn(box_tensor, box_tensor) scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor, fmt="xyxy") + scripted_out = scripted_fn(box_tensor, box_tensor) torch.testing.assert_close(scripted_out, expected) @staticmethod @@ -1568,27 +1531,27 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable): result = torch.zeros((N, M)) for i in range(N): for j in range(M): - result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="xyxy") + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) return result @staticmethod - def _run_cartesian_test(target_fn: Callable): - boxes1 = gen_box(5) - boxes2 = gen_box(7) - a = TestIouXYXYBase._cartesian_product(boxes1, boxes2, target_fn) - b = target_fn(boxes1, boxes2, fmt="xyxy") + def _run_cartesian_test(target_fn: Callable, fmt: str = "xyxy"): + boxes1 = gen_box(5, fmt=fmt) + boxes2 = gen_box(7, fmt=fmt) + a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2) torch.testing.assert_close(a, b) @staticmethod - def _run_batch_test(target_fn: Callable): - boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0) - boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0) + def _run_batch_test(target_fn: Callable, fmt: str = "xyxy"): + boxes1 = torch.stack([gen_box(5, fmt=fmt) for _ in range(3)], dim=0) + boxes2 = torch.stack([gen_box(5, fmt=fmt) for _ in range(3)], dim=0) native: Tensor = target_fn(boxes1, boxes2) iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0) torch.testing.assert_close(native, iterative) -class TestBoxIouXYXY(TestIouXYXYBase): +class TestBoxIou(TestIouBase): int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @@ -1600,115 +1563,21 @@ class TestBoxIouXYXY(TestIouXYXYBase): pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): - self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.box_iou, INT_BOXES) - - def test_iou_cartesian(self): - self._run_cartesian_test(ops.box_iou) - - def test_iou_batch(self): - self._run_batch_test(ops.box_iou) + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt): + self._run_test(partial(ops.box_iou, fmt=fmt), actual_box1, actual_box2, dtypes, atol, expected, fmt) + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_iou_jit(self, fmt): + self._run_jit_test(partial(ops.box_iou, fmt=fmt), INT_BOXES, fmt) -class TestIouCXCYWHBase: - @staticmethod - def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): - for dtype in dtypes: - actual_box1 = torch.tensor(actual_box1, dtype=dtype) - actual_box2 = torch.tensor(actual_box2, dtype=dtype) - expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2, fmt="cxcywh") - torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) - - @staticmethod - def _run_jit_test(target_fn: Callable, actual_box: list): - box_tensor = torch.tensor(actual_box, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor, fmt="cxcywh") - scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor, fmt="cxcywh") - torch.testing.assert_close(scripted_out, expected) + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_iou_cartesian(self, fmt): + self._run_cartesian_test(partial(ops.box_iou, fmt=fmt)) - @staticmethod - def _cartesian_product(boxes1, boxes2, target_fn: Callable): - N = boxes1.size(0) - M = boxes2.size(0) - result = torch.zeros((N, M)) - for i in range(N): - for j in range(M): - result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="cxcywh") - return result - - @staticmethod - def _run_cartesian_test(target_fn: Callable): - boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh") - boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh") - a = TestIouCXCYWHBase._cartesian_product(boxes1, boxes2, target_fn) - b = target_fn(boxes1, boxes2, fmt="cxcywh") - torch.testing.assert_close(a, b) - - -class TestBoxIouCXCYWH(TestIouCXCYWHBase): - int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]] - float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "actual_box1, actual_box2, dtypes, atol, expected", - [ - pytest.param( - INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected - ), - pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected), - pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected), - ], - ) - def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): - self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.box_iou, INT_BOXES_CXCYWH) - - def test_iou_cartesian(self): - self._run_cartesian_test(ops.box_iou) - - -class TestIouBase: - @staticmethod - def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): - for dtype in dtypes: - actual_box1 = torch.tensor(actual_box1, dtype=dtype) - actual_box2 = torch.tensor(actual_box2, dtype=dtype) - expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2) - torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) - - @staticmethod - def _run_jit_test(target_fn: Callable, actual_box: list): - box_tensor = torch.tensor(actual_box, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor) - scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor) - torch.testing.assert_close(scripted_out, expected) - - @staticmethod - def _cartesian_product(boxes1, boxes2, target_fn: Callable): - N = boxes1.size(0) - M = boxes2.size(0) - result = torch.zeros((N, M)) - for i in range(N): - for j in range(M): - result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) - return result - - @staticmethod - def _run_cartesian_test(target_fn: Callable): - boxes1 = gen_box(5) - boxes2 = gen_box(7) - a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) - b = target_fn(boxes1, boxes2) - torch.testing.assert_close(a, b) + @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) + def test_iou_batch(self, fmt): + self._run_batch_test(partial(ops.box_iou, fmt=fmt)) class TestGeneralizedBoxIou(TestIouBase): From 405f31e7e165378f6a509527644ef1f88255680b Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Fri, 22 Aug 2025 08:54:42 -0700 Subject: [PATCH 16/19] fix tests --- test/test_ops.py | 20 ++++++++++++++++---- torchvision/ops/boxes.py | 6 +++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 6f31b647fe7..a99f207ba4f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1510,10 +1510,13 @@ class TestIouBase: @staticmethod def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected, fmt="xyxy"): for dtype in dtypes: - actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) - actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) + _actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) + _actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt) expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2) + out = target_fn( + _actual_box1, + _actual_box2, + ) torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod @@ -1569,7 +1572,16 @@ def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt): @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) def test_iou_jit(self, fmt): - self._run_jit_test(partial(ops.box_iou, fmt=fmt), INT_BOXES, fmt) + class IoUJit(torch.nn.Module): + def __init__(self, fmt): + super().__init__() + self.iou = ops.box_iou + self.fmt = fmt + + def forward(self, boxes1, boxes2): + return self.iou(boxes1, boxes2) + + self._run_jit_test(IoUJit(fmt=fmt), INT_BOXES, fmt) @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) def test_iou_cartesian(self, fmt): diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 001c52d9289..54f8d6b86e9 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -324,7 +324,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] rb = torch.min( boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] - ) # [N,M,2] + ) # [...,N,M,2] else: # fmt == "cxcywh": lt = torch.max( boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2 @@ -333,8 +333,8 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2 ) # [N,M,2] - wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] - inter = wh[..., 0] * wh[..., 1] # [N,M] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] + inter = wh[..., 0] * wh[..., 1] # [N,M] union = area1[..., None] + area2[..., None, :] - inter From 298cfa966c5ee0e18fd34e3e4125f8952422acfa Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Fri, 22 Aug 2025 09:03:45 -0700 Subject: [PATCH 17/19] add comment --- test/test_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index a99f207ba4f..ffbab5ccbdc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1573,6 +1573,9 @@ def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt): @pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"]) def test_iou_jit(self, fmt): class IoUJit(torch.nn.Module): + # We are using this intermediate class + # since torchscript does not support + # neither partial nor lambda functions for this test. def __init__(self, fmt): super().__init__() self.iou = ops.box_iou From bc36d1e5fac7b30a22b99ae329032b64781a0762 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Fri, 22 Aug 2025 09:34:38 -0700 Subject: [PATCH 18/19] fix area test --- test/test_ops.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index ffbab5ccbdc..ff2dad8cd0e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1486,9 +1486,21 @@ def test_box_area_jit(self, fmt): torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt=fmt ) expected = ops.box_area(box_tensor, fmt) - scripted_fn = torch.jit.script(ops.box_area) + class BoxArea(torch.nn.Module): + # We are using this intermediate class + # since torchscript does not support + # neither partial nor lambda functions for this test. + def __init__(self, fmt): + super().__init__() + self.area = ops.box_area + self.fmt = fmt + + def forward(self, boxes): + return self.area(boxes, self.fmt) + + scripted_fn = torch.jit.script(BoxArea(fmt)) scripted_area = scripted_fn(box_tensor) - torch.testing.assert_close(scripted_area, expected, fmt) + torch.testing.assert_close(scripted_area, expected) INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]] @@ -1582,7 +1594,7 @@ def __init__(self, fmt): self.fmt = fmt def forward(self, boxes1, boxes2): - return self.iou(boxes1, boxes2) + return self.iou(boxes1, boxes2, fmt=self.fmt) self._run_jit_test(IoUJit(fmt=fmt), INT_BOXES, fmt) From 80d12de6a5c81d976f82104b8d1b8e60d8577a82 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Sat, 23 Aug 2025 08:40:36 -0700 Subject: [PATCH 19/19] lint --- test/test_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_ops.py b/test/test_ops.py index ff2dad8cd0e..d2cf8d29181 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1486,6 +1486,7 @@ def test_box_area_jit(self, fmt): torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt=fmt ) expected = ops.box_area(box_tensor, fmt) + class BoxArea(torch.nn.Module): # We are using this intermediate class # since torchscript does not support