Skip to content

Commit fdb16ac

Browse files
authored
4731 squeeze dim update meta (#5041)
Signed-off-by: Wenqi Li <[email protected]> Fixes #4731 (additionally fixes Project-MONAI/tutorials#569) ### Description squeeze dim drops a spatial axis should also update the affine ### Status **Ready** ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent 63a198a commit fdb16ac

File tree

5 files changed

+48
-9
lines changed

5 files changed

+48
-9
lines changed

monai/transforms/utility/array.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,11 +603,12 @@ class SqueezeDim(Transform):
603603

604604
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
605605

606-
def __init__(self, dim: Optional[int] = 0) -> None:
606+
def __init__(self, dim: Optional[int] = 0, update_meta=True) -> None:
607607
"""
608608
Args:
609609
dim: dimension to be squeezed. Default = 0
610610
"None" works when the input is numpy array.
611+
update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.
611612
612613
Raises:
613614
TypeError: When ``dim`` is not an ``Optional[int]``.
@@ -616,6 +617,7 @@ def __init__(self, dim: Optional[int] = 0) -> None:
616617
if dim is not None and not isinstance(dim, int):
617618
raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.")
618619
self.dim = dim
620+
self.update_meta = update_meta
619621

620622
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
621623
"""
@@ -624,11 +626,25 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
624626
"""
625627
img = convert_to_tensor(img, track_meta=get_track_meta())
626628
if self.dim is None:
629+
if self.update_meta:
630+
warnings.warn("update_meta=True is ignored when dim=None.")
627631
return img.squeeze()
632+
dim = (self.dim + len(img.shape)) if self.dim < 0 else self.dim
628633
# for pytorch/numpy unification
629-
if img.shape[self.dim] != 1:
630-
raise ValueError(f"Can only squeeze singleton dimension, got shape {img.shape}.")
631-
return img.squeeze(self.dim)
634+
if img.shape[dim] != 1:
635+
raise ValueError(f"Can only squeeze singleton dimension, got shape {img.shape[dim]} of {img.shape}.")
636+
img = img.squeeze(dim)
637+
if self.update_meta and isinstance(img, MetaTensor) and dim > 0 and len(img.affine.shape) == 2:
638+
h, w = img.affine.shape
639+
affine, device = img.affine, img.affine.device if isinstance(img.affine, torch.Tensor) else None
640+
if h > dim:
641+
affine = affine[torch.arange(0, h, device=device) != dim - 1]
642+
if w > dim:
643+
affine = affine[:, torch.arange(0, w, device=device) != dim - 1]
644+
if (affine.shape[0] == affine.shape[1]) and not np.linalg.det(convert_to_numpy(affine, wrap_sequence=True)):
645+
warnings.warn(f"After SqueezeDim, img.affine is ill-posed: \n{img.affine}.")
646+
img.affine = affine
647+
return img
632648

633649

634650
class DataStats(Transform):

monai/transforms/utility/dictionary.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,16 +763,19 @@ class SqueezeDimd(MapTransform):
763763

764764
backend = SqueezeDim.backend
765765

766-
def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None:
766+
def __init__(
767+
self, keys: KeysCollection, dim: int = 0, update_meta: bool = True, allow_missing_keys: bool = False
768+
) -> None:
767769
"""
768770
Args:
769771
keys: keys of the corresponding items to be transformed.
770772
See also: :py:class:`monai.transforms.compose.MapTransform`
771773
dim: dimension to be squeezed. Default: 0 (the first dimension)
774+
update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.
772775
allow_missing_keys: don't raise exception if key is missing.
773776
"""
774777
super().__init__(keys, allow_missing_keys)
775-
self.converter = SqueezeDim(dim=dim)
778+
self.converter = SqueezeDim(dim=dim, update_meta=update_meta)
776779

777780
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
778781
d = dict(data)

monai/utils/jupyter_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Matplotlib produce common plots for metrics and images.
1515
"""
1616

17+
import copy
1718
from enum import Enum
1819
from threading import RLock, Thread
1920
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -340,7 +341,7 @@ def status_dict(self) -> Dict[str, str]:
340341

341342
def status(self) -> str:
342343
"""Returns a status string for the current state of the engine."""
343-
stats = self.status_dict
344+
stats = copy.deepcopy(self.status_dict)
344345

345346
msgs = [stats.pop(StatusMembers.STATUS.value), "Iters: " + str(stats.pop(StatusMembers.ITERS.value, 0))]
346347

tests/test_squeezedim.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import numpy as np
1515
from parameterized import parameterized
1616

17+
from monai.data import MetaTensor
1718
from monai.transforms import SqueezeDim
18-
from tests.utils import TEST_NDARRAYS
19+
from tests.utils import TEST_NDARRAYS, assert_allclose
1920

2021
TESTS, TESTS_FAIL = [], []
2122
for p in TEST_NDARRAYS:
@@ -34,13 +35,28 @@ def test_shape(self, input_param, test_data, expected_shape):
3435

3536
result = SqueezeDim(**input_param)(test_data)
3637
self.assertTupleEqual(result.shape, expected_shape)
38+
if "dim" in input_param and input_param["dim"] == 2 and isinstance(result, MetaTensor):
39+
assert_allclose(result.affine.shape, [3, 3])
3740

3841
@parameterized.expand(TESTS_FAIL)
3942
def test_invalid_inputs(self, exception, input_param, test_data):
4043

4144
with self.assertRaises(exception):
4245
SqueezeDim(**input_param)(test_data)
4346

47+
def test_affine_ill_inputs(self):
48+
img = MetaTensor(
49+
np.random.rand(1, 2, 1, 3),
50+
affine=[
51+
[-0.7422, 0.0, 0.0, 186.3210],
52+
[0.0, 0.0, -3.0, 70.6580],
53+
[0.0, -0.7422, 0.0, 189.4130],
54+
[0.0, 0.0, 0.0, 1.0],
55+
],
56+
)
57+
with self.assertWarns(UserWarning):
58+
SqueezeDim(dim=2)(img)
59+
4460

4561
if __name__ == "__main__":
4662
unittest.main()

tests/test_squeezedimd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import numpy as np
1515
from parameterized import parameterized
1616

17+
from monai.data import MetaTensor
1718
from monai.transforms import SqueezeDimd
18-
from tests.utils import TEST_NDARRAYS
19+
from tests.utils import TEST_NDARRAYS, assert_allclose
1920

2021
TESTS, TESTS_FAIL = [], []
2122
for p in TEST_NDARRAYS:
@@ -82,6 +83,8 @@ def test_shape(self, input_param, test_data, expected_shape):
8283
result = SqueezeDimd(**input_param)(test_data)
8384
self.assertTupleEqual(result["img"].shape, expected_shape)
8485
self.assertTupleEqual(result["seg"].shape, expected_shape)
86+
if "dim" in input_param and isinstance(result["img"], MetaTensor) and input_param["dim"] == 2:
87+
assert_allclose(result["img"].affine.shape, [3, 3])
8588

8689
@parameterized.expand(TESTS_FAIL)
8790
def test_invalid_inputs(self, exception, input_param, test_data):

0 commit comments

Comments
 (0)