Skip to content

Commit

Permalink
Merge pull request #3996 from jakevdp:oldest-jax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646664157
  • Loading branch information
Flax Authors committed Jun 26, 2024
2 parents e5cb2f7 + 58020c4 commit 3b21870
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,18 @@ jobs:
matrix:
python-version: ['3.9', '3.10', '3.11']
test-type: [doctest, pytest, pytype, mypy]
jax-version: [newest]
exclude:
- test-type: pytype
python-version: '3.9'
- test-type: pytype
python-version: '3.10'
- test-type: mypy
python-version: '3.11'
include:
- python-version: '3.9'
test-type: pytest
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -119,7 +124,13 @@ jobs:
- name: Install Flax
run: |
venv/bin/python3 -m pip install -e .[all,testing]
venv/bin/python3 -m pip install -U jax jaxlib # Ensure we have the latest JAX
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
venv/bin/python3 -m pip install -U jax jaxlib
else
venv/bin/python3 -m pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Cached mypy cache
id: mypy_cache
uses: actions/cache@v3
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"numpy>=1.22",
"numpy>=1.23.2; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"jax>=0.4.19",
"jax>=0.4.27", # keep in sync with jax-version in .github/workflows/build.yml
"msgpack",
"optax",
"orbax-checkpoint",
Expand Down

0 comments on commit 3b21870

Please sign in to comment.