The probe API provides an interface for defining, configuring, and attaching probes to backbone/base models that can be used to adapt or fine-tune the backbone/base models to downstream tasks.
Key Ideas:
- Probes (and backbone models) are regular PyTorch modules (linear, MLP, LSTM, attention, transformer heads).
- Configuration is done via
ProbeConfig(Python) or YAML files that map toProbeConfig. - Probes may be trained online (attached to a base model) or offline (on pre-computed embeddings).
Begin with a simple linear probe on the backbone's last layer:
from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
cfg = ProbeConfig(
probe_type="linear",
target_layers=["last_layer"],
aggregation="mean",
freeze_backbone=True,
online_training=True,
)
probe = build_probe_from_config(cfg, base_model=base, num_classes=50, device="cpu")If performance plateaus, move to MLP, LSTM, attention, or transformer probes by changing probe_type and the related fields in ProbeConfig. Generally, attention probe works best with self-supervised models and transformers and it does not improve much on EfficientNet backbones.
- Simple in-domain tasks → linear probes work well on bird classification/detection tasks because most of the bioacoustics models were trained on this tasks
- Out-of-domain tasks → attention/transformer probes on all layers or even lower layers work better for repertoire classification or species that were under-represented in the training data used for the backbones.
- Limited resources →
_lastvariants with linear/MLP - Generous resources →
_allvariants with attention/transformer
Pros:
- Fast execution
- Simple architecture
- Lower memory usage
- Fewer parameters to train
Cons:
- Single representation, overfitted for species classification (mostly birds) in the case of supervised models
- May miss multi-scale features
Use when:
- Quick experiments needed
- Limited computational resources
- Strong, well-trained backbone
- Simple classification tasks
Pros:
- Rich multi-scale features
- More expressive models
- Better for complex tasks
- Learns optimal layer weighting
Cons:
- Slower execution
- High disk usage in the case of offline probes
- Higher memory requirements
- More parameters to train
Use when:
- Maximum performance needed
- Sufficient computational resources
- Out-of-domain tasks
- Comparing layer-wise representations
Task Complexity: LOW ──────────────────────────────────> HIGH
Probe Type: linear → mlp → lstm → attention → transformer
Feature Scope: SINGLE LAYER ─────────────────────────> ALL LAYERS
Variant: _last ─────────────────────────────────> _all
Computational: FAST ──────────────────────────────────> SLOW
linear_last ──────────────────────> transformer_all
from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
# 1. Load a backbone model that returns features
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
# 2. Define a simple linear probe on the backbone features
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"], # use final backbone layer
aggregation="mean", # mean-pool over time
freeze_backbone=True, # keep backbone frozen
online_training=True, # end-to-end graph (even if backbone is frozen)
)
# 3. Build the probe
probe = build_probe_from_config(
probe_config=probe_config,
base_model=base,
num_classes=50,
device="cpu",
)from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
# For pre-computed embeddings (no base model needed)
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"], # conceptual; not used when base_model=None
aggregation="none",
freeze_backbone=True,
online_training=False,
)
probe = build_probe_from_config(
probe_config=probe_config,
input_dim=768, # embedding dimension
num_classes=50,
device="cpu",
)
# Use with embeddings
# For inference, set the probe to eval mode and use torch.no_grad()
probe.eval()
with torch.no_grad():
predictions = probe(embeddings) # embeddings shape: (batch, 768)Note: The probe's forward() method does not automatically use inference mode. For inference (when you don't need gradients), you should:
- Call
probe.eval()to set the model to evaluation mode (disables dropout, batch norm updates, etc.) - Wrap the forward pass in
with torch.no_grad():to disable gradient computation and reduce memory usage
For training/fine-tuning, use probe.train() and omit the torch.no_grad() context.
Common probe_type values:
linear– simple linear classifiermlp– multi-layer perceptronlstm– LSTM sequence modelattention– self-attention headtransformer– transformer encoder probe
All probe configs support (non-exhaustive):
-
Architecture & layers
probe_type:"linear" | "mlp" | "lstm" | "attention" | "transformer"- The architecture of the probe head:"linear": 2D probe - Simple linear classifier (single fully-connected layer). Fastest and most memory-efficient. Expects 2D input(batch, features). Use withaggregation="mean"or"max". Best for: baseline performance, simple tasks, limited resources."mlp": 2D probe - Multi-layer perceptron with configurable hidden layers and non-linear activations. More expressive than linear while still efficient. Expects 2D input(batch, features). Use withaggregation="mean"or"max". Requireshidden_dimsparameter. Best for: tasks needing non-linearity, moderate complexity."lstm": 3D probe - Long Short-Term Memory network for sequence modeling. Processes temporal sequences and captures long-range dependencies. Expects 3D input(batch, time, features). Use withaggregation="none"to preserve sequence structure. Requireslstm_hidden_size,num_layers, and optionallybidirectional. Best for: temporal/sequential tasks, variable-length sequences."attention": 3D probe - Self-attention mechanism for sequence modeling. Captures relationships between all positions in a sequence. Expects 3D input(batch, time, features). Use withaggregation="none"to preserve sequence structure. Requiresnum_headsandattention_dim. Best for: tasks requiring global sequence understanding, parallel processing."transformer": 3D probe - Full transformer encoder architecture with multiple attention layers. Most expressive and powerful probe type. Expects 3D input(batch, time, features). Use withaggregation="none"to preserve sequence structure. Requiresnum_heads,attention_dim, andnum_layers. Best for: complex tasks, maximum performance, sufficient computational resources.
target_layers: List of layer names to extract embeddings from. Main options:["last_layer"]: Uses the final (non-classification) layer of the model. Best for: single-layer probing, baseline experiments, efficient computation.["all"]: Uses all discoverable layers in the model. Best for: multi-layer probing, learning optimal layer combinations, maximum expressiveness.- Specific layer names: Use concrete layer names (e.g.,
["layer_6", "layer_12"]). Discover available layers usinglist_model_layers(model_name). Best for: targeted probing of specific layers, custom layer combinations.
aggregation:"mean" | "max" | "none" | "cls_token"- Controls how to reduce the time/sequence dimension of embeddings:"mean": Average pooling over the time dimension. Reduces 3D embeddings(batch, time, features)to 2D(batch, features). Use with 2D probes (linear, MLP) that expect fixed-size feature vectors."max": Max pooling over the time dimension. Reduces 3D embeddings(batch, time, features)to 2D(batch, features). Alternative to mean pooling, can capture peak activations. Use with 2D probes (linear, MLP)."none": No aggregation - preserves the full sequence structure(batch, time, features). Required for 3D probes (LSTM, attention, transformer) that process sequences. Also enables learned weighted combination of multiple layers."cls_token": Uses only the first token (CLS token) from transformer models. Reduces to 2D(batch, features). Use with transformer-based backbones and 2D probes.
input_processing:"pooled" | "sequence" | "flatten" | "none"- How to process input embeddings before feeding to the probe:"pooled": Default - Pools embeddings to a fixed dimension. Works with embeddings that have already been aggregated (e.g., viaaggregation="mean"). Use with 2D probes (linear, MLP) that expect fixed-size feature vectors."sequence": Keeps sequence structure - Preserves the temporal/sequence dimension(batch, time, features). Required for 3D probes (LSTM, attention, transformer) that process sequences. Only compatible with sequence-based probe types. Must use withaggregation="none"."flatten": Flattens all dimensions - Reshapes multi-dimensional embeddings into a single vector. Converts any shape to(batch, features). Use when you need to flatten complex embeddings (e.g., 4D tensors) for 2D probes."none": No processing - Uses embeddings as-is without any transformation. Use when embeddings are already in the correct format for your probe type.
-
Training behavior
freeze_backbone:Trueto keep base model frozenonline_training:Truefor online (end-to-end graph) vsFalsefor pure offline
-
Probe-specific parameters
- MLP:
hidden_dims,dropout_rate,activation, ... - LSTM:
lstm_hidden_size,num_layers,bidirectional,max_sequence_length, ... - Attention/Transformer:
num_heads,attention_dim,num_layers,max_sequence_length,use_positional_encoding, ...
See
ProbeConfigclass documentation or useProbeConfig.model_json_schema()for complete parameter details, defaults, and valid ranges. - MLP:
from avex.configs import ProbeConfig
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"],
aggregation="mean",
freeze_backbone=True,
online_training=True,
)# my_linear_probe.yml
probe_type: linear
target_layers: ["backbone"]
aggregation: mean
freeze_backbone: true
online_training: truefrom avex.models.probes.utils import (
load_probe_config,
build_probe_from_config,
)
from avex import load_model
config = load_probe_config("my_linear_probe.yml")
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe = build_probe_from_config(config, base_model=base, num_classes=50, device="cpu")Unified factory function for building probe instances from a ProbeConfig. Supports both online (with base model) and offline (with pre-computed embeddings) modes.
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
def build_probe_from_config(
probe_config: ProbeConfig,
num_classes: int,
device: str,
base_model: Optional[torch.nn.Module] = None,
input_dim: Optional[int] = None,
target_length: Optional[int] = None,
**kwargs,
) -> torch.nn.Module:
...Key parameters:
probe_config: TheProbeConfigobject.num_classes: Number of output classes.device:"cpu"or"cuda", etc.base_model: Optional backbone model to attach the probe to (for online mode). If provided, probe will be attached for end-to-end training.input_dim: Optional embedding dimension (for offline mode). Required ifbase_modelis None.target_length: Optional audio target length override.
Mode detection:
- Online mode: When
base_modelis provided, the probe is attached to the base model for end-to-end training. - Offline mode: When
input_dimis provided, the probe operates on pre-computed embeddings without a base model.
Returns: A torch.nn.Module probe ready for training/inference.
from avex.models.probes.utils import load_probe_config
config = load_probe_config("my_probe.yml")Supports:
- Files with top-level probe fields.
- Files with a nested
probe_config: {...}block.
All probe configs include:
probe_type- Type of probe architecturetarget_layers- Which layers to extract features fromaggregation- How to aggregate features (mean, max, none)input_processing- How to process inputs (pooled, sequence, flatten)freeze_backbone- Whether to freeze backbone weightsonline_training- Whether to train end-to-end or offline
Probe-specific parameters:
- MLP:
hidden_dims,dropout_rate,activation - LSTM:
lstm_hidden_size,num_layers,bidirectional,max_sequence_length - Attention:
num_heads,attention_dim,num_layers,max_sequence_length - Transformer:
num_heads,attention_dim,num_layers,max_sequence_length
from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe_types = [
("linear", {"aggregation": "mean"}),
("mlp", {"aggregation": "mean", "hidden_dims": [512, 256]}),
("attention", {"input_processing": "sequence", "num_heads": 4, "attention_dim": 128}),
]
for probe_type, extra_cfg in probe_types:
cfg = ProbeConfig(
probe_type=probe_type,
target_layers=["backbone"],
freeze_backbone=True,
online_training=True,
**extra_cfg,
)
probe = build_probe_from_config(
probe_config=cfg,
base_model=base,
num_classes=10,
device="cpu",
)
print(probe_type, "parameters:", sum(p.numel() for p in probe.parameters()))Expected output:
linear parameters: 7680
mlp parameters: 395264
attention parameters: 66560
# custom_probe.yml
# probe_type: mlp
# target_layers: ["backbone"]
# aggregation: mean
# hidden_dims: [1024, 512]
from avex.models.probes.utils import (
build_probe_from_config,
load_probe_config,
)
from avex import load_model
config = load_probe_config("custom_probe.yml")
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe = build_probe_from_config(config, base_model=base, num_classes=50, device="cpu")from avex.configs import ProbeConfig
from avex.models.probes.utils import build_probe_from_config
# Create config programmatically
config = ProbeConfig(
probe_type="attention",
target_layers=["layer_12"],
aggregation="none",
input_processing="sequence",
num_heads=8,
attention_dim=64,
num_layers=1,
)
# Use it
probe = build_probe_from_config(config, base_model=my_model, num_classes=50, device="cpu")The probe API mirrors the model API structure for consistency:
avex/
├── models/probes/
│ ├── utils/ # Probe utilities (parallel to models/utils/)
│ │ ├── __init__.py
│ │ ├── registry.py # Probe class discovery + YAML helpers
│ │ └── factory.py # build_probe_from_config
│ └── [probe implementations]
└── examples/
└── 07_probe_training_and_inference.py # Usage examples
- Probe Class Registry:
_PROBE_CLASSESfor discovered probe implementations - Discovery: Dynamically finds all probe classes (LinearProbe, MLPProbe, etc.)
- YAML Helpers:
load_probe_config()for loadingProbeConfigfrom disk
- build_probe_from_config(): Unified factory for building probes from
ProbeConfig(supports both online and offline modes) - Handle parameter filtering and base-model interaction (freezing, hooks, feature-mode)
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
import torch
# Test offline mode (works independently)
cfg = ProbeConfig(
probe_type="linear",
target_layers=["backbone"],
aggregation="none",
freeze_backbone=True,
online_training=False,
)
probe = build_probe_from_config(
cfg,
input_dim=768,
num_classes=10,
device="cpu",
)
# Test forward pass (inference mode)
probe.eval()
with torch.no_grad():
dummy_embeddings = torch.randn(2, 768)
output = probe(dummy_embeddings)
print(f"Output shape: {output.shape}") # Should be (2, 10)cd /home/marius/code/avex
python examples/07_probe_training_and_inference.py✅ Probe Discovery: Automatically finds all probe classes
✅ Config Loading: load_probe_config() builds ProbeConfig from YAML
✅ Factory Usage: build_probe_from_config() builds probes from ProbeConfig (supports both online and offline modes)
✅ Offline Mode: Creates probes for pre-computed embeddings
✅ Online Mode: Loads and attaches to base models
✅ Forward Pass: Correct output shapes with dummy data
✅ No Linter Errors: All code is ruff-compliant
✅ Layer Variants: _last and _all variants work correctly
- Model Registry: Pre-existing circular import prevents model loading in some contexts
- This is a separate issue in the existing codebase
- Doesn't affect offline probe functionality
- Doesn't affect direct model instance usage
models/probes/utils/__init__.pymodels/probes/utils/registry.pymodels/probes/utils/factory.py
examples/07_probe_training_and_inference.pydocs/api_probes.md(this file)
The following components were intentionally not implemented:
models/probes/utils/checkpoint.py- Checkpoint save/load utilities- Embedding extraction utilities
These can be added in future iterations following the same design patterns.
examples/07_probe_training_and_inference.py- Complete usage examplesavex/models/probes/- Probe implementations- Model API documentation for parallel structure reference