forked from earthspecies/avex
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path06_classifier_head_loading.py
More file actions
197 lines (163 loc) · 7.44 KB
/
Copy path06_classifier_head_loading.py
File metadata and controls
197 lines (163 loc) · 7.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Example 6: Classifier Head and Probe Behavior
This example demonstrates how load_model and probe heads interact:
- How classifier weights are preserved when loading from checkpoints
- How to use return_features_only for embedding extraction
- How to attach a new linear probe head on top of a backbone
Audio Requirements:
- Each model expects a specific sample rate (defined in model_spec.audio_config.sample_rate)
- Check with: describe_model("model_name") or get_model_spec("model_name").audio_config.sample_rate
- For full reproducibility, resample using librosa with these exact parameters:
import librosa
audio_resampled = librosa.resample(
audio, orig_sr=original_sr, target_sr=target_sr,
res_type="kaiser_best", scale=True
)
"""
import argparse
from pathlib import Path
import torch
from avex import load_model
from avex.configs import ProbeConfig
from avex.models.probes.utils import build_probe_from_config
def main(device: str = "cpu") -> None:
"""Demonstrate classifier head and probe behavior.
Args:
device: Device to use for model and data.
Raises:
ValueError: If model does not have a classifier when expected.
"""
print("Example 6: Classifier Head and Probe Behavior")
print("=" * 60)
# Ensure checkpoints directory exists
checkpoints_dir = Path(__file__).parent.parent / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True)
# =========================================================================
# Part 1: Demonstrating classifier loading with checkpoints
# =========================================================================
print("\nPart 1: Checkpoint-based classifier loading")
print("-" * 60)
# Use the registered esp_aves2_sl_beats_all model with its checkpoint classifier
print("\nLoading BEATs model with classifier from checkpoint ...")
model = load_model("esp_aves2_sl_beats_all", device=device)
model = model.to(device)
# Check if model has a classifier (it should if loaded from checkpoint with classifier weights)
if not hasattr(model, "classifier") or model.classifier is None:
raise ValueError(
"Model does not have a classifier. This might happen if the checkpoint "
"doesn't contain classifier weights or if the model was loaded in "
"return_features_only mode."
)
# Store the original classifier weights from the checkpoint
original_classifier_weight = model.classifier.weight.clone()
original_classifier_bias = model.classifier.bias.clone()
original_num_classes = original_classifier_weight.shape[0]
print(f"Loaded model with {original_num_classes} classes")
print(f" Classifier weight shape: {original_classifier_weight.shape}")
# Save checkpoint
checkpoint_path = checkpoints_dir / "test_beats_checkpoint.pt"
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved checkpoint to: {checkpoint_path}")
# Demo 1: Load from explicit checkpoint (keeps classifier weights)
print("\nDemo 1: Loading from explicit checkpoint")
print(" Behavior: Classifier weights loaded from checkpoint")
loaded_model_1 = load_model(
"esp_aves2_sl_beats_all",
checkpoint_path=str(checkpoint_path),
device=device,
)
weights_match = torch.allclose(
loaded_model_1.classifier.weight,
original_classifier_weight,
atol=1e-6,
)
bias_match = torch.allclose(
loaded_model_1.classifier.bias,
original_classifier_bias,
atol=1e-6,
)
print(f" Classifier weights match checkpoint: {weights_match and bias_match}")
# =========================================================================
# Part 2: Self-supervised model (esp_aves2_naturelm_audio_v1_beats) use cases
# =========================================================================
print("\n" + "=" * 60)
print("Part 2: Self-supervised model (esp_aves2_naturelm_audio_v1_beats)")
print("=" * 60)
print("\nesp_aves2_naturelm_audio_v1_beats is a self-supervised model without a trained classifier.")
print("This demonstrates different ways to use such models.\n")
# Use case 1: Embedding extraction mode (default for models without classifier)
print("Use case 1: Embedding extraction (default behavior)")
print("-" * 60)
model = load_model("esp_aves2_naturelm_audio_v1_beats", device=device)
model.eval()
# Models without a checkpoint classifier load in embedding mode
has_classifier = hasattr(model, "classifier") and model.classifier is not None
print(f" Has classifier: {has_classifier}")
print(f" Return features only: {getattr(model, '_return_features_only', 'N/A')}")
# Test forward pass - returns unpooled frame-level features
# BEATs expects 16kHz audio
dummy_input = torch.randn(1, 16000 * 5, device=device)
with torch.no_grad():
output = model(dummy_input, padding_mask=None)
print(f" Output shape: {output.shape} (batch, time_steps, features)")
# Use case 2: Add a new classification head via linear probe
print("\nUse case 2: Adding a new classification head with linear probe")
print("-" * 60)
num_classes = 10
backbone = load_model("esp_aves2_naturelm_audio_v1_beats", device=device, return_features_only=True)
backbone.eval()
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"],
aggregation="mean",
freeze_backbone=True,
online_training=True,
)
probe = build_probe_from_config(
probe_config=probe_config,
base_model=backbone,
num_classes=num_classes,
device=device,
)
probe.eval()
dummy_input = torch.randn(1, 16000 * 5, device=device)
with torch.no_grad():
logits = probe(dummy_input)
print(f" Probe output shape: {logits.shape} (batch, num_classes)")
# Use case 3: Explicit embedding extraction with return_features_only
print("\nUse case 3: Explicit embedding extraction mode")
print("-" * 60)
model = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device=device)
model.eval()
dummy_input = torch.randn(1, 16000 * 5, device=device)
with torch.no_grad():
output = model(dummy_input, padding_mask=None)
print(f" Output shape: {output.shape} (batch, time_steps, features)")
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 60)
print("Key takeaways")
print("=" * 60)
print("""
1. Supervised models with checkpoint classifiers:
- load_model keeps the classifier weights from the checkpoint
2. Self-supervised models (like esp_aves2_naturelm_audio_v1_beats):
- No trained classifier exists, so they default to embedding extraction mode
3. return_features_only=True:
- Explicitly requests embedding extraction mode
- Returns unpooled features (batch, time_steps, features)
4. Probe heads:
- You can attach a simple linear probe via build_probe_from_config
- Backbones stay reusable across tasks and heads
""")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Classifier Head Loading Example")
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use for model and data (e.g. cpu, cuda, cuda:0)",
)
args = parser.parse_args()
main(device=args.device)