Skip to content

Commit 796271c

Browse files
ctonghpre-commit-ci[bot]ericspodKbinn
authored
Add options to skip operations for RestoreLabeld Transform (#8125)
Fixes #6380 ### Description Four new bool parameters are added into `RestoreLabeld` to allow users to selectively enable or disable each restoration operation as needed, and a corresponding test case is added to verify that the function runs correctly. This design allows users to selectively enable or disable each restoration operation as needed, providing greater flexibility. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Hsin Tong <[email protected]> Signed-off-by: Hsin-Tong Hsieh <[email protected]> Signed-off-by: kbbbbkb <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: kbbbbkb <[email protected]>
1 parent 76ef9f4 commit 796271c

File tree

2 files changed

+140
-24
lines changed

2 files changed

+140
-24
lines changed

monai/apps/deepgrow/transforms.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,14 @@ class RestoreLabeld(MapTransform):
803803
original_shape_key: key that records original shape for foreground.
804804
cropped_shape_key: key that records cropped shape for foreground.
805805
allow_missing_keys: don't raise exception if key is missing.
806+
restore_resizing: used to enable or disable resizing restoration, default is True.
807+
If True, the transform will resize the items back to its original shape.
808+
restore_cropping: used to enable or disable cropping restoration, default is True.
809+
If True, the transform will restore the items to its uncropped size.
810+
restore_spacing: used to enable or disable spacing restoration, default is True.
811+
If True, the transform will resample the items back to the spacing it had before being altered.
812+
restore_slicing: used to enable or disable slicing restoration, default is True.
813+
If True, the transform will reassemble the full volume by restoring the slices to their original positions.
806814
"""
807815

808816
def __init__(
@@ -819,6 +827,10 @@ def __init__(
819827
original_shape_key: str = "foreground_original_shape",
820828
cropped_shape_key: str = "foreground_cropped_shape",
821829
allow_missing_keys: bool = False,
830+
restore_resizing: bool = True,
831+
restore_cropping: bool = True,
832+
restore_spacing: bool = True,
833+
restore_slicing: bool = True,
822834
) -> None:
823835
super().__init__(keys, allow_missing_keys)
824836
self.ref_image = ref_image
@@ -833,6 +845,10 @@ def __init__(
833845
self.end_coord_key = end_coord_key
834846
self.original_shape_key = original_shape_key
835847
self.cropped_shape_key = cropped_shape_key
848+
self.restore_resizing = restore_resizing
849+
self.restore_cropping = restore_cropping
850+
self.restore_spacing = restore_spacing
851+
self.restore_slicing = restore_slicing
836852

837853
def __call__(self, data: Any) -> dict:
838854
d = dict(data)
@@ -842,38 +858,45 @@ def __call__(self, data: Any) -> dict:
842858
image = d[key]
843859

844860
# Undo Resize
845-
current_shape = image.shape
846-
cropped_shape = meta_dict[self.cropped_shape_key]
847-
if np.any(np.not_equal(current_shape, cropped_shape)):
848-
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
849-
image = resizer(image, mode=mode, align_corners=align_corners)
861+
if self.restore_resizing:
862+
current_shape = image.shape
863+
cropped_shape = meta_dict[self.cropped_shape_key]
864+
if np.any(np.not_equal(current_shape, cropped_shape)):
865+
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
866+
image = resizer(image, mode=mode, align_corners=align_corners)
850867

851868
# Undo Crop
852-
original_shape = meta_dict[self.original_shape_key]
853-
result = np.zeros(original_shape, dtype=np.float32)
854-
box_start = meta_dict[self.start_coord_key]
855-
box_end = meta_dict[self.end_coord_key]
856-
857-
spatial_dims = min(len(box_start), len(image.shape[1:]))
858-
slices = tuple(
859-
[slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
860-
)
861-
result[slices] = image
869+
if self.restore_cropping:
870+
original_shape = meta_dict[self.original_shape_key]
871+
result = np.zeros(original_shape, dtype=np.float32)
872+
box_start = meta_dict[self.start_coord_key]
873+
box_end = meta_dict[self.end_coord_key]
874+
875+
spatial_dims = min(len(box_start), len(image.shape[1:]))
876+
slices = tuple(
877+
[slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
878+
)
879+
result[slices] = image
880+
else:
881+
result = image
862882

863883
# Undo Spacing
864-
current_size = result.shape[1:]
865-
# change spatial_shape from HWD to DHW
866-
spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
867-
spatial_size = spatial_shape[-len(current_size) :]
884+
if self.restore_spacing:
885+
current_size = result.shape[1:]
886+
# change spatial_shape from HWD to DHW
887+
spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
888+
spatial_size = spatial_shape[-len(current_size) :]
868889

869-
if np.any(np.not_equal(current_size, spatial_size)):
870-
resizer = Resize(spatial_size=spatial_size, mode=mode)
871-
result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore
890+
if np.any(np.not_equal(current_size, spatial_size)):
891+
resizer = Resize(spatial_size=spatial_size, mode=mode)
892+
result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore
872893

873894
# Undo Slicing
874895
slice_idx = meta_dict.get("slice_idx")
875896
final_result: NdarrayOrTensor
876-
if slice_idx is None or self.slice_only:
897+
if not self.restore_slicing: # do nothing if restore slicing isn't requested
898+
final_result = result
899+
elif slice_idx is None or self.slice_only:
877900
final_result = result if len(result.shape) <= 3 else result[0]
878901
else:
879902
slice_idx = meta_dict["slice_idx"][0]

tests/test_deepgrow_transforms.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,21 @@
141141

142142
DATA_12 = {"image": np.arange(27).reshape(3, 3, 3), PostFix.meta("image"): {}, "guidance": [[0, 0, 0], [0, 1, 1], 1]}
143143

144+
DATA_13 = {
145+
"image": np.arange(64).reshape((1, 4, 4, 4)),
146+
PostFix.meta("image"): {
147+
"spatial_shape": [8, 8, 4],
148+
"foreground_start_coord": np.array([1, 1, 1]),
149+
"foreground_end_coord": np.array([3, 3, 3]),
150+
"foreground_original_shape": (1, 4, 4, 4),
151+
"foreground_cropped_shape": (1, 2, 2, 2),
152+
"original_affine": np.array(
153+
[[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]
154+
),
155+
},
156+
"pred": np.array([[[[10, 20], [30, 40]], [[50, 60], [70, 80]]]]),
157+
}
158+
144159
FIND_SLICE_TEST_CASE_1 = [{"label": "label", "sids": "sids"}, DATA_1, [0]]
145160

146161
FIND_SLICE_TEST_CASE_2 = [{"label": "label", "sids": "sids"}, DATA_2, [0, 1]]
@@ -329,6 +344,74 @@
329344

330345
RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT]
331346

347+
RESTORE_LABEL_TEST_CASE_3_RESULT = np.zeros((10, 20, 20))
348+
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 0:10] = 1
349+
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 10:20] = 2
350+
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 0:10] = 3
351+
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 10:20] = 4
352+
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 0:10] = 5
353+
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 10:20] = 6
354+
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 0:10] = 7
355+
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 10:20] = 8
356+
357+
RESTORE_LABEL_TEST_CASE_3 = [
358+
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_cropping": False},
359+
DATA_11,
360+
RESTORE_LABEL_TEST_CASE_3_RESULT,
361+
]
362+
363+
RESTORE_LABEL_TEST_CASE_4_RESULT = np.zeros((4, 8, 8))
364+
RESTORE_LABEL_TEST_CASE_4_RESULT[1, 2:6, 2:6] = np.array(
365+
[[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]
366+
)
367+
RESTORE_LABEL_TEST_CASE_4_RESULT[2, 2:6, 2:6] = np.array(
368+
[[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]
369+
)
370+
371+
RESTORE_LABEL_TEST_CASE_4 = [
372+
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_resizing": False},
373+
DATA_13,
374+
RESTORE_LABEL_TEST_CASE_4_RESULT,
375+
]
376+
377+
RESTORE_LABEL_TEST_CASE_5_RESULT = np.zeros((4, 4, 4))
378+
RESTORE_LABEL_TEST_CASE_5_RESULT[1, 1:3, 1:3] = np.array([[10.0, 20.0], [30.0, 40.0]])
379+
RESTORE_LABEL_TEST_CASE_5_RESULT[2, 1:3, 1:3] = np.array([[50.0, 60.0], [70.0, 80.0]])
380+
381+
RESTORE_LABEL_TEST_CASE_5 = [
382+
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_spacing": False},
383+
DATA_13,
384+
RESTORE_LABEL_TEST_CASE_5_RESULT,
385+
]
386+
387+
RESTORE_LABEL_TEST_CASE_6_RESULT = np.zeros((1, 4, 8, 8))
388+
RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 1, 2:6, 2:6] = np.array(
389+
[[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]
390+
)
391+
RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 2, 2:6, 2:6] = np.array(
392+
[[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]
393+
)
394+
395+
RESTORE_LABEL_TEST_CASE_6 = [
396+
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_slicing": False},
397+
DATA_13,
398+
RESTORE_LABEL_TEST_CASE_6_RESULT,
399+
]
400+
401+
RESTORE_LABEL_TEST_CASE_7 = [
402+
{
403+
"keys": ["pred"],
404+
"ref_image": "image",
405+
"mode": "nearest",
406+
"restore_resizing": False,
407+
"restore_cropping": False,
408+
"restore_spacing": False,
409+
"restore_slicing": False,
410+
},
411+
DATA_11,
412+
np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]),
413+
]
414+
332415
FETCH_2D_SLICE_TEST_CASE_1 = [
333416
{"keys": ["image"], "guidance": "guidance"},
334417
DATA_12,
@@ -445,7 +528,17 @@ def test_correct_results(self, arguments, input_data, expected_result):
445528

446529
class TestRestoreLabeld(unittest.TestCase):
447530

448-
@parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2])
531+
@parameterized.expand(
532+
[
533+
RESTORE_LABEL_TEST_CASE_1,
534+
RESTORE_LABEL_TEST_CASE_2,
535+
RESTORE_LABEL_TEST_CASE_3,
536+
RESTORE_LABEL_TEST_CASE_4,
537+
RESTORE_LABEL_TEST_CASE_5,
538+
RESTORE_LABEL_TEST_CASE_6,
539+
RESTORE_LABEL_TEST_CASE_7,
540+
]
541+
)
449542
def test_correct_results(self, arguments, input_data, expected_result):
450543
result = RestoreLabeld(**arguments)(input_data)
451544
np.testing.assert_allclose(result["pred"], expected_result)

0 commit comments

Comments
 (0)