Skip to content

Commit fd5f406

Browse files
[tests] Add additional tests and clean up missing parts (#1982)
1 parent 672e03a commit fd5f406

File tree

5 files changed

+95
-79
lines changed

5 files changed

+95
-79
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<img src="https://github.com/mindee/doctr/raw/main/docs/images/Logo_doctr.gif" width="40%">
33
</p>
44

5-
[![Slack Icon](https://img.shields.io/badge/Slack-Community-4A154B?style=flat-square&logo=slack&logoColor=white)](https://slack.mindee.com) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) ![Build Status](https://github.com/mindee/doctr/workflows/builds/badge.svg) [![Docker Images](https://img.shields.io/badge/Docker-4287f5?style=flat&logo=docker&logoColor=white)](https://github.com/mindee/doctr/pkgs/container/doctr) [![codecov](https://codecov.io/gh/mindee/doctr/branch/main/graph/badge.svg?token=577MO567NM)](https://codecov.io/gh/mindee/doctr) [![CodeFactor](https://www.codefactor.io/repository/github/mindee/doctr/badge?s=bae07db86bb079ce9d6542315b8c6e70fa708a7e)](https://www.codefactor.io/repository/github/mindee/doctr) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/340a76749b634586a498e1c0ab998f08)](https://app.codacy.com/gh/mindee/doctr?utm_source=github.com&utm_medium=referral&utm_content=mindee/doctr&utm_campaign=Badge_Grade) [![Doc Status](https://github.com/mindee/doctr/workflows/doc-status/badge.svg)](https://mindee.github.io/doctr) [![Pypi](https://img.shields.io/badge/pypi-v0.12.0-blue.svg)](https://pypi.org/project/python-doctr/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/mindee/doctr) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb) [![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20docTR%20Guru-006BFF)](https://gurubase.io/g/doctr)
5+
[![Slack Icon](https://img.shields.io/badge/Slack-Community-4A154B?style=flat-square&logo=slack&logoColor=white)](https://slack.mindee.com) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) ![Build Status](https://github.com/mindee/doctr/workflows/builds/badge.svg) [![Docker Images](https://img.shields.io/badge/Docker-4287f5?style=flat&logo=docker&logoColor=white)](https://github.com/mindee/doctr/pkgs/container/doctr) [![codecov](https://codecov.io/gh/mindee/doctr/branch/main/graph/badge.svg?token=577MO567NM)](https://codecov.io/gh/mindee/doctr) [![CodeFactor](https://www.codefactor.io/repository/github/mindee/doctr/badge?s=bae07db86bb079ce9d6542315b8c6e70fa708a7e)](https://www.codefactor.io/repository/github/mindee/doctr) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/340a76749b634586a498e1c0ab998f08)](https://app.codacy.com/gh/mindee/doctr?utm_source=github.com&utm_medium=referral&utm_content=mindee/doctr&utm_campaign=Badge_Grade) [![Doc Status](https://github.com/mindee/doctr/workflows/doc-status/badge.svg)](https://mindee.github.io/doctr) [![Pypi](https://img.shields.io/badge/pypi-v1.0.0-blue.svg)](https://pypi.org/project/python-doctr/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/mindee/doctr) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb) [![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20docTR%20Guru-006BFF)](https://gurubase.io/g/doctr)
66

77

88
**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**

api/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
44

55
[tool.poetry]
66
name = "doctr-api"
7-
version = "0.12.1a0"
7+
version = "1.0.0a0"
88
description = "Backend template for your OCR API with docTR"
99
authors = ["Mindee <[email protected]>"]
1010
license = "Apache-2.0"

doctr/models/preprocessor/pytorch.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,21 @@ def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]:
6060

6161
return batches
6262

63-
def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
63+
def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
6464
if x.ndim != 3:
6565
raise AssertionError("expected list of 3D Tensors")
66-
if isinstance(x, np.ndarray):
67-
if x.dtype not in (np.uint8, np.float32, np.float16):
68-
raise TypeError("unsupported data type for numpy.ndarray")
69-
x = torch.from_numpy(x.copy()).permute(2, 0, 1)
70-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
71-
raise TypeError("unsupported data type for torch.Tensor")
66+
if x.dtype not in (np.uint8, np.float32, np.float16):
67+
raise TypeError("unsupported data type for numpy.ndarray")
68+
tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
7269
# Resizing
73-
x = self.resize(x)
70+
tensor = self.resize(tensor)
7471
# Data type
75-
if x.dtype == torch.uint8:
76-
x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
72+
if tensor.dtype == torch.uint8:
73+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
7774
else:
78-
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
75+
tensor = tensor.to(dtype=torch.float32)
7976

80-
return x
77+
return tensor
8178

8279
def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
8380
"""Prepare document data for model forwarding
@@ -94,29 +91,29 @@ def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
9491
raise AssertionError("expected 4D Tensor")
9592
if x.dtype not in (np.uint8, np.float32, np.float16):
9693
raise TypeError("unsupported data type for numpy.ndarray")
97-
x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) # type: ignore[assignment]
94+
tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
9895

9996
# Resizing
100-
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
101-
x = F.resize(
102-
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
97+
if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]:
98+
tensor = F.resize(
99+
tensor, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
103100
)
104101
# Data type
105-
if x.dtype == torch.uint8: # type: ignore[union-attr]
106-
x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
102+
if tensor.dtype == torch.uint8:
103+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
107104
else:
108-
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
109-
batches = [x]
105+
tensor = tensor.to(dtype=torch.float32)
106+
batches = [tensor]
110107

111108
elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
112109
# Sample transform (to tensor, resize)
113110
samples = list(multithread_exec(self.sample_transforms, x))
114111
# Batching
115-
batches = self.batch_inputs(samples) # type: ignore[assignment]
112+
batches = self.batch_inputs(samples)
116113
else:
117114
raise TypeError(f"invalid input type: {type(x)}")
118115

119116
# Batch transforms (normalize)
120117
batches = list(multithread_exec(self.normalize, batches))
121118

122-
return batches # type: ignore[return-value]
119+
return batches

doctr/transforms/modules/pytorch.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,21 @@
2727

2828

2929
class Resize(T.Resize):
30-
"""Resize the input image to the given size"""
30+
"""Resize the input image to the given size
31+
32+
>>> import torch
33+
>>> from doctr.transforms import Resize
34+
>>> transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=True)
35+
>>> out = transfo(torch.rand((3, 64, 64)))
36+
37+
Args:
38+
size: output size in pixels, either a tuple (height, width) or a single integer for square images
39+
interpolation: interpolation mode to use for resizing, default is bilinear
40+
preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
41+
if True, the image will be resized to fit within the target size while maintaining its aspect ratio
42+
symmetric_pad: whether to symmetrically pad the image to the target size,
43+
if True, the image will be padded equally on both sides to fit the target size
44+
"""
3145

3246
def __init__(
3347
self,
@@ -36,42 +50,30 @@ def __init__(
3650
preserve_aspect_ratio: bool = False,
3751
symmetric_pad: bool = False,
3852
) -> None:
39-
super().__init__(size, interpolation, antialias=True)
53+
super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True)
4054
self.preserve_aspect_ratio = preserve_aspect_ratio
4155
self.symmetric_pad = symmetric_pad
4256

43-
if not isinstance(self.size, (int, tuple, list)):
44-
raise AssertionError("size should be either a tuple, a list or an int")
45-
4657
def forward(
4758
self,
4859
img: torch.Tensor,
4960
target: np.ndarray | None = None,
5061
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
51-
if isinstance(self.size, int):
52-
target_ratio = img.shape[-2] / img.shape[-1]
53-
else:
54-
target_ratio = self.size[0] / self.size[1]
62+
target_ratio = self.size[0] / self.size[1]
5563
actual_ratio = img.shape[-2] / img.shape[-1]
5664

57-
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio and (isinstance(self.size, (tuple, list)))):
65+
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
5866
# If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one
5967
# We can use with the regular resize
6068
if target is not None:
6169
return super().forward(img), target
6270
return super().forward(img)
6371
else:
6472
# Resize
65-
if isinstance(self.size, (tuple, list)):
66-
if actual_ratio > target_ratio:
67-
tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
68-
else:
69-
tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
70-
elif isinstance(self.size, int): # self.size is the longest side, infer the other
71-
if img.shape[-2] <= img.shape[-1]:
72-
tmp_size = (max(int(self.size * actual_ratio), 1), self.size)
73-
else:
74-
tmp_size = (self.size, max(int(self.size / actual_ratio), 1))
73+
if actual_ratio > target_ratio:
74+
tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
75+
else:
76+
tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
7577

7678
# Scale image
7779
img = F.resize(img, tmp_size, self.interpolation, antialias=True)
@@ -93,14 +95,14 @@ def forward(
9395
if self.preserve_aspect_ratio:
9496
# Get absolute coords
9597
if target.shape[1:] == (4,):
96-
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
98+
if self.symmetric_pad:
9799
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
98100
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
99101
else:
100102
target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
101103
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
102104
elif target.shape[1:] == (4, 2):
103-
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
105+
if self.symmetric_pad:
104106
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
105107
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
106108
else:

tests/pytorch/test_transforms_pt.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,60 +27,77 @@ def test_resize():
2727

2828
assert torch.all(out == 1)
2929
assert out.shape[-2:] == output_size
30-
assert repr(transfo) == f"Resize(output_size={output_size}, interpolation='bilinear')"
30+
assert repr(transfo) == "Resize(output_size=(32, 32), interpolation='bilinear')"
3131

32-
transfo = Resize(output_size, preserve_aspect_ratio=True)
32+
# Test with preserve_aspect_ratio
33+
output_size = (32, 32)
3334
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
34-
out = transfo(input_t)
3535

36+
# Asymmetric padding
37+
transfo = Resize(output_size, preserve_aspect_ratio=True)
38+
out = transfo(input_t)
3639
assert out.shape[-2:] == output_size
3740
assert not torch.all(out == 1)
38-
# Asymetric padding
3941
assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 1)
4042

41-
# Symetric padding
42-
transfo = Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True)
43-
assert repr(transfo) == (
44-
f"Resize(output_size={output_size}, interpolation='bilinear', preserve_aspect_ratio=True, symmetric_pad=True)"
45-
)
43+
# Symmetric padding
44+
transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True)
4645
out = transfo(input_t)
4746
assert out.shape[-2:] == output_size
48-
# symetric padding
49-
assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 0)
47+
assert torch.all(out[:, 0] == 0) and torch.all(out[:, -1] == 0)
5048

51-
# Inverse aspect ratio
49+
expected = "Resize(output_size=(32, 32), interpolation='bilinear', preserve_aspect_ratio=True, symmetric_pad=True)"
50+
assert repr(transfo) == expected
51+
52+
# Test with inverse resize
5253
input_t = torch.ones((3, 64, 32), dtype=torch.float32)
54+
transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True)
5355
out = transfo(input_t)
56+
assert out.shape[-2:] == (32, 32)
5457

55-
assert not torch.all(out == 1)
56-
assert out.shape[-2:] == output_size
57-
58-
# Same aspect ratio
59-
output_size = (32, 128)
60-
transfo = Resize(output_size, preserve_aspect_ratio=True)
58+
# Test resize with same ratio
59+
transfo = Resize((32, 128), preserve_aspect_ratio=True)
6160
out = transfo(torch.ones((3, 16, 64), dtype=torch.float32))
62-
assert out.shape[-2:] == output_size
61+
assert out.shape[-2:] == (32, 128)
6362

64-
# FP16
63+
# Test with fp16 input
64+
transfo = Resize((32, 128), preserve_aspect_ratio=True)
6565
input_t = torch.ones((3, 64, 64), dtype=torch.float16)
6666
out = transfo(input_t)
6767
assert out.dtype == torch.float16
6868

69-
# --- Test with target (bounding boxes) ---
70-
71-
target_boxes = np.array([[0.1, 0.1, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]])
72-
output_size = (64, 64)
73-
74-
transfo = Resize(output_size, preserve_aspect_ratio=True)
69+
padding = [True, False]
70+
for symmetric_pad in padding:
71+
# Test with target boxes
72+
target_boxes = np.array([[0.1, 0.1, 0.3, 0.4], [0.2, 0.2, 0.8, 0.8]])
73+
transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad)
74+
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
75+
out, new_target = transfo(input_t, target_boxes)
76+
77+
assert out.shape[-2:] == (64, 64)
78+
assert new_target.shape == target_boxes.shape
79+
assert np.all((0 <= new_target) & (new_target <= 1))
80+
81+
# Test with target polygons
82+
target_boxes = np.array([
83+
[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9], [0.1, 0.9]],
84+
[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8], [0.2, 0.8]],
85+
])
86+
transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad)
87+
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
88+
out, new_target = transfo(input_t, target_boxes)
89+
90+
assert out.shape[-2:] == (64, 64)
91+
assert new_target.shape == target_boxes.shape
92+
assert np.all((0 <= new_target) & (new_target <= 1))
93+
94+
# Test with invalid target shape
7595
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
76-
out, new_target = transfo(input_t, target_boxes)
96+
target = np.ones((2, 5)) # Invalid shape
7797

78-
assert out.shape[-2:] == output_size
79-
assert new_target.shape == target_boxes.shape
80-
assert np.all(new_target >= 0) and np.all(new_target <= 1)
81-
82-
out = transfo(input_t)
83-
assert out.shape[-2:] == output_size
98+
transfo = Resize((64, 64), preserve_aspect_ratio=True)
99+
with pytest.raises(AssertionError):
100+
transfo(input_t, target)
84101

85102

86103
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)