This document describes the new flexible probe system that replaces the old rigid linear probe approach. The new system allows you to configure different types of probes with various aggregation and processing strategies.
The new probe system provides:
- Multiple probe types: Linear, MLP, LSTM, Attention, and Transformer probes, plus weighted versions of each
- Flexible aggregation: Mean, max, concatenation, CLS token, or no aggregation
- Input processing options: Flatten, sequence, pooled, or no processing
- Probe-specific parameters: Hidden dimensions, attention heads, LSTM configuration, etc.
- Training overrides: Per-probe learning rates, batch sizes, epochs, etc.
- Backward compatibility: Legacy configurations still work automatically
Simple linear classification layer. Good for baseline performance.
probe_config:
probe_type: "linear"
aggregation: "mean"
input_processing: "pooled"
target_layers: ["layer_12"]Multi-layer perceptron with configurable hidden dimensions.
probe_config:
probe_type: "mlp"
aggregation: "mean"
input_processing: "pooled"
target_layers: ["layer_8", "layer_12"]
hidden_dims: [512, 256]
dropout_rate: 0.2
activation: "gelu"Long Short-Term Memory network for sequence modeling.
probe_config:
probe_type: "lstm"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
lstm_hidden_size: 256
num_layers: 2
bidirectional: true
max_sequence_length: 1000Attention mechanism for sequence modeling.
probe_config:
probe_type: "attention"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_6", "layer_10"]
num_heads: 8
attention_dim: 512
num_layers: 2
max_sequence_length: 800
use_positional_encoding: trueFull transformer architecture for complex sequence modeling.
probe_config:
probe_type: "transformer"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 12
attention_dim: 768
num_layers: 4
max_sequence_length: 1200
use_positional_encoding: trueWeighted probe types are enhanced versions of the standard probes that use learned weights to combine multiple layer embeddings. They provide a single architecture head that learns optimal weights for combining embeddings from different layers.
Single linear classifier with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_linear"
aggregation: "none" # Required for weighted probes
input_processing: "pooled"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
freeze_backbone: trueSingle MLP with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_mlp"
aggregation: "none" # Required for weighted probes
input_processing: "pooled"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
hidden_dims: [512, 256]
dropout_rate: 0.2
activation: "gelu"
freeze_backbone: trueSingle LSTM with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_lstm"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
lstm_hidden_size: 128
num_layers: 2
bidirectional: true
max_sequence_length: 1000
use_positional_encoding: false
dropout_rate: 0.3
freeze_backbone: trueSingle attention mechanism with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_attention"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 8
attention_dim: 256
num_layers: 2
max_sequence_length: 800
use_positional_encoding: false
dropout_rate: 0.3
freeze_backbone: trueSingle minimal attention mechanism with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_attention_minimal"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 4
freeze_backbone: trueSingle transformer encoder with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_transformer"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 12
attention_dim: 768
num_layers: 4
max_sequence_length: 1200
use_positional_encoding: true
dropout_rate: 0.3
freeze_backbone: true- Single Architecture Head: Each weighted probe uses one architecture component (linear, MLP, LSTM, attention, transformer) instead of multiple projection heads per layer
- Learned Weighted Sum: Uses
nn.Parameterto learn optimal weights for combining multiple layer embeddings - Dimension Validation: Ensures all embeddings have the same dimension for weighted sum aggregation
- Weight Debugging: All weighted probes implement
print_learned_weights()method to show which layers are most important - Efficiency: More efficient than multiple projection heads while maintaining or improving performance
- Aggregation: Must use
aggregation: "none"to enable learned weights - Multiple Layers: Requires multiple target layers to learn meaningful weights
- Same Dimensions: All layer embeddings must have the same dimension for weighted sum
Average embeddings across layers (default for backward compatibility).
Take maximum values across layers.
Concatenate embeddings from all layers (requires larger probe networks).
Use only the CLS token from sequence-based models.
No aggregation - pass embeddings directly to sequence-based probes.
Pool embeddings to fixed dimension (default for backward compatibility).
Keep sequence structure for sequence-based probes.
Flatten all dimensions into a single vector.
No processing - use embeddings as-is.
experiments:
- run_name: "simple_linear"
run_config: "configs/run_configs/example_run.yml"
pretrained: true
layers: "layer_12" # Legacy field
frozen: true # Legacy fieldexperiments:
- run_name: "advanced_mlp"
run_config: "configs/run_configs/example_run.yml"
pretrained: true
probe_config:
name: "advanced_mlp"
probe_type: "mlp"
aggregation: "concat"
input_processing: "pooled"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
freeze_backbone: true
learning_rate: 3e-4 # Override global LR
batch_size: 4 # Override global batch size
hidden_dims: [1024, 512, 256]
dropout_rate: 0.15
activation: "relu"experiments:
- run_name: "sequence_lstm"
run_config: "configs/run_configs/example_run.yml"
pretrained: true
probe_config:
name: "sequence_lstm"
probe_type: "lstm"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_8", "layer_12"]
lstm_hidden_size: 256
num_layers: 2
bidirectional: true
max_sequence_length: 1000
use_positional_encoding: falseThe new system automatically handles legacy configurations:
- Legacy fields still work:
layersandfrozenfields are automatically converted toprobe_config - No breaking changes: Existing configurations continue to work without modification
- Gradual migration: You can update configurations one at a time
experiments:
- run_name: "old_style"
layers: "layer_12"
frozen: trueexperiments:
- run_name: "new_style"
probe_config:
probe_type: "linear"
aggregation: "mean"
input_processing: "pooled"
target_layers: ["layer_12"]
freeze_backbone: trueEach probe can override global training parameters:
probe_config:
# ... other config ...
learning_rate: 5e-4 # Override global lr
batch_size: 4 # Override global batch_size
train_epochs: 15 # Override global train_epochs
optimizer: "adam" # Override global optimizer
weight_decay: 0.001 # Override global weight_decay- Linear: Baseline performance, quick experiments
- MLP: Better performance, moderate complexity
- LSTM: Sequence modeling, moderate complexity
- Attention: Sequence modeling, higher complexity
- Transformer: Complex sequence modeling, highest complexity
- Weighted Probes: Enhanced versions that learn optimal weights for combining multiple layers
- Use when you want to leverage multiple layers efficiently
- Better performance than concatenation with lower computational cost
- Provides interpretability through learned layer weights
- Single layer: Use
["layer_12"]for final representations - Multiple layers: Use
["layer_6", "layer_8", "layer_10", "layer_12"]for hierarchical features - Early layers: Use
["layer_1", "layer_2", "layer_3"]for low-level features
- Mean/Max: Good for classification tasks
- Concat: Better for complex tasks, requires larger probe networks
- None: Required for sequence-based probes and weighted probes
- Weighted Sum: Automatic with weighted probes when using
aggregation: "none"
- Pooled: Good for classification tasks
- Sequence: Required for sequence-based probes
- Flatten: Good for spatial features
The system automatically validates configurations:
- Required parameters for each probe type
- Compatibility between aggregation and input processing methods
- Valid parameter ranges (positive integers, valid activation functions, etc.)
- Layer name consistency
Common validation errors and solutions:
# Error: MLP probe requires hidden_dims
probe_config:
probe_type: "mlp"
# Missing: hidden_dims
# Solution: Add required parameters
probe_config:
probe_type: "mlp"
hidden_dims: [512, 256]# Error: cls_token aggregation requires sequence input_processing
probe_config:
aggregation: "cls_token"
input_processing: "pooled"
# Solution: Use sequence input_processing
probe_config:
aggregation: "cls_token"
input_processing: "sequence"- Linear/MLP: Low memory usage
- LSTM: Moderate memory usage
- Attention/Transformer: Higher memory usage
- Linear: Fastest training
- MLP: Fast training
- LSTM: Moderate training speed
- Attention/Transformer: Slower training
- Linear: Fastest inference
- MLP: Fast inference
- LSTM: Moderate inference speed
- Attention/Transformer: Slower inference
- Out of Memory: Reduce batch size or use simpler probe types
- Slow Training: Use simpler probe types or reduce hidden dimensions
- Poor Performance: Try different aggregation methods or layer combinations
- Validation Errors: Check parameter compatibility and required fields
Enable debug logging to see detailed configuration validation:
import logging
logging.basicConfig(level=logging.DEBUG)The system is designed to be extensible:
- New probe types: Easy to add new probe architectures
- Custom aggregations: Support for custom aggregation functions
- Advanced processing: More sophisticated input processing methods
- Hyperparameter optimization: Integration with hyperparameter search tools