A comprehensive framework for validating and benchmarking neural network operations using JAX profiling and latency prediction models.
This validation framework provides tools to:
- Profile neural network operations (matrix multiplication, pooling, normalization, activation, elementwise)
- Compare predicted latencies against actual TPU execution times
- Generate detailed performance reports and visualizations
- Support multiple operation types with flexible configurations
- Benchmark on TPU by default (CPU also supported for testing)
unified_model_verification.py is the primary script for running verification tests. It orchestrates validation across multiple operation types and generates comprehensive performance reports.
# Run unified verification for all operation types
python unified_model_verification.py
# Run specific operation type verification
python unified_model_verification.py --matmul
python unified_model_verification.py --pooling
python unified_model_verification.py --elementwise
python unified_model_verification.py --normalization
python unified_model_verification.py --activation| Command | Description |
|---|---|
| (no args) | Run unified verification for all operation types |
--matmul |
Matrix multiplication verification (small-medium sizes) |
--large-matmul |
Large matrix multiplication verification (1K-4K dimensions) |
--pooling |
Max and average pooling operations |
--elementwise |
Element-wise operations (add, subtract, multiply, divide) |
--elementwise-3d |
3D element-wise operations |
--normalization |
Layer norm, RMS norm, and batch normalization |
--activation |
Activation functions (ReLU, Sigmoid, Tanh, etc.) |
--help |
Show help message |
The main entry point that orchestrates verification tests.
Key Functions:
run_unified_verification(): Runs comprehensive tests across all operation typesmatmul_unified_verification(): Specialized matrix multiplication testspooling_unified_verification(): Pooling operation testselementwise_unified_verification(): Element-wise operation testsnormalization_unified_verification(): Normalization operation testsactivation_unified_verification(): Activation function tests
Workflow:
- Creates verification directory for results
- Initializes
ModelVerificationinstance - Adds test configurations for various operations and shapes
- Executes verification and collects results
- Generates detailed analysis reports grouped by operation type
Low-level validation framework for profiling JAX kernels.
Key Classes:
ValidationConfig: Defines kernel configuration (type, shapes, parameters)ValidationPackage: Handles single kernel profiling and trace parsingValidationManager: Manages multiple validation packages
Features:
- JAX kernel compilation and profiling on TPU
- Trace event filtering and analysis
- StableHLO intermediate representation generation
- SCALE-Sim topology file generation
Implements latency prediction models for different operation types.
Key Classes:
PredictionManager: Manages prediction configurations and generates predictions
Supported Operations:
- Elementwise: ADD, SUBTRACT, MULTIPLY, DIVIDE
- Activation: RELU, SIGMOID, TANH, LEAKY_RELU, ELU, SELU, LINEAR, BINARY
- Normalization: LAYER_NORM, RMS_NORM, BATCH_NORM
- Pooling: MAX_POOLING, AVG_POOLING
- Matmul: Matrix multiplication operations
Prediction Strategy:
- Uses linear models from
linear_models.py - Handles shape transformations (1D, 2D, 3D+)
- Applies dimension-specific corrections for edge cases
Defines operation taxonomy and types.
Enumerations:
OperationType: High-level operation categoriesOperationElementwise: Element-wise operationsOperationActivation: Activation functionsOperationNormalization: Normalization operationsOperationPooling: Pooling operationsOperationMatmul: Matrix multiplication types
Contains JAX implementations of validation kernels.
Key Enumerations:
KernelType: Maps to specific JAX kernel implementationsScaleSimTopologyType: GEMM vs CONV topology types
Example Kernels:
- Matrix operations:
validation_matrix_multiply,validation_dot_product - Activations:
validation_relu,validation_sigmoid,validation_tanh - Normalizations:
validation_layer_norm,validation_batch_norm,validation_rms_norm - Pooling:
validation_max_pooling,validation_avg_pooling
Pre-trained linear regression models for latency prediction.
Contains operation-specific models like:
linear_model_elementwise_add_1d(size)linear_model_activation_relu_2d(size)linear_model_matmul(m, n, k)linear_model_normalization_layer_norm_2d(size)
Helper functions to generate ValidationConfig objects.
Example Functions:
generate_matrix_multiply_config(name, M, N, K)
generate_layer_norm_config(name, shape, axis)
generate_max_pooling_config(name, shape, window_shape, strides, padding)Utility classes and functions.
Key Classes:
DataFrameGenerator: Flexible DataFrame construction with column alignment
Parses JAX profiling traces.
Key Class:
TraceParser: Extracts and processes profiling data from trace directories
After running verification, results are saved to the specified directory (e.g., ./unified_verification_results/):
merged_verification_results.csv: Complete verification results with columns:Operation_Type: Category (elementwise, activation, matmul, etc.)Operation: Specific operation (ADD, RELU, LINEAR, etc.)Input_Shapes: Input tensor shapesPredicted_Latency_us: Model prediction in microsecondsActual_Duration_us: Measured GPU execution timeError_Percentage: Prediction error percentage
filtered_events_avg_fusion.csv: Average fusion kernel durationskernel_name: Configuration namedur(us): Duration in microseconds
Individual trace directories for each configuration containing:
trace_events.json: Raw JAX profiling eventsfiltered_events.json: Filtered relevant events
from unified_model_verification import run_unified_verification
# Run comprehensive verification
results = run_unified_verification()
# Results DataFrame includes:
# - Predicted vs Actual latencies
# - Error percentages
# - Operation metadataimport operation_classification as oc
from model_verification import ModelVerification
# Create custom verification
verifier = ModelVerification(profile_dir="./my_results")
# Add custom configurations
verifier.add_verification_config(
operation_type=oc.OperationType.MATMUL,
operation=oc.OperationMatmul.LINEAR,
shapes=[(512, 256), (256, 512)],
operation_params={'M': 512, 'N': 512, 'K': 256}
)
# Run verification
results = verifier.verify()from flexible_validation import ValidationManager, ValidationConfig
from kernel_functions import KernelType
import jax.numpy as jnp
# Create validation manager
manager = ValidationManager(profile_dir="./my_traces")
# Add configuration
config = ValidationConfig(
name="my_matmul",
kernel_type=KernelType.MATRIX_MULTIPLY,
inputs=[((256, 128), jnp.float16), ((128, 256), jnp.float16)]
)
manager.add_config(config)
# Profile operations
manager.profile_all_packages(repeat=5)
manager.parse_all_packages()
# Get results
df = manager.get_filtered_events_dataframe_for_avg_fusion_duration()The framework calculates several performance metrics:
- MAPE (Mean Absolute Percentage Error): Average absolute prediction error
- RMSE (Root Mean Square Error): RMS of prediction errors in microseconds
- Min/Max Error: Range of prediction errors
Example output from run_unified_verification():
============================================================
DETAILED ANALYSIS BY OPERATION TYPE
============================================================
ADD | Tests: 11 | MAPE: 8.23% | RMSE: 45.67 μs | Range: 2.1%-15.8%
RELU | Tests: 9 | MAPE: 5.45% | RMSE: 32.18 μs | Range: 1.2%-12.3%
MATMUL | Tests: 9 | MAPE: 12.67% | RMSE: 156.89 μs | Range: 3.4%-28.9%
LAYER_NORM | Tests: 7 | MAPE: 7.89% | RMSE: 78.45 μs | Range: 2.8%-18.6%
============================================================
OVERALL STATISTICS
============================================================
Total test cases: 36
Overall MAPE: 8.56%
Overall RMSE: 78.30 μs
🏆 Best prediction:
activation - relu
Shape: [(512,)]
Error: 1.23%
⚠️ Worst prediction:
matmul - linear
Shape: [(1024, 768), (768, 512)]
Error: 28.92%
Install all required dependencies using pip:
pip install -r requirements.txtThe framework is configured for TPU by default. The requirements.txt includes jax[tpu] for TPU support.
For CPU-only execution (testing/development):
pip install jax>=0.4.0- JAX (>=0.4.0): For TPU kernel execution and profiling
- pandas (>=1.5.0): Data manipulation and CSV output
- numpy (>=1.23.0): Numerical computations
- Python 3.8+ required
package = ValidationPackage(config)
package.setup_validation()
stablehlo_ir = package.get_stableHLO_string()manager = ValidationManager()
# ... add configs ...
manager.write_scale_sim_topology_csv()
# Generates: scale_sim_gemm_topology.csv, scale_sim_conv_topology.csvpackage.parse_profile_trace()
filtered_events = package.filter_profile_trace_events(trace_events)- 1D shapes: For vector operations (elementwise, activations)
- 2D shapes: For matrices and 2D operations
- 3D shapes: For normalization (batch, sequence, hidden)
- 4D shapes: For pooling/convolution (batch, channels, height, width)
Different operations support various parameters:
# Layer Norm: axis parameter
operation_params={'axis': (2,)}
# Pooling: window and stride
operation_params={'window_shape': (2, 2), 'strides': (2, 2), 'padding': 'VALID'}-
JAX Not Found: Ensure JAX is installed with TPU support
pip install jax[tpu]>=0.4.0 -
TPU Not Detected: Verify TPU access and JAX installation
import jax print(jax.devices()) # Should show TPU devices
-
Profile Directory Errors: Directory is created automatically, but ensure write permissions
-
Shape Mismatches: Verify input shapes match operation requirements
- Matmul:
(M, K)and(K, N) - Pooling:
(N, C, H, W)format
- Matmul:
-
Missing Results: Check that
model_verification.pyis properly imported and available
When adding new operations:
- Add kernel function to
kernel_functions.py - Add enum to
operation_classification.py - Implement prediction model in
latency_prediction.py - Add linear model coefficients to
linear_models.py - Create config generator in
kernel_configs.py - Add test case to
unified_model_verification.py
- Fork the repository
- Create a feature branch (
git checkout -b feature/your-feature) - Commit your changes (
git commit -am 'Add new feature') - Push to the branch (
git push origin feature/your-feature) - Open a Pull Request
This project is licensed under the MIT License - see the LICENSE file for details.
Part of the SCALE-Sim Project.