Skip to content

Commit 9bed661

Browse files
Expose hidden block arguments
1 parent 52e78d7 commit 9bed661

File tree

2 files changed

+57
-16
lines changed

2 files changed

+57
-16
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3131
"""
3232

33+
from collections.abc import Iterable, Sequence
34+
from typing import Literal
35+
3336
import torch
3437
from torch import nn
3538
from torch.nn import functional as F
@@ -38,9 +41,22 @@
3841

3942

4043
class DeepLabV3Decoder(nn.Sequential):
41-
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
44+
def __init__(
45+
self,
46+
in_channels: int,
47+
out_channels: int,
48+
atrous_rates: Iterable[int],
49+
aspp_separable: bool,
50+
aspp_dropout: float,
51+
):
4252
super().__init__(
43-
ASPP(in_channels, out_channels, atrous_rates),
53+
ASPP(
54+
in_channels,
55+
out_channels,
56+
atrous_rates,
57+
separable=aspp_separable,
58+
dropout=aspp_dropout,
59+
),
4460
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
4561
nn.BatchNorm2d(out_channels),
4662
nn.ReLU(),
@@ -54,10 +70,12 @@ def forward(self, *features):
5470
class DeepLabV3PlusDecoder(nn.Module):
5571
def __init__(
5672
self,
57-
encoder_channels,
58-
out_channels=256,
59-
atrous_rates=(12, 24, 36),
60-
output_stride=16,
73+
encoder_channels: Sequence[int, ...],
74+
out_channels: int,
75+
atrous_rates: Iterable[int],
76+
output_stride: Literal[8, 16],
77+
aspp_separable: bool,
78+
aspp_dropout: float,
6179
):
6280
super().__init__()
6381
if output_stride not in {8, 16}:
@@ -69,7 +87,13 @@ def __init__(
6987
self.output_stride = output_stride
7088

7189
self.aspp = nn.Sequential(
72-
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
90+
ASPP(
91+
encoder_channels[-1],
92+
out_channels,
93+
atrous_rates,
94+
separable=aspp_separable,
95+
dropout=aspp_dropout,
96+
),
7397
SeparableConv2d(
7498
out_channels, out_channels, kernel_size=3, padding=1, bias=False
7599
),
@@ -164,7 +188,8 @@ def __init__(
164188
in_channels: int,
165189
out_channels: int,
166190
atrous_rates: Iterable[int],
167-
separable: bool=False,
191+
separable: bool,
192+
dropout: float,
168193
):
169194
super(ASPP, self).__init__()
170195
modules = [
@@ -189,7 +214,7 @@ def __init__(
189214
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
190215
nn.BatchNorm2d(out_channels),
191216
nn.ReLU(),
192-
nn.Dropout(0.5),
217+
nn.Dropout(dropout),
193218
)
194219

195220
def forward(self, x):

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ class DeepLabV3(SegmentationModel):
2424
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
2525
other pretrained weights (see table with available weights for each encoder_name)
2626
decoder_channels: A number of convolution filters in ASPP module. Default is 256
27+
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
28+
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
29+
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False
30+
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
2731
in_channels: A number of input channels for the model, default is 3 (RGB images)
2832
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
2933
activation: An activation function to apply after the final convolution layer.
3034
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
3135
**callable** and **None**.
3236
Default is **None**
33-
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
37+
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
3438
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
3539
on top of encoder if **aux_params** is not **None** (default). Supported params:
3640
- classes (int): A number of classes
@@ -51,11 +55,15 @@ def __init__(
5155
encoder_name: str = "resnet34",
5256
encoder_depth: int = 5,
5357
encoder_weights: Optional[str] = "imagenet",
58+
encoder_output_stride: Literal[8, 16] = 8,
5459
decoder_channels: int = 256,
60+
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
61+
decoder_aspp_separable: bool = False,
62+
decoder_aspp_dropout: float = 0.5,
5563
in_channels: int = 3,
5664
classes: int = 1,
5765
activation: Optional[str] = None,
58-
upsampling: int = 8,
66+
upsampling: Optional[int] = None,
5967
aux_params: Optional[dict] = None,
6068
):
6169
super().__init__()
@@ -65,19 +73,23 @@ def __init__(
6573
in_channels=in_channels,
6674
depth=encoder_depth,
6775
weights=encoder_weights,
68-
output_stride=8,
76+
output_stride=encoder_output_stride,
6977
)
7078

7179
self.decoder = DeepLabV3Decoder(
72-
in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels
80+
in_channels=self.encoder.out_channels[-1],
81+
out_channels=decoder_channels,
82+
atrous_rates=decoder_atrous_rates,
83+
aspp_separable=decoder_aspp_separable,
84+
aspp_dropout=decoder_aspp_dropout,
7385
)
7486

7587
self.segmentation_head = SegmentationHead(
7688
in_channels=self.decoder.out_channels,
7789
out_channels=classes,
7890
activation=activation,
7991
kernel_size=1,
80-
upsampling=upsampling,
92+
upsampling=encoder_output_stride if upsampling is None else upsampling,
8193
)
8294

8395
if aux_params is not None:
@@ -102,8 +114,9 @@ class DeepLabV3Plus(SegmentationModel):
102114
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
103115
other pretrained weights (see table with available weights for each encoder_name)
104116
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
105-
decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
106117
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
118+
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True
119+
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
107120
decoder_channels: A number of convolution filters in ASPP module. Default is 256
108121
in_channels: A number of input channels for the model, default is 3 (RGB images)
109122
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -134,8 +147,9 @@ def __init__(
134147
encoder_weights: Optional[str] = "imagenet",
135148
encoder_output_stride: Literal[8, 16] = 16,
136149
decoder_channels: int = 256,
137-
decoder_atrous_rates: tuple = (12, 24, 36),
138150
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
151+
decoder_aspp_separable: bool = True,
152+
decoder_aspp_dropout: float = 0.5,
139153
in_channels: int = 3,
140154
classes: int = 1,
141155
activation: Optional[str] = None,
@@ -157,6 +171,8 @@ def __init__(
157171
out_channels=decoder_channels,
158172
atrous_rates=decoder_atrous_rates,
159173
output_stride=encoder_output_stride,
174+
aspp_separable=decoder_aspp_separable,
175+
aspp_dropout=decoder_aspp_dropout,
160176
)
161177

162178
self.segmentation_head = SegmentationHead(

0 commit comments

Comments
 (0)