Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 89 additions & 44 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import os
from abc import ABC, abstractmethod
from functools import lru_cache
from functools import lru_cache, partial
from itertools import product
from typing import Callable

Expand Down Expand Up @@ -242,7 +242,7 @@ def _helper_boxes_shape(self, func):
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))

# test boxes as List[Tensor[N, 4]]
# test boxes as list[Tensor[N, 4]]
with pytest.raises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
Expand Down Expand Up @@ -1446,34 +1446,60 @@ def test_bbox_convert_jit(self):


class TestBoxArea:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area(box)
def area_check(self, box, expected, fmt="xyxy", atol=1e-4):
out = ops.box_area(box, fmt=fmt)
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
def test_int_boxes(self, dtype):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_int_boxes(self, dtype, fmt):
box_tensor = ops.box_convert(
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), in_fmt="xyxy", out_fmt=fmt
)
expected = torch.tensor([10000, 0], dtype=torch.int32)
self.area_check(box_tensor, expected)
self.area_check(box_tensor, expected, fmt)

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_float_boxes(self, dtype):
box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_float_boxes(self, dtype, fmt):
box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
self.area_check(box_tensor, expected)

def test_float16_box(self):
box_tensor = torch.tensor(
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
self.area_check(box_tensor, expected, fmt)

@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_float16_box(self, fmt):
box_tensor = ops.box_convert(
torch.tensor(
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]],
dtype=torch.float16,
),
in_fmt="xyxy",
out_fmt=fmt,
)

expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
self.area_check(box_tensor, expected, atol=0.01)
self.area_check(box_tensor, expected, fmt, atol=0.01)

@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_box_area_jit(self, fmt):
box_tensor = ops.box_convert(
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt=fmt
)
expected = ops.box_area(box_tensor, fmt)

class BoxArea(torch.nn.Module):
# We are using this intermediate class
# since torchscript does not support
# neither partial nor lambda functions for this test.
def __init__(self, fmt):
super().__init__()
self.area = ops.box_area
self.fmt = fmt

def test_box_area_jit(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
expected = ops.box_area(box_tensor)
scripted_fn = torch.jit.script(ops.box_area)
def forward(self, boxes):
return self.area(boxes, self.fmt)

scripted_fn = torch.jit.script(BoxArea(fmt))
scripted_area = scripted_fn(box_tensor)
torch.testing.assert_close(scripted_area, expected)

Expand All @@ -1487,25 +1513,28 @@ def test_box_area_jit(self):
]


def gen_box(size, dtype=torch.float) -> Tensor:
def gen_box(size, dtype=torch.float, fmt="xyxy") -> Tensor:
xy1 = torch.rand((size, 2), dtype=dtype)
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
return torch.cat([xy1, xy2], axis=-1)
return ops.box_convert(torch.cat([xy1, xy2], axis=-1), in_fmt="xyxy", out_fmt=fmt)


class TestIouBase:
@staticmethod
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected, fmt="xyxy"):
for dtype in dtypes:
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
_actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
_actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
expected_box = torch.tensor(expected)
out = target_fn(actual_box1, actual_box2)
out = target_fn(
_actual_box1,
_actual_box2,
)
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, actual_box: list):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
def _run_jit_test(target_fn: Callable, actual_box: list, fmt="xyxy"):
box_tensor = ops.box_convert(torch.tensor(actual_box, dtype=torch.float), in_fmt="xyxy", out_fmt=fmt)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
Expand All @@ -1522,17 +1551,17 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = gen_box(5)
boxes2 = gen_box(7)
def _run_cartesian_test(target_fn: Callable, fmt: str = "xyxy"):
boxes1 = gen_box(5, fmt=fmt)
boxes2 = gen_box(7, fmt=fmt)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
torch.testing.assert_close(a, b)

@staticmethod
def _run_batch_test(target_fn: Callable):
boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
def _run_batch_test(target_fn: Callable, fmt: str = "xyxy"):
boxes1 = torch.stack([gen_box(5, fmt=fmt) for _ in range(3)], dim=0)
boxes2 = torch.stack([gen_box(5, fmt=fmt) for _ in range(3)], dim=0)
native: Tensor = target_fn(boxes1, boxes2)
iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0)
torch.testing.assert_close(native, iterative)
Expand All @@ -1550,17 +1579,33 @@ class TestBoxIou(TestIouBase):
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)

def test_iou_batch(self):
self._run_batch_test(ops.box_iou)
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt):
self._run_test(partial(ops.box_iou, fmt=fmt), actual_box1, actual_box2, dtypes, atol, expected, fmt)

@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_iou_jit(self, fmt):
class IoUJit(torch.nn.Module):
# We are using this intermediate class
# since torchscript does not support
# neither partial nor lambda functions for this test.
def __init__(self, fmt):
super().__init__()
self.iou = ops.box_iou
self.fmt = fmt

def forward(self, boxes1, boxes2):
return self.iou(boxes1, boxes2, fmt=self.fmt)

self._run_jit_test(IoUJit(fmt=fmt), INT_BOXES, fmt)

@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_iou_cartesian(self, fmt):
self._run_cartesian_test(partial(ops.box_iou, fmt=fmt))

@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
def test_iou_batch(self, fmt):
self._run_batch_test(partial(ops.box_iou, fmt=fmt))


class TestGeneralizedBoxIou(TestIouBase):
Expand Down
80 changes: 61 additions & 19 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,33 +270,68 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes


def box_area(boxes: Tensor) -> Tensor:
def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by their
(x1, y1, x2, y2) coordinates.
Computes the area of a set of bounding boxes from a given format.

Args:
boxes (Tensor[..., 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
boxes (Tensor[..., 4]): boxes for which the area will be computed.
fmt (str): Format of the input boxes.
Default is "xyxy" to preserve backward compatibility.
Supported formats are "xyxy", "xywh", and "cxcywh".

Returns:
Tensor[N]: the area for each box
Tensor[N]: Tensor containing the area for each box.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_area)
allowed_fmts = (
"xyxy",
"xywh",
"cxcywh",
)
if fmt not in allowed_fmts:
raise ValueError(f"Unsupported Bounding Box area for given format {fmt}")
boxes = _upcast(boxes)
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
if fmt == "xyxy":
area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
else:
# For formats with width and height, area = width * height
# Supported: cxcywh, xywh
area = boxes[..., 2] * boxes[..., 3]

return area


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
area1 = box_area(boxes1)
area2 = box_area(boxes2)
def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
area1 = box_area(boxes1, fmt=fmt)
area2 = box_area(boxes2, fmt=fmt)

lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
allowed_fmts = (
"xyxy",
"xywh",
"cxcywh",
)
if fmt not in allowed_fmts:
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.")

if fmt == "xyxy":
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
elif fmt == "xywh":
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
) # [...,N,M,2]
else: # fmt == "cxcywh":
lt = torch.max(
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
) # [N,M,2]
rb = torch.min(
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
) # [N,M,2]

wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[..., 0] * wh[..., 1] # [N,M]
Expand All @@ -306,24 +341,31 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
return inter, union


def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
"""
Return intersection-over-union (Jaccard index) between two sets of boxes.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.

Args:
boxes1 (Tensor[..., N, 4]): first set of boxes
boxes2 (Tensor[..., M, 4]): second set of boxes
fmt (str): Format of the input boxes.
Default is "xyxy" to preserve backward compatibility.
Supported formats are "xyxy", "xywh", and "cxcywh".

Returns:
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_iou)
inter, union = _box_inter_union(boxes1, boxes2)
allowed_fmts = (
"xyxy",
"xywh",
"cxcywh",
)
if fmt not in allowed_fmts:
raise ValueError(f"Unsupported Box IoU Calculation for given format {fmt}.")
inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
iou = inter / union
return iou

Expand Down
Loading