|
| 1 | +# Composite action to download the jax and jaxlib wheels |
| 2 | +name: Download JAX CPU wheels |
| 3 | + |
| 4 | +inputs: |
| 5 | + runner: |
| 6 | + description: "Which runner type should the wheels be downloaded for?" |
| 7 | + type: string |
| 8 | + default: "linux-x86-n4-16" |
| 9 | + python: |
| 10 | + description: "Which python version should the artifact be downloaded for?" |
| 11 | + required: true |
| 12 | + type: string |
| 13 | + jaxlib-version: |
| 14 | + description: "Which jaxlib version to download? (head/pypi_latest)" |
| 15 | + type: string |
| 16 | + default: "head" |
| 17 | + skip-download-jaxlib-from-gcs: |
| 18 | + description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)" |
| 19 | + default: '0' |
| 20 | + type: string |
| 21 | + gcs_download_uri: |
| 22 | + description: "GCS location prefix from where the artifacts should be downloaded" |
| 23 | + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' |
| 24 | + type: string |
| 25 | +permissions: {} |
| 26 | +runs: |
| 27 | + using: "composite" |
| 28 | + |
| 29 | + steps: |
| 30 | + # Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow. |
| 31 | + - name: Set env vars for use in artifact download URL |
| 32 | + shell: bash |
| 33 | + run: | |
| 34 | + os=$(uname -s | awk '{print tolower($0)}') |
| 35 | + arch=$(uname -m) |
| 36 | +
|
| 37 | + # Adjust os and arch for Windows |
| 38 | + if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then |
| 39 | + os="win" |
| 40 | + arch="amd64" |
| 41 | + fi |
| 42 | +
|
| 43 | + # Get the major and minor version of Python. |
| 44 | + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 |
| 45 | + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t |
| 46 | + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') |
| 47 | +
|
| 48 | + echo "OS=${os}" >> $GITHUB_ENV |
| 49 | + echo "ARCH=${arch}" >> $GITHUB_ENV |
| 50 | + # Python wheels follow a naming convention: standard wheels use the pattern |
| 51 | + # `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use |
| 52 | + # `*-cp<py_version>-cp<py_version>t-*`. |
| 53 | + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV |
| 54 | + - name: Download wheels from GCS (non-Windows runs) |
| 55 | + shell: bash |
| 56 | + id: download-wheel-artifacts-nw |
| 57 | + # Set continue-on-error to true to prevent actions from failing the workflow if this step |
| 58 | + # fails. Instead, we verify the outcome in the step below so that we can print a more |
| 59 | + # informative error message. |
| 60 | + continue-on-error: true |
| 61 | + if: ${{ !contains(inputs.runner, 'windows-x86') }} |
| 62 | + run: | |
| 63 | + mkdir -p $(pwd)/dist |
| 64 | + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ |
| 65 | +
|
| 66 | + if [[ "${{ inputs.skip-download-jaxlib-from-gcs }}" == "1" ]]; then |
| 67 | + echo "JAX only release. Only downloading the jax wheel from the release bucket." |
| 68 | + else |
| 69 | + if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then |
| 70 | + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ |
| 71 | + elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then |
| 72 | + PYTHON=python${{ inputs.python }} |
| 73 | + $PYTHON -m pip download jaxlib --dest $(pwd)/dist/ |
| 74 | + else |
| 75 | + echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}" |
| 76 | + exit 1 |
| 77 | + fi |
| 78 | + fi |
| 79 | + - name: Download wheels from GCS (Windows runs) |
| 80 | + shell: cmd |
| 81 | + id: download-wheel-artifacts-w |
| 82 | + # Set continue-on-error to true to prevent actions from failing the workflow if this step |
| 83 | + # fails. Instead, we verify the outcome in step below so that we can print a more |
| 84 | + # informative error message. |
| 85 | + continue-on-error: true |
| 86 | + if: ${{ contains(inputs.runner, 'windows-x86') }} |
| 87 | + run: | |
| 88 | + mkdir dist |
| 89 | + @REM Use `call` so that we can run sequential gcloud storage commands on Windows |
| 90 | + @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 |
| 91 | + call gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ |
| 92 | +
|
| 93 | + if "${{ inputs.skip-download-jaxlib-from-gcs }}"=="1" ( |
| 94 | + echo "JAX only release. Only downloading the jax wheel from the release bucket." |
| 95 | + ) else ( |
| 96 | + call gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ |
| 97 | + ) |
| 98 | + - name: Skip the test run if the wheel artifacts were not downloaded successfully |
| 99 | + shell: bash |
| 100 | + if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' |
| 101 | + run: | |
| 102 | + echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" |
| 103 | + echo "built successfully by the artifact build jobs and are available in the GCS bucket." |
| 104 | + echo "Skipping the test run." |
| 105 | + exit 1 |
0 commit comments