2828# =========================================================================
2929
3030from collections import OrderedDict
31- from enum import Enum
3231from typing import Callable , Dict , List , Optional , Sequence , Type , Union
3332
3433import torch
3736from monai .networks .blocks import UpSample
3837from monai .networks .layers .factories import Conv , Dropout
3938from monai .networks .layers .utils import get_act_layer , get_norm_layer
40- from monai .utils import InterpolateMode , UpsampleMode , export
41- from monai .utils .enums import StrEnum
39+ from monai .utils . enums import HoVerNetBranch , HoVerNetMode , InterpolateMode , UpsampleMode
40+ from monai .utils .module import export , look_up_option
4241
4342__all__ = ["HoVerNet" , "Hovernet" , "HoVernet" , "HoVerNet" ]
4443
@@ -380,40 +379,21 @@ class HoVerNet(nn.Module):
380379 Medical Image Analysis 2019
381380
382381 Args:
382+ mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
383+ a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`.
383384 in_channels: number of the input channel.
384385 out_classes: number of the nuclear type classes.
385386 act: activation type and arguments. Defaults to relu.
386387 norm: feature normalization type and arguments. Defaults to batch norm.
387388 dropout_prob: dropout rate after each dense layer.
388389 """
389390
390- class Mode (Enum ):
391- FAST : int = 0
392- ORIGINAL : int = 1
393-
394- class Branch (StrEnum ):
395- """
396- Three branches of HoVerNet model, which results in three outputs:
397- `HOVER` is horizontal and vertical regressed gradient map of each nucleus,
398- `NUCLEUS` is the segmentation of all nuclei, and
399- `TYPE` is the type of each nucleus.
400-
401- """
402-
403- HV = "horizontal_vertical"
404- NP = "nucleus_prediction"
405- NC = "type_prediction"
406-
407- def _mode_to_int (self , mode ) -> int :
408-
409- if mode == self .Mode .FAST :
410- return 0
411- else :
412- return 1
391+ Mode = HoVerNetMode
392+ Branch = HoVerNetBranch
413393
414394 def __init__ (
415395 self ,
416- mode : Mode = Mode .FAST ,
396+ mode : Union [ HoVerNetMode , str ] = HoVerNetMode .FAST ,
417397 in_channels : int = 3 ,
418398 out_classes : int = 0 ,
419399 act : Union [str , tuple ] = ("relu" , {"inplace" : True }),
@@ -423,10 +403,9 @@ def __init__(
423403
424404 super ().__init__ ()
425405
426- self .mode : int = self ._mode_to_int (mode )
427-
428- if mode not in [self .Mode .ORIGINAL , self .Mode .FAST ]:
429- raise ValueError ("Input size should be 270 x 270 when using Mode.ORIGINAL" )
406+ if isinstance (mode , str ):
407+ mode = mode .upper ()
408+ self .mode = look_up_option (mode , HoVerNetMode )
430409
431410 if out_classes > 128 :
432411 raise ValueError ("Number of nuclear types classes exceeds maximum (128)" )
@@ -441,7 +420,7 @@ def __init__(
441420 # number of layers in each pooling block.
442421 _block_config : Sequence [int ] = (3 , 4 , 6 , 3 )
443422
444- if mode == self . Mode .FAST :
423+ if self . mode == HoVerNetMode .FAST :
445424 _ksize = 3
446425 _pad = 3
447426 else :
@@ -510,12 +489,12 @@ def __init__(
510489
511490 def forward (self , x : torch .Tensor ) -> Dict [str , torch .Tensor ]:
512491
513- if self .mode == 1 :
492+ if self .mode == HoVerNetMode . ORIGINAL . value :
514493 if x .shape [- 1 ] != 270 or x .shape [- 2 ] != 270 :
515- raise ValueError ("Input size should be 270 x 270 when using Mode .ORIGINAL" )
494+ raise ValueError ("Input size should be 270 x 270 when using HoVerNetMode .ORIGINAL" )
516495 else :
517496 if x .shape [- 1 ] != 256 or x .shape [- 2 ] != 256 :
518- raise ValueError ("Input size should be 256 x 256 when using Mode .FAST" )
497+ raise ValueError ("Input size should be 256 x 256 when using HoVerNetMode .FAST" )
519498
520499 x = x / 255.0 # to 0-1 range to match XY
521500 x = self .input_features (x )
@@ -531,11 +510,11 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
531510 x = self .upsample (x )
532511
533512 output = {
534- HoVerNet . Branch .NP .value : self .nucleus_prediction (x , short_cuts ),
535- HoVerNet . Branch .HV .value : self .horizontal_vertical (x , short_cuts ),
513+ HoVerNetBranch .NP .value : self .nucleus_prediction (x , short_cuts ),
514+ HoVerNetBranch .HV .value : self .horizontal_vertical (x , short_cuts ),
536515 }
537516 if self .type_prediction is not None :
538- output [HoVerNet . Branch .NC .value ] = self .type_prediction (x , short_cuts )
517+ output [HoVerNetBranch .NC .value ] = self .type_prediction (x , short_cuts )
539518
540519 return output
541520
0 commit comments