Skip to content

Commit 06e35fc

Browse files
igorsugakfacebook-github-bot
authored andcommitted
replace uses of np.ndarray with npt.NDArray
Summary: X-link: pytorch/opacus#681 X-link: pytorch/captum#1389 X-link: pytorch/botorch#2586 X-link: pytorch/audio#3846 This replaces uses of `numpy.ndarray` in type annotations with `numpy.typing.NDArray`. In Numpy-1.24.0+ `numpy.ndarray` is annotated as generic type. Without template parameters it triggers static analysis errors: ```counterexample Generic type `ndarray` expects 2 type parameters. ``` `numpy.typing.NDArray` is an alias that provides default template parameters. Reviewed By: ryanthomasjohnson Differential Revision: D64619891 fbshipit-source-id: dffc096b1ce90d11e73d475f0bbcb8867ed9ef01
1 parent e737b8f commit 06e35fc

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

torchbenchmark/models/pytorch_unet/pytorch_unet/predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44

55
import numpy as np
6+
import numpy.typing as npt
67
import torch
78
import torch.nn.functional as F
89
from PIL import Image
@@ -102,7 +103,7 @@ def _generate_name(fn):
102103
return args.output or list(map(_generate_name, args.input))
103104

104105

105-
def mask_to_image(mask: np.ndarray):
106+
def mask_to_image(mask: npt.NDArray):
106107
if mask.ndim == 2:
107108
return Image.fromarray((mask * 255).astype(np.uint8))
108109
elif mask.ndim == 3:

torchbenchmark/models/sam/predictor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Tuple
88

99
import numpy as np
10+
import numpy.typing as npt
1011
import torch
1112

1213
from .sam import Sam
@@ -32,7 +33,7 @@ def __init__(
3233

3334
def set_image(
3435
self,
35-
image: np.ndarray,
36+
image: npt.NDArray,
3637
image_format: str = "RGB",
3738
) -> None:
3839
"""
@@ -92,13 +93,13 @@ def set_torch_image(
9293

9394
def predict(
9495
self,
95-
point_coords: Optional[np.ndarray] = None,
96-
point_labels: Optional[np.ndarray] = None,
97-
box: Optional[np.ndarray] = None,
98-
mask_input: Optional[np.ndarray] = None,
96+
point_coords: Optional[npt.NDArray] = None,
97+
point_labels: Optional[npt.NDArray] = None,
98+
box: Optional[npt.NDArray] = None,
99+
mask_input: Optional[npt.NDArray] = None,
99100
multimask_output: bool = True,
100101
return_logits: bool = False,
101-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102+
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
102103
"""
103104
Predict masks for the given input prompts, using the currently set image.
104105

torchbenchmark/models/sam/transforms.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Tuple
99

1010
import numpy as np
11+
import numpy.typing as npt
1112
import torch
1213
from torch.nn import functional as F
1314
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
@@ -23,7 +24,7 @@ class ResizeLongestSide:
2324
def __init__(self, target_length: int) -> None:
2425
self.target_length = target_length
2526

26-
def apply_image(self, image: np.ndarray) -> np.ndarray:
27+
def apply_image(self, image: npt.NDArray) -> npt.NDArray:
2728
"""
2829
Expects a numpy array with shape HxWxC in uint8 format.
2930
"""
@@ -33,8 +34,8 @@ def apply_image(self, image: np.ndarray) -> np.ndarray:
3334
return np.array(resize(to_pil_image(image), target_size))
3435

3536
def apply_coords(
36-
self, coords: np.ndarray, original_size: Tuple[int, ...]
37-
) -> np.ndarray:
37+
self, coords: npt.NDArray, original_size: Tuple[int, ...]
38+
) -> npt.NDArray:
3839
"""
3940
Expects a numpy array of length 2 in the final dimension. Requires the
4041
original image size in (H, W) format.
@@ -49,8 +50,8 @@ def apply_coords(
4950
return coords
5051

5152
def apply_boxes(
52-
self, boxes: np.ndarray, original_size: Tuple[int, ...]
53-
) -> np.ndarray:
53+
self, boxes: npt.NDArray, original_size: Tuple[int, ...]
54+
) -> npt.NDArray:
5455
"""
5556
Expects a numpy array shape Bx4. Requires the original image size
5657
in (H, W) format.

0 commit comments

Comments
 (0)