Skip to content

Commit a0afa60

Browse files
authored
2420 Add image_only option to several spatial transforms (#2421)
* [DLMED] add image_only and unit tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix flake8 issue Signed-off-by: Nic Ma <[email protected]>
1 parent 88fe0fc commit a0afa60

File tree

4 files changed

+47
-10
lines changed

4 files changed

+47
-10
lines changed

monai/transforms/spatial/array.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
8989
align_corners: bool = False,
9090
dtype: DtypeLike = np.float64,
91+
image_only: bool = False,
9192
) -> None:
9293
"""
9394
Args:
@@ -114,13 +115,16 @@ def __init__(
114115
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
115116
If None, use the data type of input data. To be compatible with other modules,
116117
the output data type is always ``np.float32``.
118+
image_only: if True return only the image volume, otherwise return (image, original affine, new affine).
119+
117120
"""
118121
self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)
119122
self.diagonal = diagonal
120123
self.mode: GridSampleMode = GridSampleMode(mode)
121124
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
122125
self.align_corners = align_corners
123126
self.dtype = dtype
127+
self.image_only = image_only
124128

125129
def __call__(
126130
self,
@@ -131,7 +135,7 @@ def __call__(
131135
align_corners: Optional[bool] = None,
132136
dtype: DtypeLike = None,
133137
output_spatial_shape: Optional[np.ndarray] = None,
134-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
138+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
135139
"""
136140
Args:
137141
data_array: in shape (num_channels, H[, W, ...]).
@@ -204,7 +208,8 @@ def __call__(
204208
)
205209
output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore
206210
new_affine = to_affine_nd(affine, new_affine)
207-
return output_data, affine, new_affine
211+
212+
return output_data if self.image_only else (output_data, affine, new_affine)
208213

209214

210215
class Orientation(Transform):
@@ -217,6 +222,7 @@ def __init__(
217222
axcodes: Optional[str] = None,
218223
as_closest_canonical: bool = False,
219224
labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")),
225+
image_only: bool = False,
220226
) -> None:
221227
"""
222228
Args:
@@ -229,6 +235,7 @@ def __init__(
229235
labels: optional, None or sequence of (2,) sequences
230236
(2,) sequences are labels for (beginning, end) of output axis.
231237
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
238+
image_only: if True return only the image volume, otherwise return (image, original affine, new affine).
232239
233240
Raises:
234241
ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values.
@@ -243,10 +250,11 @@ def __init__(
243250
self.axcodes = axcodes
244251
self.as_closest_canonical = as_closest_canonical
245252
self.labels = labels
253+
self.image_only = image_only
246254

247255
def __call__(
248256
self, data_array: np.ndarray, affine: Optional[np.ndarray] = None
249-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
257+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
250258
"""
251259
original orientation of `data_array` is defined by `affine`.
252260
@@ -289,7 +297,8 @@ def __call__(
289297
data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt))
290298
new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape)
291299
new_affine = to_affine_nd(affine, new_affine)
292-
return data_array, affine, new_affine
300+
301+
return data_array if self.image_only else (data_array, affine, new_affine)
293302

294303

295304
class Flip(Transform):
@@ -1270,6 +1279,7 @@ def __init__(
12701279
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION,
12711280
as_tensor_output: bool = False,
12721281
device: Optional[torch.device] = None,
1282+
image_only: bool = False,
12731283
) -> None:
12741284
"""
12751285
The affine transformations are applied in rotate, shear, translate, scale order.
@@ -1296,6 +1306,7 @@ def __init__(
12961306
as_tensor_output: the computation is implemented using pytorch tensors, this option specifies
12971307
whether to convert it back to numpy arrays.
12981308
device: device on which the tensor will be allocated.
1309+
image_only: if True return only the image volume, otherwise return (image, affine).
12991310
"""
13001311
self.affine_grid = AffineGrid(
13011312
rotate_params=rotate_params,
@@ -1305,6 +1316,7 @@ def __init__(
13051316
as_tensor_output=True,
13061317
device=device,
13071318
)
1319+
self.image_only = image_only
13081320
self.resampler = Resample(as_tensor_output=as_tensor_output, device=device)
13091321
self.spatial_size = spatial_size
13101322
self.mode: GridSampleMode = GridSampleMode(mode)
@@ -1316,7 +1328,7 @@ def __call__(
13161328
spatial_size: Optional[Union[Sequence[int], int]] = None,
13171329
mode: Optional[Union[GridSampleMode, str]] = None,
13181330
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
1319-
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
1331+
):
13201332
"""
13211333
Args:
13221334
img: shape must be (num_channels, H, W[, D]),
@@ -1334,10 +1346,9 @@ def __call__(
13341346
"""
13351347
sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:])
13361348
grid, affine = self.affine_grid(spatial_size=sp_size)
1337-
return (
1338-
self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode),
1339-
affine,
1340-
)
1349+
ret = self.resampler(img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode)
1350+
1351+
return ret if self.image_only else (ret, affine)
13411352

13421353

13431354
class RandAffine(RandomizableTransform):

tests/test_affine.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
{"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)},
2424
np.arange(9).reshape(1, 3, 3),
2525
],
26+
[
27+
dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True),
28+
{"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)},
29+
np.arange(9).reshape(1, 3, 3),
30+
],
2631
[
2732
dict(padding_mode="zeros", as_tensor_output=False, device=None),
2833
{"img": np.arange(4).reshape((1, 2, 2))},
@@ -78,7 +83,9 @@ class TestAffine(unittest.TestCase):
7883
@parameterized.expand(TEST_CASES)
7984
def test_affine(self, input_param, input_data, expected_val):
8085
g = Affine(**input_param)
81-
result, _ = g(**input_data)
86+
result = g(**input_data)
87+
if isinstance(result, tuple):
88+
result = result[0]
8289
self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor))
8390
np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)
8491

tests/test_orientation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
np.arange(12).reshape((2, 1, 2, 3)),
2626
"RAS",
2727
],
28+
[
29+
{"axcodes": "RAS", "image_only": True},
30+
np.arange(12).reshape((2, 1, 2, 3)),
31+
{"affine": np.eye(4)},
32+
np.arange(12).reshape((2, 1, 2, 3)),
33+
"RAS",
34+
],
2835
[
2936
{"axcodes": "ALS"},
3037
np.arange(12).reshape((2, 1, 2, 3)),
@@ -114,6 +121,9 @@ class TestOrientationCase(unittest.TestCase):
114121
def test_ornt(self, init_param, img, data_param, expected_data, expected_code):
115122
ornt = Orientation(**init_param)
116123
res = ornt(img, **data_param)
124+
if not isinstance(res, tuple):
125+
np.testing.assert_allclose(res, expected_data)
126+
return
117127
np.testing.assert_allclose(res[0], expected_data)
118128
original_affine = data_param["affine"]
119129
np.testing.assert_allclose(original_affine, res[1])

tests/test_spacing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
{"affine": np.eye(4)},
3131
np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]),
3232
],
33+
[
34+
{"pixdim": 1.0, "padding_mode": "zeros", "dtype": float, "image_only": True},
35+
np.ones((1, 2, 1, 2)), # data
36+
{"affine": np.eye(4)},
37+
np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]),
38+
],
3339
[
3440
{"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float},
3541
np.ones((1, 2, 1, 2)), # data
@@ -145,6 +151,9 @@ class TestSpacingCase(unittest.TestCase):
145151
@parameterized.expand(TEST_CASES)
146152
def test_spacing(self, init_param, img, data_param, expected_output):
147153
res = Spacing(**init_param)(img, **data_param)
154+
if not isinstance(res, tuple):
155+
np.testing.assert_allclose(res, expected_output, atol=1e-6)
156+
return
148157
np.testing.assert_allclose(res[0], expected_output, atol=1e-6)
149158
sr = len(res[0].shape) - 1
150159
if isinstance(init_param["pixdim"], float):

0 commit comments

Comments
 (0)