| 
 | 1 | +# Composite action to download the jax and jaxlib wheels  | 
 | 2 | +name: Download JAX CPU wheels  | 
 | 3 | + | 
 | 4 | +inputs:  | 
 | 5 | +  runner:  | 
 | 6 | +    description: "Which runner type should the wheels be downloaded for?"  | 
 | 7 | +    type: string  | 
 | 8 | +    default: "linux-x86-n4-16"  | 
 | 9 | +  python:  | 
 | 10 | +    description: "Which python version should the artifact be downloaded for?"  | 
 | 11 | +    required: true  | 
 | 12 | +    type: string  | 
 | 13 | +  jaxlib-version:  | 
 | 14 | +    description: "Which jaxlib version to download? (head/pypi_latest)"  | 
 | 15 | +    type: string  | 
 | 16 | +    default: "head"  | 
 | 17 | +  skip-download-jaxlib-from-gcs:  | 
 | 18 | +    description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)"  | 
 | 19 | +    default: '0'  | 
 | 20 | +    type: string  | 
 | 21 | +  gcs_download_uri:  | 
 | 22 | +    description: "GCS location prefix from where the artifacts should be downloaded"  | 
 | 23 | +    default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'  | 
 | 24 | +    type: string  | 
 | 25 | +permissions: {}  | 
 | 26 | +runs:  | 
 | 27 | +  using: "composite"  | 
 | 28 | + | 
 | 29 | +  steps:  | 
 | 30 | +    # Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.  | 
 | 31 | +    - name: Set env vars for use in artifact download URL  | 
 | 32 | +      shell: bash  | 
 | 33 | +      run: |  | 
 | 34 | +        os=$(uname -s | awk '{print tolower($0)}')  | 
 | 35 | +        arch=$(uname -m)  | 
 | 36 | +
  | 
 | 37 | +        # Adjust os and arch for Windows  | 
 | 38 | +        if [[  $os  =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then  | 
 | 39 | +          os="win"  | 
 | 40 | +          arch="amd64"  | 
 | 41 | +        fi  | 
 | 42 | +
  | 
 | 43 | +        # Get the major and minor version of Python.  | 
 | 44 | +        # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310  | 
 | 45 | +        # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t  | 
 | 46 | +        python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')  | 
 | 47 | +
  | 
 | 48 | +        echo "OS=${os}" >> $GITHUB_ENV  | 
 | 49 | +        echo "ARCH=${arch}" >> $GITHUB_ENV  | 
 | 50 | +        # Python wheels follow a naming convention: standard wheels use the pattern  | 
 | 51 | +        # `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use  | 
 | 52 | +        # `*-cp<py_version>-cp<py_version>t-*`.  | 
 | 53 | +        echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV  | 
 | 54 | +    - name: Download wheels from GCS (non-Windows runs)  | 
 | 55 | +      shell: bash  | 
 | 56 | +      id: download-wheel-artifacts-nw  | 
 | 57 | +      # Set continue-on-error to true to prevent actions from failing the workflow if this step  | 
 | 58 | +      # fails. Instead, we verify the outcome in the step below so that we can print a more  | 
 | 59 | +      # informative error message.  | 
 | 60 | +      continue-on-error: true  | 
 | 61 | +      if: ${{ !contains(inputs.runner, 'windows-x86') }}  | 
 | 62 | +      run: |  | 
 | 63 | +        mkdir -p $(pwd)/dist  | 
 | 64 | +        gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/  | 
 | 65 | +
  | 
 | 66 | +        if [[ "${{ inputs.skip-download-jaxlib-from-gcs }}" == "1" ]]; then  | 
 | 67 | +          echo "JAX only release. Only downloading the jax wheel from the release bucket."  | 
 | 68 | +        else  | 
 | 69 | +          if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then  | 
 | 70 | +            gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/  | 
 | 71 | +          elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then  | 
 | 72 | +            PYTHON=python${{ inputs.python }}  | 
 | 73 | +            $PYTHON -m pip download jaxlib --dest $(pwd)/dist/  | 
 | 74 | +          else  | 
 | 75 | +            echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"  | 
 | 76 | +            exit 1  | 
 | 77 | +          fi  | 
 | 78 | +        fi  | 
 | 79 | +    - name: Download wheels from GCS (Windows runs)  | 
 | 80 | +      shell: cmd  | 
 | 81 | +      id: download-wheel-artifacts-w  | 
 | 82 | +      # Set continue-on-error to true to prevent actions from failing the workflow if this step  | 
 | 83 | +      # fails. Instead, we verify the outcome in step below so that we can print a more  | 
 | 84 | +      # informative error message.  | 
 | 85 | +      continue-on-error: true  | 
 | 86 | +      if: ${{ contains(inputs.runner, 'windows-x86') }}  | 
 | 87 | +      run: |  | 
 | 88 | +        mkdir dist  | 
 | 89 | +        @REM Use `call` so that we can run sequential gcloud storage commands on Windows  | 
 | 90 | +        @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652  | 
 | 91 | +        call gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/  | 
 | 92 | +
  | 
 | 93 | +        if "${{ inputs.skip-download-jaxlib-from-gcs }}"=="1" (  | 
 | 94 | +          echo "JAX only release. Only downloading the jax wheel from the release bucket."  | 
 | 95 | +        ) else (  | 
 | 96 | +          call gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/  | 
 | 97 | +        )  | 
 | 98 | +    - name: Skip the test run if the wheel artifacts were not downloaded successfully  | 
 | 99 | +      shell: bash  | 
 | 100 | +      if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'  | 
 | 101 | +      run: |  | 
 | 102 | +        echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"  | 
 | 103 | +        echo "built successfully by the artifact build jobs and are available in the GCS bucket."  | 
 | 104 | +        echo "Skipping the test run."  | 
 | 105 | +        exit 1  | 
0 commit comments