Skip to content

Commit 7c8b015

Browse files
authored
Merge pull request #54 from magcil/52-integrate-mobilenets-as-feature-extractor-backbones
52 integrate mobilenets as feature extractor backbones
2 parents 5c15433 + e0b631a commit 7c8b015

File tree

15 files changed

+1056
-34
lines changed

15 files changed

+1056
-34
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# DeepAudioX
22

3+
[![docs-dev](https://img.shields.io/badge/docs--dev-latest-blue)](https://deepaudio-x.readthedocs.io/en/latest/)
34
[![PyPI version](https://img.shields.io/pypi/v/deepaudio-x.svg?cacheSeconds=60&v=1)](https://pypi.org/project/deepaudio-x/)
45
[![Python versions](https://img.shields.io/pypi/pyversions/deepaudio-x.svg?cacheSeconds=300)](https://pypi.org/project/deepaudio-x/)
56
[![License](https://img.shields.io/github/license/magcil/deepaudio-x.svg)](https://github.com/magcil/deepaudio-x/blob/main/LICENSE)
67
[![Run Tests](https://github.com/magcil/deepaudio-x/actions/workflows/tests.yml/badge.svg)](https://github.com/magcil/deepaudio-x/actions/workflows/tests.yml)
78

9+
810
<p align="left">
911
<img src="docs/source/_static/DeepAudioX_whitebg.png" style="width: 60%" alt="DeepAudio-X logo">
1012
</p>
@@ -203,6 +205,10 @@ classifier = AudioClassifier(
203205

204206
- **BEATs** (`"beats"`): BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
205207
- **PaSST** (`"passt"`): Efficient Training of Audio Transformers with Patchout (https://arxiv.org/abs/2110.05069)
208+
- **MobileNet (0.5x, AudioSet)** (`"mobilenet_05_as"`): MobileNetV3 audio backbone pretrained on AudioSet
209+
- **MobileNet (1.0x, AudioSet)** (`"mobilenet_10_as"`): MobileNetV3 audio backbone pretrained on AudioSet
210+
- **MobileNet (4.0x, AudioSet)** (`"mobilenet_40_as"`): MobileNetV3 audio backbone pretrained on AudioSet
211+
Width multipliers (`0.5x`, `1.0x`, `4.0x`) scale convolution channel sizes. Reference: https://arxiv.org/abs/2211.04772
206212

207213
### Key Parameters
208214

docs/source/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
from __future__ import annotations
44

5-
from datetime import datetime
65
import os
76
import sys
8-
7+
from datetime import datetime
98
from importlib import metadata
109

1110
# -- Path setup --------------------------------------------------------------

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "deepaudio-x"
3-
version = "0.3.7"
3+
version = "0.4.0"
44
description = "DeepAudio-X: Self-supervised audio toolkit for audio classification and beyond."
55
authors = [
66
{ name = "Christos Nikou", email = "chrisnick92@gmail.com" },
@@ -23,6 +23,7 @@ dependencies = [
2323
"soundfile>=0.13.1",
2424
"torch>=2.8.0",
2525
"torchaudio>=2.8.0",
26+
"torchvision>=0.23.0",
2627
"tqdm>=4.67.1",
2728
]
2829

src/deepaudiox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This page provides the core API reference for DeepAudioX.
33
"""
44

5-
__version__ = "0.3.7"
5+
__version__ = "0.4.0"
66

77
# Top-level API exports
88
from deepaudiox.datasets.audio_classification_dataset import ( # noqa: F401

src/deepaudiox/modules/backbones/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable
44

55
from deepaudiox.modules.backbones.beats.beats_modules.BEATs import BEATs
6+
from deepaudiox.modules.backbones.mobilenet.model import MobileNet, MobileNetConfig
67
from deepaudiox.modules.backbones.passt.passt import PaSST
78
from deepaudiox.modules.baseclasses import BaseBackbone as Backbone
89

@@ -32,3 +33,21 @@ def beats_base() -> BEATs:
3233
def passt_base() -> PaSST:
3334
"""PaSST backbone"""
3435
return PaSST()
36+
37+
38+
@register_backbone("mobilenet_05_as")
39+
def monilenet_05_base() -> MobileNet:
40+
"""MobileNet backbone"""
41+
return MobileNet(cfg=MobileNetConfig({"width_mult": 0.5}))
42+
43+
44+
@register_backbone("mobilenet_10_as")
45+
def monilenet_10_base() -> MobileNet:
46+
"""MobileNet backbone"""
47+
return MobileNet(cfg=MobileNetConfig({"width_mult": 1}))
48+
49+
50+
@register_backbone("mobilenet_40_as")
51+
def monilenet_40_base() -> MobileNet:
52+
"""MobileNet backbone"""
53+
return MobileNet(cfg=MobileNetConfig({"width_mult": 4}))
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from collections.abc import Callable
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch import Tensor
6+
from torchvision.ops.misc import Conv2dNormActivation
7+
8+
from deepaudiox.modules.backbones.mobilenet.utils import cnn_out_size, make_divisible
9+
10+
11+
class ConcurrentSEBlock(torch.nn.Module):
12+
"""
13+
Applies multiple Squeeze-and-Excitation (SE) operations concurrently across
14+
different dimensions and aggregates the results.
15+
16+
This block allows the model to attend to channel, frequency, or time dimensions
17+
independently before merging the attention masks using a specified aggregation
18+
operation (max, avg, add, or min).
19+
"""
20+
21+
def __init__(self, c_dim: int, f_dim: int, t_dim: int, se_cnf: dict) -> None:
22+
"""
23+
Initializes the ConcurrentSEBlock.
24+
25+
Args:
26+
c_dim (int): Number of channels.
27+
f_dim (int): Frequency dimension size.
28+
t_dim (int): Time dimension size.
29+
se_cnf (Dict): Configuration dictionary containing:
30+
- 'se_dims': List of dimensions to apply SE on (1=C, 2=F, 3=T).
31+
- 'se_r': Reduction ratio for the bottleneck.
32+
- 'se_agg': Aggregation method ('max', 'avg', 'add', 'min').
33+
"""
34+
super().__init__()
35+
dims = [c_dim, f_dim, t_dim]
36+
self.conc_se_layers = nn.ModuleList()
37+
for d in se_cnf["se_dims"]:
38+
input_dim = dims[d - 1]
39+
squeeze_dim = make_divisible(input_dim // se_cnf["se_r"], 8)
40+
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d))
41+
if se_cnf["se_agg"] == "max":
42+
self.agg_op = lambda x: torch.max(x, dim=0)[0]
43+
elif se_cnf["se_agg"] == "avg":
44+
self.agg_op = lambda x: torch.mean(x, dim=0)
45+
elif se_cnf["se_agg"] == "add":
46+
self.agg_op = lambda x: torch.sum(x, dim=0)
47+
elif se_cnf["se_agg"] == "min":
48+
self.agg_op = lambda x: torch.min(x, dim=0)[0]
49+
else:
50+
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented")
51+
52+
def forward(self, input: Tensor) -> Tensor:
53+
"""
54+
Forward pass of the concurrent SE block.
55+
56+
Args:
57+
input (Tensor): Input tensor of shape (B, C, F, T).
58+
59+
Returns:
60+
Tensor: Attention-weighted tensor aggregated from multiple SE paths.
61+
"""
62+
se_outs = []
63+
for se_layer in self.conc_se_layers:
64+
se_outs.append(se_layer(input))
65+
out = self.agg_op(torch.stack(se_outs, dim=0))
66+
return out
67+
68+
69+
class SqueezeExcitation(torch.nn.Module):
70+
"""
71+
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507.
72+
"""
73+
74+
def __init__(
75+
self,
76+
input_dim: int,
77+
squeeze_dim: int,
78+
se_dim: int,
79+
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
80+
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
81+
) -> None:
82+
"""
83+
Initializes the SE block.
84+
85+
Args:
86+
input_dim (int): Number of features in the target dimension.
87+
squeeze_dim (int): Size of the bottleneck (input_dim // reduction_ratio).
88+
se_dim (int): The dimension to preserve (1, 2, or 3).
89+
activation (Callable): Non-linear activation for the bottleneck.
90+
scale_activation (Callable): Activation for the final attention mask.
91+
"""
92+
super().__init__()
93+
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim)
94+
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim)
95+
assert se_dim in [1, 2, 3]
96+
self.se_dim = [1, 2, 3]
97+
self.se_dim.remove(se_dim)
98+
self.activation = activation()
99+
self.scale_activation = scale_activation()
100+
101+
def _scale(self, input: Tensor) -> Tensor:
102+
"""
103+
Computes the attention mask by squeezing spatial/channel information.
104+
105+
Args:
106+
input (Tensor): Input feature map.
107+
108+
Returns:
109+
Tensor: The computed attention weights (0 to 1).
110+
"""
111+
scale = torch.mean(input, self.se_dim, keepdim=True)
112+
shape = scale.size()
113+
scale = self.fc1(scale.squeeze(2).squeeze(2))
114+
scale = self.activation(scale)
115+
scale = self.fc2(scale)
116+
scale = scale
117+
return self.scale_activation(scale).view(shape)
118+
119+
def forward(self, input: Tensor) -> Tensor:
120+
"""
121+
Applies the computed attention mask to the input tensor.
122+
123+
Args:
124+
input (Tensor): Input feature map.
125+
126+
Returns:
127+
Tensor: Element-wise scaled feature map.
128+
"""
129+
scale = self._scale(input)
130+
return scale * input
131+
132+
133+
class InvertedResidualConfig:
134+
"""
135+
Configuration helper for MobileNetV3 Inverted Residual blocks.
136+
137+
Stores architectural parameters for a single block including expansion,
138+
kernel size, stride, and Squeeze-and-Excitation settings.
139+
"""
140+
141+
def __init__(
142+
self,
143+
input_channels: int,
144+
kernel: int,
145+
expanded_channels: int,
146+
out_channels: int,
147+
use_se: bool,
148+
activation: str,
149+
stride: int,
150+
dilation: int,
151+
width_mult: float,
152+
):
153+
"""
154+
Initializes block configuration and adjusts channels by the width multiplier.
155+
"""
156+
self.input_channels = self.adjust_channels(input_channels, width_mult)
157+
self.kernel = kernel
158+
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
159+
self.out_channels = self.adjust_channels(out_channels, width_mult)
160+
self.use_se = use_se
161+
self.use_hs = activation == "HS"
162+
self.stride = stride
163+
self.dilation = dilation
164+
self.f_dim: int | None = None
165+
self.t_dim: int | None = None
166+
167+
@staticmethod
168+
def adjust_channels(channels: int, width_mult: float):
169+
"""
170+
Scales the number of channels by width_mult and ensures divisibility by 8.
171+
172+
Args:
173+
channels (int): Base number of channels.
174+
width_mult (float): Scaling factor.
175+
176+
Returns:
177+
int: Adjusted channel count.
178+
"""
179+
return make_divisible(channels * width_mult, 8)
180+
181+
def out_size(self, in_size: int):
182+
"""
183+
Calculates the output spatial size for this block given an input size.
184+
185+
Args:
186+
in_size (int): Input height or width.
187+
188+
Returns:
189+
int: Output height or width after convolution.
190+
"""
191+
padding = (self.kernel - 1) // 2 * self.dilation
192+
return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride)
193+
194+
195+
class InvertedResidual(nn.Module):
196+
"""
197+
MobileNetV3 Inverted Residual Block.
198+
199+
Consists of:
200+
1. 1x1 Expansion convolution (if necessary).
201+
2. Depthwise convolution.
202+
3. Squeeze-and-Excitation (optional).
203+
4. 1x1 Projection convolution.
204+
5. Residual connection (if stride=1 and input_dims == output_dims).
205+
"""
206+
207+
def __init__(
208+
self,
209+
cnf: InvertedResidualConfig,
210+
se_cnf: dict,
211+
norm_layer: Callable[..., nn.Module],
212+
depthwise_norm_layer: Callable[..., nn.Module],
213+
):
214+
"""
215+
Initializes the Inverted Residual block.
216+
217+
Args:
218+
cnf (InvertedResidualConfig): Structural settings for the block.
219+
se_cnf (Dict): Configuration for the Squeeze-Excitation layers.
220+
norm_layer (Callable): Normalization for expansion and projection.
221+
depthwise_norm_layer (Callable): Normalization for the depthwise layer.
222+
"""
223+
super().__init__()
224+
if not (1 <= cnf.stride <= 2):
225+
raise ValueError("illegal stride value")
226+
227+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
228+
229+
layers: list[nn.Module] = []
230+
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
231+
232+
# expand
233+
if cnf.expanded_channels != cnf.input_channels:
234+
layers.append(
235+
Conv2dNormActivation(
236+
cnf.input_channels,
237+
cnf.expanded_channels,
238+
kernel_size=1,
239+
norm_layer=norm_layer,
240+
activation_layer=activation_layer,
241+
)
242+
)
243+
244+
# depthwise
245+
stride = 1 if cnf.dilation > 1 else cnf.stride
246+
layers.append(
247+
Conv2dNormActivation(
248+
cnf.expanded_channels,
249+
cnf.expanded_channels,
250+
kernel_size=cnf.kernel,
251+
stride=stride,
252+
dilation=cnf.dilation,
253+
groups=cnf.expanded_channels,
254+
norm_layer=depthwise_norm_layer,
255+
activation_layer=activation_layer,
256+
)
257+
)
258+
if cnf.use_se and se_cnf["se_dims"] is not None:
259+
if cnf.f_dim is None or cnf.t_dim is None:
260+
raise ValueError("cnf.f_dim and cnf.t_dim must be set before constructing SE blocks")
261+
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf))
262+
263+
# project
264+
layers.append(
265+
Conv2dNormActivation(
266+
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
267+
)
268+
)
269+
270+
self.block = nn.Sequential(*layers)
271+
self.out_channels = cnf.out_channels
272+
self._is_cn = cnf.stride > 1
273+
274+
def forward(self, inp: Tensor) -> Tensor:
275+
"""
276+
Forward pass with optional residual skip connection.
277+
278+
Args:
279+
inp (Tensor): Input feature map of shape (B, C, F, T).
280+
281+
Returns:
282+
Tensor: Processed feature map.
283+
"""
284+
result = self.block(inp)
285+
if self.use_res_connect:
286+
result += inp
287+
return result

0 commit comments

Comments
 (0)