Skip to content
148 changes: 71 additions & 77 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,94 +13,89 @@ env:

jobs:
prek:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v6
- name: prek check
uses: j178/prek-action@v1
with:
extra-args: --all-files --skip ruff --skip ruff-format --skip ty --skip mypy


lint:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.13"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y pandoc gsfonts
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install '.[doc,test]'
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
uv pip install --upgrade jaxlib jax
uv pip install --upgrade '.[doc,test]'
uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --upgrade -r docs/requirements.txt
uv pip freeze
- name: Lint with mypy and ruff
run: |
make lint
uv run make lint
- name: Build documentation
run: |
make docs
uv run make docs
- name: Test documentation
run: |
make doctest
python -m doctest -v README.md

uv run make doctest
uv run python -m doctest -v README.md

test-modeling:

runs-on: ubuntu-latest
needs: [lint, prek]
strategy:
matrix:
python-version: ["3.11", "3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
update-path: true
enable-cache: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y graphviz
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e '.[dev,test]'
pip freeze
# See: https://github.com/pyro-ppl/pyro-api/pull/26
# uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip
uv pip install --upgrade jaxlib jax
uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --upgrade -e '.[dev,test]'
uv pip freeze
- name: Test with pytest
run: |
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
CI=1 uv run pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum"
JAX_ENABLE_X64=1 uv run pytest -vs test/test_distributions.py -k "powerLaw or Dagum"
- name: Test tracer leak
if: matrix.python-version == '3.13'
env:
JAX_CHECK_TRACER_LEAKS: 1
run: |
pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
pytest -vs test/test_distributions.py::test_mean_var -k Gompertz

uv run pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
uv run pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
uv run pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
uv run pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
uv run pytest -vs test/test_distributions.py::test_mean_var -k Gompertz
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
uses: coverallsapp/github-action@v2
Expand All @@ -109,51 +104,53 @@ jobs:
parallel: true
flag-name: test-modeling


test-inference:

runs-on: ubuntu-latest
needs: [lint, prek]
strategy:
matrix:
python-version: ["3.11", "3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
update-path: true
python-version: ${{ matrix.python-version }}
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e '.[dev,test]'
pip freeze
# See: https://github.com/pyro-ppl/pyro-api/pull/26
# uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip
uv pip install --upgrade jaxlib jax
uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --upgrade -e '.[dev,test]'
uv pip freeze
- name: Test with pytest
run: |
pytest -vs --durations=20 test/infer/test_mcmc.py
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
uv run pytest -vs --durations=20 test/infer/test_mcmc.py
uv run pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
uv run pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64
JAX_ENABLE_X64=1 uv run pytest -vs test/infer/test_mcmc.py -k x64
- name: Test chains
run: |
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/stochastic_support/test_dcc.py
XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
- name: Test custom prng
run: |
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
JAX_ENABLE_CUSTOM_PRNG=1 uv run pytest -vs test/infer/test_mcmc.py
- name: Test nested sampling
run: |
JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py
JAX_ENABLE_X64=1 uv run pytest -vs test/contrib/test_nested_sampling.py
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
uses: coverallsapp/github-action@v2
Expand All @@ -162,32 +159,32 @@ jobs:
parallel: true
flag-name: test-inference


examples:

runs-on: ubuntu-latest
needs: [lint, prek]
strategy:
matrix:
python-version: ["3.13"]
env:
UV_PYTHON: ${{ matrix.python-version }}

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
update-path: true
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e '.[dev,examples,test]'
pip freeze
uv pip install --upgrade jaxlib jax
uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip
uv pip install --upgrade -e '.[dev,examples,test]'
uv pip freeze
- name: Test with pytest
run: |
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs -k test_example
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
uses: coverallsapp/github-action@v2
Expand All @@ -196,9 +193,7 @@ jobs:
parallel: true
flag-name: examples


finish:

needs: [test-modeling, test-inference, examples]
runs-on: ubuntu-latest
steps:
Expand All @@ -208,4 +203,3 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel-finished: true
carryforward: "test-modeling,test-inference,examples"

Loading