Skip to content

Commit df6bc9c

Browse files
authored
5782 Flexible interp modes in regunet (#5807)
Signed-off-by: Wenqi Li <[email protected]> Fixes #5782 ### Description - adds 'mode' and 'align_corners' options to the blocks and nets - fixes a few typos ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent 315d2d2 commit df6bc9c

File tree

7 files changed

+65
-14
lines changed

7 files changed

+65
-14
lines changed

.github/workflows/conda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
os: [windows-latest, ubuntu-latest]
21+
os: [ubuntu-latest]
2222
python-version: ["3.9"]
2323
runs-on: ${{ matrix.os }}
2424
env:

monai/networks/blocks/localnet_block.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
166166

167167
class LocalNetUpSampleBlock(nn.Module):
168168
"""
169-
A up-sample module that can be used for LocalNet, based on:
169+
An up-sample module that can be used for LocalNet, based on:
170170
`Weakly-supervised convolutional neural networks for multimodal image registration
171171
<https://doi.org/10.1016/j.media.2018.07.002>`_.
172172
`Label-driven weakly-supervised learning for multimodal deformable image registration
@@ -176,12 +176,21 @@ class LocalNetUpSampleBlock(nn.Module):
176176
DeepReg (https://github.com/DeepRegNet/DeepReg)
177177
"""
178178

179-
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
179+
def __init__(
180+
self,
181+
spatial_dims: int,
182+
in_channels: int,
183+
out_channels: int,
184+
mode: str = "nearest",
185+
align_corners: Optional[bool] = None,
186+
) -> None:
180187
"""
181188
Args:
182189
spatial_dims: number of spatial dimensions.
183190
in_channels: number of input channels.
184191
out_channels: number of output channels.
192+
mode: interpolation mode of the additive upsampling, default to 'nearest'.
193+
align_corners: whether to align corners for the additive upsampling, default to None.
185194
Raises:
186195
ValueError: when ``in_channels != 2 * out_channels``
187196
"""
@@ -199,9 +208,11 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> No
199208
f"got in_channels={in_channels}, out_channels={out_channels}"
200209
)
201210
self.out_channels = out_channels
211+
self.mode = mode
212+
self.align_corners = align_corners
202213

203-
def addictive_upsampling(self, x, mid) -> torch.Tensor:
204-
x = F.interpolate(x, mid.shape[2:])
214+
def additive_upsampling(self, x, mid) -> torch.Tensor:
215+
x = F.interpolate(x, mid.shape[2:], mode=self.mode, align_corners=self.align_corners)
205216
# [(batch, out_channels, ...), (batch, out_channels, ...)]
206217
x = x.split(split_size=int(self.out_channels), dim=1)
207218
# (batch, out_channels, ...)
@@ -226,7 +237,7 @@ def forward(self, x, mid) -> torch.Tensor:
226237
"expecting mid spatial dimensions be exactly the double of x spatial dimensions, "
227238
f"got x of shape {x.shape}, mid of shape {mid.shape}"
228239
)
229-
h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid)
240+
h0 = self.deconv_block(x) + self.additive_upsampling(x, mid)
230241
r1 = h0 + mid
231242
r2 = self.conv_block(h0)
232243
out: torch.Tensor = self.residual_block(r2, r1)

monai/networks/blocks/regunet_block.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def __init__(
200200
out_channels: int,
201201
kernel_initializer: Optional[str] = "kaiming_uniform",
202202
activation: Optional[str] = None,
203+
mode: str = "nearest",
204+
align_corners: Optional[bool] = None,
203205
):
204206
"""
205207
@@ -211,6 +213,8 @@ def __init__(
211213
out_channels: number of output channels
212214
kernel_initializer: kernel initializer
213215
activation: kernel activation function
216+
mode: feature map interpolation mode, default to "nearest".
217+
align_corners: whether to align corners for feature map interpolation.
214218
"""
215219
super().__init__()
216220
self.extract_levels = extract_levels
@@ -228,6 +232,8 @@ def __init__(
228232
for d in extract_levels
229233
]
230234
)
235+
self.mode = mode
236+
self.align_corners = align_corners
231237

232238
def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
233239
"""
@@ -240,7 +246,9 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
240246
Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size``
241247
"""
242248
feature_list = [
243-
F.interpolate(layer(x[self.max_level - level]), size=image_size)
249+
F.interpolate(
250+
layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners
251+
)
244252
for layer, level in zip(self.layers, self.extract_levels)
245253
]
246254
out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0)

monai/networks/nets/regunet.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,23 @@ def build_output_block(self):
337337

338338

339339
class AdditiveUpSampleBlock(nn.Module):
340-
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int):
340+
def __init__(
341+
self,
342+
spatial_dims: int,
343+
in_channels: int,
344+
out_channels: int,
345+
mode: str = "nearest",
346+
align_corners: Optional[bool] = None,
347+
):
341348
super().__init__()
342349
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
350+
self.mode = mode
351+
self.align_corners = align_corners
343352

344353
def forward(self, x: torch.Tensor) -> torch.Tensor:
345354
output_size = [size * 2 for size in x.shape[2:]]
346355
deconved = self.deconv(x)
347-
resized = F.interpolate(x, output_size)
356+
resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners)
348357
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
349358
out: torch.Tensor = deconved + resized
350359
return out
@@ -372,8 +381,10 @@ def __init__(
372381
out_activation: Optional[str] = None,
373382
out_channels: int = 3,
374383
pooling: bool = True,
375-
use_addictive_sampling: bool = True,
384+
use_additive_sampling: bool = True,
376385
concat_skip: bool = False,
386+
mode: str = "nearest",
387+
align_corners: Optional[bool] = None,
377388
):
378389
"""
379390
Args:
@@ -385,10 +396,14 @@ def __init__(
385396
out_channels: number of channels for the output
386397
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
387398
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
388-
use_addictive_sampling: whether use additive up-sampling layer for decoding.
399+
use_additive_sampling: whether use additive up-sampling layer for decoding.
389400
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
401+
mode: mode for interpolation when use_additive_sampling, default is "nearest".
402+
align_corners: align_corners for interpolation when use_additive_sampling, default is None.
390403
"""
391-
self.use_additive_upsampling = use_addictive_sampling
404+
self.use_additive_upsampling = use_additive_sampling
405+
self.mode = mode
406+
self.align_corners = align_corners
392407
super().__init__(
393408
spatial_dims=spatial_dims,
394409
in_channels=in_channels,
@@ -412,7 +427,11 @@ def build_bottom_block(self, in_channels: int, out_channels: int):
412427
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
413428
if self.use_additive_upsampling:
414429
return AdditiveUpSampleBlock(
415-
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels
430+
spatial_dims=self.spatial_dims,
431+
in_channels=in_channels,
432+
out_channels=out_channels,
433+
mode=self.mode,
434+
align_corners=self.align_corners,
416435
)
417436

418437
return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)

tests/test_localnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
"extract_levels": (0, 1),
3333
"pooling": False,
3434
"concat_skip": True,
35+
"mode": "bilinear",
36+
"align_corners": True,
3537
},
3638
(1, 2, 16, 16),
3739
(1, 2, 16, 16),

tests/test_localnet_block.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,17 @@
2525
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3]
2626
]
2727

28-
TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]]
28+
TEST_CASE_UP_SAMPLE = [
29+
[
30+
{
31+
"spatial_dims": spatial_dims,
32+
"in_channels": 4,
33+
"out_channels": 2,
34+
"mode": "bilinear" if spatial_dims == 2 else "trilinear",
35+
}
36+
]
37+
for spatial_dims in [2, 3]
38+
]
2939

3040
TEST_CASE_EXTRACT = [
3141
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}]

tests/test_regunet_block.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"out_channels": 1,
5454
"kernel_initializer": "zeros",
5555
"activation": "sigmoid",
56+
"mode": "trilinear",
5657
},
5758
[(1, 3, 2, 2, 2), (1, 2, 4, 4, 4), (1, 1, 8, 8, 8)],
5859
(3, 3, 3),

0 commit comments

Comments
 (0)