|
| 1 | +from collections.abc import Callable |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torch import Tensor |
| 6 | +from torchvision.ops.misc import Conv2dNormActivation |
| 7 | + |
| 8 | +from deepaudiox.modules.backbones.mobilenet.utils import cnn_out_size, make_divisible |
| 9 | + |
| 10 | + |
| 11 | +class ConcurrentSEBlock(torch.nn.Module): |
| 12 | + """ |
| 13 | + Applies multiple Squeeze-and-Excitation (SE) operations concurrently across |
| 14 | + different dimensions and aggregates the results. |
| 15 | +
|
| 16 | + This block allows the model to attend to channel, frequency, or time dimensions |
| 17 | + independently before merging the attention masks using a specified aggregation |
| 18 | + operation (max, avg, add, or min). |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, c_dim: int, f_dim: int, t_dim: int, se_cnf: dict) -> None: |
| 22 | + """ |
| 23 | + Initializes the ConcurrentSEBlock. |
| 24 | +
|
| 25 | + Args: |
| 26 | + c_dim (int): Number of channels. |
| 27 | + f_dim (int): Frequency dimension size. |
| 28 | + t_dim (int): Time dimension size. |
| 29 | + se_cnf (Dict): Configuration dictionary containing: |
| 30 | + - 'se_dims': List of dimensions to apply SE on (1=C, 2=F, 3=T). |
| 31 | + - 'se_r': Reduction ratio for the bottleneck. |
| 32 | + - 'se_agg': Aggregation method ('max', 'avg', 'add', 'min'). |
| 33 | + """ |
| 34 | + super().__init__() |
| 35 | + dims = [c_dim, f_dim, t_dim] |
| 36 | + self.conc_se_layers = nn.ModuleList() |
| 37 | + for d in se_cnf["se_dims"]: |
| 38 | + input_dim = dims[d - 1] |
| 39 | + squeeze_dim = make_divisible(input_dim // se_cnf["se_r"], 8) |
| 40 | + self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d)) |
| 41 | + if se_cnf["se_agg"] == "max": |
| 42 | + self.agg_op = lambda x: torch.max(x, dim=0)[0] |
| 43 | + elif se_cnf["se_agg"] == "avg": |
| 44 | + self.agg_op = lambda x: torch.mean(x, dim=0) |
| 45 | + elif se_cnf["se_agg"] == "add": |
| 46 | + self.agg_op = lambda x: torch.sum(x, dim=0) |
| 47 | + elif se_cnf["se_agg"] == "min": |
| 48 | + self.agg_op = lambda x: torch.min(x, dim=0)[0] |
| 49 | + else: |
| 50 | + raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented") |
| 51 | + |
| 52 | + def forward(self, input: Tensor) -> Tensor: |
| 53 | + """ |
| 54 | + Forward pass of the concurrent SE block. |
| 55 | +
|
| 56 | + Args: |
| 57 | + input (Tensor): Input tensor of shape (B, C, F, T). |
| 58 | +
|
| 59 | + Returns: |
| 60 | + Tensor: Attention-weighted tensor aggregated from multiple SE paths. |
| 61 | + """ |
| 62 | + se_outs = [] |
| 63 | + for se_layer in self.conc_se_layers: |
| 64 | + se_outs.append(se_layer(input)) |
| 65 | + out = self.agg_op(torch.stack(se_outs, dim=0)) |
| 66 | + return out |
| 67 | + |
| 68 | + |
| 69 | +class SqueezeExcitation(torch.nn.Module): |
| 70 | + """ |
| 71 | + This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507. |
| 72 | + """ |
| 73 | + |
| 74 | + def __init__( |
| 75 | + self, |
| 76 | + input_dim: int, |
| 77 | + squeeze_dim: int, |
| 78 | + se_dim: int, |
| 79 | + activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, |
| 80 | + scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, |
| 81 | + ) -> None: |
| 82 | + """ |
| 83 | + Initializes the SE block. |
| 84 | +
|
| 85 | + Args: |
| 86 | + input_dim (int): Number of features in the target dimension. |
| 87 | + squeeze_dim (int): Size of the bottleneck (input_dim // reduction_ratio). |
| 88 | + se_dim (int): The dimension to preserve (1, 2, or 3). |
| 89 | + activation (Callable): Non-linear activation for the bottleneck. |
| 90 | + scale_activation (Callable): Activation for the final attention mask. |
| 91 | + """ |
| 92 | + super().__init__() |
| 93 | + self.fc1 = torch.nn.Linear(input_dim, squeeze_dim) |
| 94 | + self.fc2 = torch.nn.Linear(squeeze_dim, input_dim) |
| 95 | + assert se_dim in [1, 2, 3] |
| 96 | + self.se_dim = [1, 2, 3] |
| 97 | + self.se_dim.remove(se_dim) |
| 98 | + self.activation = activation() |
| 99 | + self.scale_activation = scale_activation() |
| 100 | + |
| 101 | + def _scale(self, input: Tensor) -> Tensor: |
| 102 | + """ |
| 103 | + Computes the attention mask by squeezing spatial/channel information. |
| 104 | +
|
| 105 | + Args: |
| 106 | + input (Tensor): Input feature map. |
| 107 | +
|
| 108 | + Returns: |
| 109 | + Tensor: The computed attention weights (0 to 1). |
| 110 | + """ |
| 111 | + scale = torch.mean(input, self.se_dim, keepdim=True) |
| 112 | + shape = scale.size() |
| 113 | + scale = self.fc1(scale.squeeze(2).squeeze(2)) |
| 114 | + scale = self.activation(scale) |
| 115 | + scale = self.fc2(scale) |
| 116 | + scale = scale |
| 117 | + return self.scale_activation(scale).view(shape) |
| 118 | + |
| 119 | + def forward(self, input: Tensor) -> Tensor: |
| 120 | + """ |
| 121 | + Applies the computed attention mask to the input tensor. |
| 122 | +
|
| 123 | + Args: |
| 124 | + input (Tensor): Input feature map. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + Tensor: Element-wise scaled feature map. |
| 128 | + """ |
| 129 | + scale = self._scale(input) |
| 130 | + return scale * input |
| 131 | + |
| 132 | + |
| 133 | +class InvertedResidualConfig: |
| 134 | + """ |
| 135 | + Configuration helper for MobileNetV3 Inverted Residual blocks. |
| 136 | +
|
| 137 | + Stores architectural parameters for a single block including expansion, |
| 138 | + kernel size, stride, and Squeeze-and-Excitation settings. |
| 139 | + """ |
| 140 | + |
| 141 | + def __init__( |
| 142 | + self, |
| 143 | + input_channels: int, |
| 144 | + kernel: int, |
| 145 | + expanded_channels: int, |
| 146 | + out_channels: int, |
| 147 | + use_se: bool, |
| 148 | + activation: str, |
| 149 | + stride: int, |
| 150 | + dilation: int, |
| 151 | + width_mult: float, |
| 152 | + ): |
| 153 | + """ |
| 154 | + Initializes block configuration and adjusts channels by the width multiplier. |
| 155 | + """ |
| 156 | + self.input_channels = self.adjust_channels(input_channels, width_mult) |
| 157 | + self.kernel = kernel |
| 158 | + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) |
| 159 | + self.out_channels = self.adjust_channels(out_channels, width_mult) |
| 160 | + self.use_se = use_se |
| 161 | + self.use_hs = activation == "HS" |
| 162 | + self.stride = stride |
| 163 | + self.dilation = dilation |
| 164 | + self.f_dim: int | None = None |
| 165 | + self.t_dim: int | None = None |
| 166 | + |
| 167 | + @staticmethod |
| 168 | + def adjust_channels(channels: int, width_mult: float): |
| 169 | + """ |
| 170 | + Scales the number of channels by width_mult and ensures divisibility by 8. |
| 171 | +
|
| 172 | + Args: |
| 173 | + channels (int): Base number of channels. |
| 174 | + width_mult (float): Scaling factor. |
| 175 | +
|
| 176 | + Returns: |
| 177 | + int: Adjusted channel count. |
| 178 | + """ |
| 179 | + return make_divisible(channels * width_mult, 8) |
| 180 | + |
| 181 | + def out_size(self, in_size: int): |
| 182 | + """ |
| 183 | + Calculates the output spatial size for this block given an input size. |
| 184 | +
|
| 185 | + Args: |
| 186 | + in_size (int): Input height or width. |
| 187 | +
|
| 188 | + Returns: |
| 189 | + int: Output height or width after convolution. |
| 190 | + """ |
| 191 | + padding = (self.kernel - 1) // 2 * self.dilation |
| 192 | + return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride) |
| 193 | + |
| 194 | + |
| 195 | +class InvertedResidual(nn.Module): |
| 196 | + """ |
| 197 | + MobileNetV3 Inverted Residual Block. |
| 198 | +
|
| 199 | + Consists of: |
| 200 | + 1. 1x1 Expansion convolution (if necessary). |
| 201 | + 2. Depthwise convolution. |
| 202 | + 3. Squeeze-and-Excitation (optional). |
| 203 | + 4. 1x1 Projection convolution. |
| 204 | + 5. Residual connection (if stride=1 and input_dims == output_dims). |
| 205 | + """ |
| 206 | + |
| 207 | + def __init__( |
| 208 | + self, |
| 209 | + cnf: InvertedResidualConfig, |
| 210 | + se_cnf: dict, |
| 211 | + norm_layer: Callable[..., nn.Module], |
| 212 | + depthwise_norm_layer: Callable[..., nn.Module], |
| 213 | + ): |
| 214 | + """ |
| 215 | + Initializes the Inverted Residual block. |
| 216 | +
|
| 217 | + Args: |
| 218 | + cnf (InvertedResidualConfig): Structural settings for the block. |
| 219 | + se_cnf (Dict): Configuration for the Squeeze-Excitation layers. |
| 220 | + norm_layer (Callable): Normalization for expansion and projection. |
| 221 | + depthwise_norm_layer (Callable): Normalization for the depthwise layer. |
| 222 | + """ |
| 223 | + super().__init__() |
| 224 | + if not (1 <= cnf.stride <= 2): |
| 225 | + raise ValueError("illegal stride value") |
| 226 | + |
| 227 | + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels |
| 228 | + |
| 229 | + layers: list[nn.Module] = [] |
| 230 | + activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU |
| 231 | + |
| 232 | + # expand |
| 233 | + if cnf.expanded_channels != cnf.input_channels: |
| 234 | + layers.append( |
| 235 | + Conv2dNormActivation( |
| 236 | + cnf.input_channels, |
| 237 | + cnf.expanded_channels, |
| 238 | + kernel_size=1, |
| 239 | + norm_layer=norm_layer, |
| 240 | + activation_layer=activation_layer, |
| 241 | + ) |
| 242 | + ) |
| 243 | + |
| 244 | + # depthwise |
| 245 | + stride = 1 if cnf.dilation > 1 else cnf.stride |
| 246 | + layers.append( |
| 247 | + Conv2dNormActivation( |
| 248 | + cnf.expanded_channels, |
| 249 | + cnf.expanded_channels, |
| 250 | + kernel_size=cnf.kernel, |
| 251 | + stride=stride, |
| 252 | + dilation=cnf.dilation, |
| 253 | + groups=cnf.expanded_channels, |
| 254 | + norm_layer=depthwise_norm_layer, |
| 255 | + activation_layer=activation_layer, |
| 256 | + ) |
| 257 | + ) |
| 258 | + if cnf.use_se and se_cnf["se_dims"] is not None: |
| 259 | + if cnf.f_dim is None or cnf.t_dim is None: |
| 260 | + raise ValueError("cnf.f_dim and cnf.t_dim must be set before constructing SE blocks") |
| 261 | + layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf)) |
| 262 | + |
| 263 | + # project |
| 264 | + layers.append( |
| 265 | + Conv2dNormActivation( |
| 266 | + cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None |
| 267 | + ) |
| 268 | + ) |
| 269 | + |
| 270 | + self.block = nn.Sequential(*layers) |
| 271 | + self.out_channels = cnf.out_channels |
| 272 | + self._is_cn = cnf.stride > 1 |
| 273 | + |
| 274 | + def forward(self, inp: Tensor) -> Tensor: |
| 275 | + """ |
| 276 | + Forward pass with optional residual skip connection. |
| 277 | +
|
| 278 | + Args: |
| 279 | + inp (Tensor): Input feature map of shape (B, C, F, T). |
| 280 | +
|
| 281 | + Returns: |
| 282 | + Tensor: Processed feature map. |
| 283 | + """ |
| 284 | + result = self.block(inp) |
| 285 | + if self.use_res_connect: |
| 286 | + result += inp |
| 287 | + return result |
0 commit comments