Skip to content

Commit 607123b

Browse files
HoVerNet Mode and Branch to independent StrEnum (#5219)
Fixes #5218 ### Description This PR moves `HoVerNet` Mode and Branch outside of the module and make them independent `StrEnums`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. Signed-off-by: Behrooz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6528014 commit 607123b

File tree

4 files changed

+47
-38
lines changed

4 files changed

+47
-38
lines changed

monai/networks/nets/hovernet.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
# =========================================================================
2929

3030
from collections import OrderedDict
31-
from enum import Enum
3231
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
3332

3433
import torch
@@ -37,8 +36,8 @@
3736
from monai.networks.blocks import UpSample
3837
from monai.networks.layers.factories import Conv, Dropout
3938
from monai.networks.layers.utils import get_act_layer, get_norm_layer
40-
from monai.utils import InterpolateMode, UpsampleMode, export
41-
from monai.utils.enums import StrEnum
39+
from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode
40+
from monai.utils.module import export, look_up_option
4241

4342
__all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"]
4443

@@ -380,40 +379,21 @@ class HoVerNet(nn.Module):
380379
Medical Image Analysis 2019
381380
382381
Args:
382+
mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
383+
a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`.
383384
in_channels: number of the input channel.
384385
out_classes: number of the nuclear type classes.
385386
act: activation type and arguments. Defaults to relu.
386387
norm: feature normalization type and arguments. Defaults to batch norm.
387388
dropout_prob: dropout rate after each dense layer.
388389
"""
389390

390-
class Mode(Enum):
391-
FAST: int = 0
392-
ORIGINAL: int = 1
393-
394-
class Branch(StrEnum):
395-
"""
396-
Three branches of HoVerNet model, which results in three outputs:
397-
`HOVER` is horizontal and vertical regressed gradient map of each nucleus,
398-
`NUCLEUS` is the segmentation of all nuclei, and
399-
`TYPE` is the type of each nucleus.
400-
401-
"""
402-
403-
HV = "horizontal_vertical"
404-
NP = "nucleus_prediction"
405-
NC = "type_prediction"
406-
407-
def _mode_to_int(self, mode) -> int:
408-
409-
if mode == self.Mode.FAST:
410-
return 0
411-
else:
412-
return 1
391+
Mode = HoVerNetMode
392+
Branch = HoVerNetBranch
413393

414394
def __init__(
415395
self,
416-
mode: Mode = Mode.FAST,
396+
mode: Union[HoVerNetMode, str] = HoVerNetMode.FAST,
417397
in_channels: int = 3,
418398
out_classes: int = 0,
419399
act: Union[str, tuple] = ("relu", {"inplace": True}),
@@ -423,10 +403,9 @@ def __init__(
423403

424404
super().__init__()
425405

426-
self.mode: int = self._mode_to_int(mode)
427-
428-
if mode not in [self.Mode.ORIGINAL, self.Mode.FAST]:
429-
raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL")
406+
if isinstance(mode, str):
407+
mode = mode.upper()
408+
self.mode = look_up_option(mode, HoVerNetMode)
430409

431410
if out_classes > 128:
432411
raise ValueError("Number of nuclear types classes exceeds maximum (128)")
@@ -441,7 +420,7 @@ def __init__(
441420
# number of layers in each pooling block.
442421
_block_config: Sequence[int] = (3, 4, 6, 3)
443422

444-
if mode == self.Mode.FAST:
423+
if self.mode == HoVerNetMode.FAST:
445424
_ksize = 3
446425
_pad = 3
447426
else:
@@ -510,12 +489,12 @@ def __init__(
510489

511490
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
512491

513-
if self.mode == 1:
492+
if self.mode == HoVerNetMode.ORIGINAL.value:
514493
if x.shape[-1] != 270 or x.shape[-2] != 270:
515-
raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL")
494+
raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL")
516495
else:
517496
if x.shape[-1] != 256 or x.shape[-2] != 256:
518-
raise ValueError("Input size should be 256 x 256 when using Mode.FAST")
497+
raise ValueError("Input size should be 256 x 256 when using HoVerNetMode.FAST")
519498

520499
x = x / 255.0 # to 0-1 range to match XY
521500
x = self.input_features(x)
@@ -531,11 +510,11 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
531510
x = self.upsample(x)
532511

533512
output = {
534-
HoVerNet.Branch.NP.value: self.nucleus_prediction(x, short_cuts),
535-
HoVerNet.Branch.HV.value: self.horizontal_vertical(x, short_cuts),
513+
HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts),
514+
HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts),
536515
}
537516
if self.type_prediction is not None:
538-
output[HoVerNet.Branch.NC.value] = self.type_prediction(x, short_cuts)
517+
output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts)
539518

540519
return output
541520

monai/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
GridPatchSort,
3030
GridSampleMode,
3131
GridSamplePadMode,
32+
HoVerNetBranch,
33+
HoVerNetMode,
3234
InterpolateMode,
3335
InverseKeys,
3436
JITMetadataKeys,

monai/utils/enums.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
"ImageStatsKeys",
5353
"LabelStatsKeys",
5454
"AlgoEnsembleKeys",
55+
"HoVerNetMode",
56+
"HoVerNetBranch",
5557
]
5658

5759

@@ -587,3 +589,28 @@ class AlgoEnsembleKeys(StrEnum):
587589
ID = "identifier"
588590
ALGO = "infer_algo"
589591
SCORE = "best_metric"
592+
593+
594+
class HoVerNetMode(StrEnum):
595+
"""
596+
Modes for HoVerNet model:
597+
`FAST`: a faster implementation (than original)
598+
`ORIGINAL`: the original implementation
599+
"""
600+
601+
FAST = "FAST"
602+
ORIGINAL = "ORIGINAL"
603+
604+
605+
class HoVerNetBranch(StrEnum):
606+
"""
607+
Three branches of HoVerNet model, which results in three outputs:
608+
`HV` is horizontal and vertical gradient map of each nucleus (regression),
609+
`NP` is the pixel prediction of all nuclei (segmentation), and
610+
`NC` is the type of each nucleus (classification).
611+
612+
"""
613+
614+
HV = "horizontal_vertical"
615+
NP = "nucleus_prediction"
616+
NC = "type_prediction"

tests/test_hovernet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
ILL_CASES = [
5656
[{"out_classes": 6, "mode": 3}],
57+
[{"out_classes": 6, "mode": "Wrong"}],
5758
[{"out_classes": 1000, "mode": HoVerNet.Mode.ORIGINAL}],
5859
[{"out_classes": 1, "mode": HoVerNet.Mode.ORIGINAL}],
5960
[{"out_classes": 6, "mode": HoVerNet.Mode.ORIGINAL, "dropout_prob": 100}],

0 commit comments

Comments
 (0)