GPU-Native Bayesian Inference with JAX and BlackJAX
π Essential Reading: For the authoritative reference on blackjax nested sampling theory and applications, see the Nested Sampling Book by David Yallup.
"Nested sampling is a Bayesian computational technique that solves the key problem of evidence evaluation" β from the Nested Sampling Book
π Click here to run the workshop interactively in Google Colab
π Click here to preview the workshop with all plots and outputs
This interactive workshop demonstrates modern nested sampling using BlackJAX, a GPU-native probabilistic programming library built on JAX. The modular design allows flexible delivery from a 20-minute core workshop to a comprehensive 110-minute session covering advanced topics. Learn how to leverage automatic differentiation and JIT compilation for high-performance Bayesian inference.
workshop-blackjax-nested-sampling/
βββ workshop_nested_sampling.py # Source script (development)
βββ workshop_nested_sampling.ipynb # Clean interactive notebook
βββ workshop_nested_sampling_executed.ipynb # Pre-executed with outputs
βββ CLAUDE.md # Claude development guidance
βββ CLAUDE_WORKSHOP_TEMPLATE.md # Workshop development template
βββ README.md # This file
βββ development/ # Development materials
βββ docs/ # Development documentation
βββ history/ # Development conversation logs
βββ reference-materials/ # Source materials and examples
βββ scripts/ # Helper scripts and utilities
Core Duration: 20 minutes (suitable for talks)
Full Duration: 110 minutes (core + extensions)
Format: Hands-on Jupyter notebook
Platform: Runnable in Google Colab (no local installation required)
Core Workshop (20 minutes):
- GPU-Native Nested Sampling with BlackJAX
- JAX Integration for automatic differentiation and JIT compilation
- Anesthetic Visualization for professional nested sampling post-processing
- Performance Comparisons between different sampling algorithms
Advanced Extensions (90 minutes optional): 5. Custom Sampler Development using BlackJAX's modular components 6. JAX Ecosystem Integration with gradient descent and optimization 7. Simulation-Based Inference with neural posterior estimation
- Basic nested sampling workflow
- Evidence computation and uncertainty quantification
- Posterior visualization with true value overlays
- Multivariate parameter estimation
- Correlation coefficient inference with proper transforms
- Advanced anesthetic plotting techniques
- BlackJAX nested sampling vs. NUTS (Hamiltonian Monte Carlo)
- Timing benchmarks and sampler trade-offs
- When to use nested sampling vs. other methods
- Understanding BlackJAX's modular architecture
- Implementing custom MCMC kernels and adaptive schemes
- Research applications and specialized sampling strategies
- Gradient-based optimization with Optax
- Image-based inference problems
- Complementary strengths of different approaches
- Neural posterior estimation with Flax
- Amortized inference and training workflows
- Modern SBI vs. traditional Bayesian methods
Click the "Open in Colab" badge above to run the clean, interactive workshop in your browser. No installation required!
Click the "View on GitHub" badge to see the workshop with all plots and outputs pre-executed for quick reference.
# Clone the repository
git clone https://github.com/handley-lab/workshop-blackjax-nested-sampling.git
cd workshop-blackjax-nested-sampling
# Core dependencies (required for Parts 1-3)
pip install git+https://github.com/handley-lab/blackjax
pip install anesthetic tqdm matplotlib jupyter
# Advanced extensions (optional for Parts 4-6)
pip install optax flax
# Launch the notebook (clean version)
jupyter notebook workshop_nested_sampling.ipynb
# Or view the executed version with plots
jupyter notebook workshop_nested_sampling_executed.ipynb
# Run the standalone Python script
python workshop_nested_sampling.py
- JAX: Automatic differentiation and JIT compilation
- BlackJAX: GPU-native MCMC and nested sampling
- Anesthetic: Nested sampling visualization and analysis
- NumPy/SciPy: Scientific computing foundations
- Matplotlib: Plotting and visualization
- Optax: Gradient-based optimization (Part 5)
- Flax: Neural networks and SBI (Part 6)
- Additional JAX ecosystem packages
The workshop generates several publication-ready visualizations:
- Data Visualization: Synthetic datasets with true model overlays
- Posterior Plots: Corner plots with true parameter markers
- Performance Comparisons: Sampler timing and accuracy benchmarks
- Evidence Computation: Bayesian model comparison metrics
- Python Experience: Basic familiarity with NumPy/SciPy
- Bayesian Inference: Understanding of posteriors, priors, and likelihoods
- Optional: Previous exposure to MCMC methods (helpful but not required)
- JAX Backend: Automatic vectorization and GPU acceleration
- JIT Compilation: Near-compiled performance from Python code
- Automatic Differentiation: Efficient gradient computation
- Anesthetic Integration: Industry-standard nested sampling post-processing
- Evidence Computation: Natural Bayesian model comparison
- Parameter Transforms: Proper handling of constrained parameters
- Progressive Complexity: From simple line fitting to multivariate inference
- Hands-on Examples: Interactive code cells with immediate feedback
- Performance Insights: Real timing comparisons between methods
- Main Repository: handley-lab/blackjax
- Nested Sampling Branch: Focus on the
nested_sampling
branch for latest features
- Documentation: anesthetic.readthedocs.io
- Plotting Examples: Comprehensive visualization gallery
- JAX Documentation: jax.readthedocs.io
- Scientific Computing: Auto-differentiation and JIT compilation tutorials
This workshop was developed for the SBI Galaxy Evolution 2025 conference. Contributions and improvements are welcome!
- Event: SBI Galev 2025 (sbi-galev.github.io/2025)
- Session: Nested sampling for simulation-based inference
- Builds On: JAX/SciML workshop content by Viraj Pandya
This workshop is open-source and available for educational use. Please see individual dependency licenses for JAX, BlackJAX, and Anesthetic.
After completing this workshop, consider:
- Apply to Your Data: Use BlackJAX nested sampling on your research problems
- Explore Other Samplers: Try BlackJAX's HMC, NUTS, and MALA implementations
- GPU Acceleration: Run on TPUs/GPUs for large-scale inference problems
- Model Comparison: Use evidence computation for Bayesian model selection
- Community: Join discussions in the BlackJAX and JAX communities
Workshop Development: Generated with Claude Code β’ Author: Will Handley β’ Institution: University of Cambridge