diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..03490db50 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,3 @@ +[*] +trim_trailing_whitespace = true +insert_final_newline = true diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..d953c93dd --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,14 @@ +# ran black and isort for coherent code formatting +bfa0e33294f2b1dc25e65a33be2397f989824298 + +# reran black with linelength 80 for greater readability +ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 + +# Remove f-prefix from strings that don't use formatting +7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 + +# format tests/linear_4bit.py +34735ba89de8235ea9da6ef409f814dcea9e2038 + +# Reformat with ruff-format +5a4263f4dc05fe8f78f4111beab9f68a81deeab1 diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index ac8e9de00..6ae3c7c0a 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -18,15 +18,15 @@ body: label: Reproduction description: | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. - Please provide the simplest reproducer as possible so that we can quickly fix the issue. + Please provide the simplest reproducer as possible so that we can quickly fix the issue. placeholder: | - Reproducer: - + Reproducer: + - type: textarea id: expected-behavior validations: required: true attributes: label: Expected behavior - description: "A clear and concise description of what you would expect to happen." \ No newline at end of file + description: "A clear and concise description of what you would expect to happen." diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml index 4e75c2a64..1dc2a298d 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.yml +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -1,6 +1,6 @@ name: "\U0001F680 Feature request" description: Submit a proposal/request for a new feature -labels: [ "feature" ] +labels: ["feature"] body: - type: textarea id: feature-request @@ -18,7 +18,7 @@ body: attributes: label: Motivation description: | - Please outline the motivation for the proposal. Is your feature request related to a problem? + Please outline the motivation for the proposal. Is your feature request related to a problem? - type: textarea id: contribution @@ -27,4 +27,4 @@ body: attributes: label: Your contribution description: | - Is there any way that you could help, e.g. by submitting a PR? \ No newline at end of file + Is there any way that you could help, e.g. by submitting a PR? diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..8a36c3689 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: pip + directory: "/" + schedule: + interval: "weekly" + groups: + major: + update-types: [major] + minor-patch: + update-types: [minor, patch] diff --git a/.github/scripts/auditwheel_show.py b/.github/scripts/auditwheel_show.py new file mode 100755 index 000000000..c9dd09cc2 --- /dev/null +++ b/.github/scripts/auditwheel_show.py @@ -0,0 +1,31 @@ +import argparse +import subprocess + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("wheels", nargs="*") + args = ap.parse_args() + if not args.wheels: + ap.error("At least one wheel must be provided.") + for whl in args.wheels: + print(f"### `{whl}`") + + audit_wheel_output = subprocess.run( + ["auditwheel", "show", whl], + capture_output=True, + text=True, + errors="backslashreplace", + ) + + if audit_wheel_output.stdout: + print(audit_wheel_output.stdout) + + if audit_wheel_output.stderr: + print(f"**Error:**\n```{audit_wheel_output.stderr}```") + + print("---") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/build-cpu.sh b/.github/scripts/build-cpu.sh new file mode 100644 index 000000000..6dc6a8ddf --- /dev/null +++ b/.github/scripts/build-cpu.sh @@ -0,0 +1,23 @@ +#!/bin/bash +declare build_arch +declare build_os + +set -xeuo pipefail + +pip install cmake==3.28.3 + +if [ "${build_os:0:6}" == ubuntu ] && [ "${build_arch}" == aarch64 ]; then + # Allow cross-compile on aarch64 + sudo apt-get update + sudo apt-get install -y gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu g++-aarch64-linux-gnu + cmake -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ -DCOMPUTE_BACKEND=cpu . +elif [ "${build_os:0:5}" == macos ] && [ "${build_arch}" == aarch64 ]; then + cmake -DCMAKE_OSX_ARCHITECTURES=arm64 -DCOMPUTE_BACKEND=cpu . +else + cmake -DCOMPUTE_BACKEND=cpu . +fi +cmake --build . --config Release + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh new file mode 100644 index 000000000..0f9b8d726 --- /dev/null +++ b/.github/scripts/build-cuda.sh @@ -0,0 +1,29 @@ +#!/bin/bash +declare build_arch +declare build_os +declare cuda_version + +set -xeuo pipefail +build_capability="50;52;60;61;70;75;80;86;89;90" +[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????} +[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???} +[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja +for NO_CUBLASLT in ON OFF; do + if [ "${build_os:0:6}" == ubuntu ]; then + image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 + echo "Using image $image" + docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \ + && cmake --build ." + else + pip install cmake==3.28.3 + cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . + cmake --build . --config Release + fi +done + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") diff --git a/.github/scripts/set_platform_tag.py b/.github/scripts/set_platform_tag.py new file mode 100644 index 000000000..c82077074 --- /dev/null +++ b/.github/scripts/set_platform_tag.py @@ -0,0 +1,32 @@ +import argparse +import platform +import sys + + +def get_platform_tag(architecture): + system = platform.system() + + if system == "Linux": + tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64" + elif system == "Darwin": + tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64" + elif system == "Windows": + tag = "win_amd64" if architecture == "x86_64" else "win_arm64" + else: + sys.exit(f"Unsupported system: {system}") + + return tag + + +def main(): + parser = argparse.ArgumentParser(description="Determine platform tag.") + parser.add_argument("arch", type=str, help="Architecture (e.g., x86_64, aarch64)") + args = parser.parse_args() + + tag = get_platform_tag(args.arch) + + print(tag) # This will be captured by the GitHub Actions workflow + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index 2921d70df..10272be87 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -8,11 +8,11 @@ on: - v*-release jobs: - build: + build: uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main with: commit_sha: ${{ github.sha }} package: bitsandbytes repo_owner: TimDettmers secrets: - token: ${{ secrets.HUGGINGFACE_PUSH }} \ No newline at end of file + hf_token: ${{ secrets.HUGGINGFACE_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index dace206b1..d6455fd11 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -9,9 +9,10 @@ concurrency: jobs: build: + if: github.repository == 'TimDettmers/bitsandbytes' uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} package: bitsandbytes - repo_owner: TimDettmers \ No newline at end of file + repo_owner: TimDettmers diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..01084d44f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,19 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + +jobs: + Lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.12" + - uses: pre-commit/action@v3.0.0 + env: + RUFF_OUTPUT_FORMAT: github diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 000000000..ba5961f72 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,203 @@ +name: Python package + +on: + push: {} + pull_request: + branches: [main] + paths: + - ".github/workflows/python-package.yml" + - "bitsandbytes/**" + - "csrc/**" + - "include/**" + - "tests/**" + - "CMakeLists.txt" + - "requirements*.txt" + - "setup.py" + - "pyproject.toml" + - "pytest.ini" + release: + types: [published] + workflow_dispatch: {} # Allow manual trigger + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + ## + # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. + ## + build-shared-libs: + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + arch: [x86_64, aarch64] + exclude: + - os: windows-latest # This probably requires arm64 Windows agents + arch: aarch64 + - os: ubuntu-latest # Temporary. Takes too long, not ready yet. + arch: aarch64 + runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents + steps: + - uses: actions/checkout@v4 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cpu.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_${{ matrix.os }}_${{ matrix.arch }} + path: output/* + retention-days: 7 + ## + # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) + ## + build-shared-libs-cuda: + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + arch: [x86_64, aarch64] + cuda_version: + ["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2"] + exclude: + - os: windows-latest # This probably requires arm64 Windows agents + arch: aarch64 + - os: ubuntu-latest # Temporary. Takes too long, not ready yet. + arch: aarch64 + runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents + steps: + - uses: actions/checkout@v4 + # Linux: We use Docker to build cross platform Cuda (aarch64 is built in emulation) + - name: Set up Docker multiarch + if: startsWith(matrix.os, 'ubuntu') + uses: docker/setup-qemu-action@v2 + # Windows: We install Cuda on the agent (slow) + - uses: Jimver/cuda-toolkit@v0.2.14 + if: startsWith(matrix.os, 'windows') + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda_version }} + method: "network" + sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' + linux-local-args: '["--toolkit"]' + use-github-cache: false + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cuda.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + cuda_version: ${{ matrix.cuda_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} + path: output/* + retention-days: 7 + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + # The specific Python version is irrelevant in this context as we are only packaging non-C extension + # code. This ensures compatibility across Python versions, including Python 3.8, as compatibility is + # dictated by the packaged code itself, not the Python version used for packaging. + python-version: ["3.10"] + arch: [x86_64, aarch64] + exclude: + - os: windows-latest # This probably requires arm64 Windows agents + arch: aarch64 + - os: ubuntu-latest # Temporary. Takes too long, not ready yet. + arch: aarch64 + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Download build artifact + uses: actions/download-artifact@v4 + with: + merge-multiple: true + pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" + path: output/ + - name: Copy correct platform shared library + shell: bash + run: | + ls -lR output/ + cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - run: pip install build wheel + - run: python -m build . + - name: Determine and Set Platform Tag, then Tag Wheel + shell: bash + run: | + PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") + echo "PLATFORM_TAG=$PLATFORM_TAG" + wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: dist/bitsandbytes-*.whl + retention-days: 7 + + audit-wheels: + needs: build-wheels + runs-on: ubuntu-latest + env: + PIP_DISABLE_PIP_VERSION_CHECK: 1 + steps: + - uses: actions/checkout@v4 + - name: Download all wheels + uses: actions/download-artifact@v4 + with: + merge-multiple: true + pattern: "bdist_wheel_*" + path: wheels/ + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install auditwheel + - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY + +# test: +# needs: +# - build-wheels +# strategy: +# fail-fast: false +# matrix: +# include: +# - os: ubuntu-latest +# arch: x86_64 +# python-version: "3.8" +# - os: windows-latest +# arch: x86_64 +# python-version: "3.8" +# runs-on: ${{ matrix.os }} +# steps: +# - uses: actions/checkout@v4 +# - uses: actions/download-artifact@v4 +# with: +# merge-multiple: true +# pattern: "bdist_wheel_${{ matrix.os }}_${{ matrix.arch }}*" +# path: wheel/ +# - uses: actions/setup-python@v5 +# with: +# python-version: ${{ matrix.python-version }} +# cache: pip +# - shell: bash +# run: ls -lar wheel/ +# - run: pip install wheel/*.whl -r requirements-ci.txt +# - run: pytest --log-cli-level=DEBUG --continue-on-collection-errors tests diff --git a/.github/workflows/stale.yml.disabled b/.github/workflows/stale.yml.disabled index ec011c7fb..0b4f789ea 100644 --- a/.github/workflows/stale.yml.disabled +++ b/.github/workflows/stale.yml.disabled @@ -24,4 +24,4 @@ jobs: pip install PyGithub - name: Close stale issues run: | - python scripts/stale.py \ No newline at end of file + python scripts/stale.py diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 000000000..6497caf2d --- /dev/null +++ b/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: + workflow_run: + workflows: ["Build PR Documentation"] + types: + - completed + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: bitsandbytes + secrets: + hf_token: ${{ secrets.HUGGINGFACE_PUSH }} + comment_bot_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 2f929968b..22f5a6cd6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,29 @@ __pycache__/ *.py[cod] *$py.class - -# C extensions *.so +*.dll +*.dylib +*.o +*.obj +*.air +*.metallib + +# CMake generated files +CMakeCache.txt +CMakeScripts/ +cmake_install.cmake +Makefile +CMakeFiles/ +*.sln +*.vcxproj* +*.xcodeproj/ +bitsandbytes.dir/ +Debug/ +Release/ + +# IDE local files +.vs/ # Distribution / packaging .Python @@ -133,4 +153,4 @@ dmypy.json dependencies cuda_build -.vscode/* +output/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..a859d05af --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.2 + hooks: + - id: ruff + args: + - --fix + - id: ruff-format + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-merge-conflict + - id: check-yaml + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: trailing-whitespace + - id: mixed-line-ending + args: + - --fix=lf + - repo: https://github.com/crate-ci/typos + rev: v1.18.2 + hooks: + - id: typos diff --git a/.style.yapf b/.style.yapf index a185235cf..e60ac16e5 100644 --- a/.style.yapf +++ b/.style.yapf @@ -10,4 +10,4 @@ SPLIT_BEFORE_BITWISE_OPERATOR = True SPLIT_BEFORE_FIRST_ARGUMENT = True SPLIT_BEFORE_LOGICAL_OPERATOR = True SPLIT_BEFORE_NAMED_ASSIGNS = True -SPLIT_COMPLEX_COMPREHENSION = True \ No newline at end of file +SPLIT_COMPLEX_COMPREHENSION = True diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 000000000..939843f43 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "ms-python.python", + "charliermarsh.ruff", + "twxs.cmake" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..906f28588 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "ruff.fixAll": true, + "ruff.lint.run": "onType", + "editor.codeActionsOnSave": { + "source.fixAll": "always" + } +} diff --git a/CHANGELOG.md b/CHANGELOG.md index c12443cf3..397dceb77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -342,3 +342,27 @@ Bug fixes: - Fixed a bug where kgetColRowStats (LLM.int8()) would fail for certain dimensions @LucQueen @905 - Fixed a bug where the adjusted regular Embedding layer was not available via bnb.nn.Embedding @neel04 #563 - Fixed added missing scipy requirement @dulalbert #525 + +### 0.43.0 + +#### Improvements and New Features: +- QLoRA + FSDP official support is now live! https://github.com/TimDettmers/bitsandbytes/pull/970 by @warner-benjamin and team - with FSDP you can train very large models (70b scale) on multiple 24GB consumer-type GPUs. See https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html for more details. +- Introduced improvements to the CI process for enhanced performance and efficiency during builds, specifically enabling more effective cross-compilation on Linux platforms. This was accomplished by deprecating Make and migrating to Cmake, as well as implementing new corresponding workflows. Huge thanks go to @wkpark, @rickardp, @matthewdouglas and @younesbelkada; #1055, #1050, #1111. +- Windows should be officially supported in bitsandbytes if you install the library from source. See: https://huggingface.co/docs/bitsandbytes/main/en/index for more details +- Updated installation instructions to provide more comprehensive guidance for users. This includes clearer explanations and additional tips for various setup scenarios, making the library more accessible to a broader audience (@rickardp, #1047). +- Enhanced the library's compatibility and setup process, including fixes for CPU-only installations and improvements in CUDA setup error messaging. This effort aims to streamline the installation process and improve user experience across different platforms and setups (@wkpark, @akx, #1038, #996, #1012). +- Setup a new documentation at https://huggingface.co/docs/bitsandbytes/main with extensive new sections and content to help users better understand and utilize the library. Especially notable are the new API docs. (big thanks to @stevhliu and @mishig25 from HuggingFace #1012). The API docs have been also addressed in #1075. + +#### Bug Fixes: +- Addressed a race condition in kEstimateQuantiles, enhancing the reliability of quantile estimation in concurrent environments (@pnunna93, #1061). +- Fixed various minor issues, including typos in code comments and documentation, to improve code clarity and prevent potential confusion (@Brian Vaughan, #1063). + +#### Internal and Build System Enhancements: +- Implemented several enhancements to the internal and build systems, including adjustments to the CI workflows, portability improvements, and build artifact management. These changes contribute to a more robust and flexible development process, ensuring the library's ongoing quality and maintainability (@rickardp, @akx, @wkpark, @matthewdouglas; #949, #1053, #1045, #1037). + +#### Contributors: +This release is made possible thanks to the many active contributors that submitted PRs and many others who contributed to discussions, reviews, and testing. Your efforts greatly enhance the library's quality and user experience. It's truly inspiring to work with such a dedicated and competent group of volunteers and professionals! + +We give a special thanks to @TimDettmers for managing to find a little bit of time for valuable consultations on critical topics, despite preparing for and touring the states applying for professor positions. We wish him the utmost success! + +We also extend our gratitude to the broader community for your continued support, feedback, and engagement, which play a crucial role in driving the library's development forward. diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..3bedefd51 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,308 @@ +# This CMake config hopefully makes it easier to compile. +# Ensure the CUDA Toolkit is available on your path. Then run: +# For GCC: `cmake -B build . && cmake --build build` +# For MSVC: `cmake -B build . && cmake --build build --config Release` +# You can also use the following options and variables +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend +# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support +# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version +# is whatever CMake finds on your path. +# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. +# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` +# Check your compute capability here: https://developer.nvidia.com/cuda-gpus +# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +cmake_minimum_required(VERSION 3.22.1) + +project(bitsandbytes LANGUAGES CXX) + +# If run without specifying a build type, default to using the Release configuration: +# optimizing the generated binaries for performance and also adds the `-DNDEBUG` flag, +# which turns off a bunch of asserts which seem to link to new symbols in libstdc++, +# worsening our many_linux compliance.. +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Define included source files +set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(HIP_FILES csrc/ops.hip csrc/kernels.hip) +set(MPS_FILES csrc/mps_ops.mm) +set(METAL_FILES csrc/mps_kernels.metal) +# C++ sources are always included +list(APPEND SRC_FILES ${CPP_FILES}) + +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) +option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) + +if(APPLE) + set(CMAKE_OSX_DEPLOYMENT_TARGET 13.1) +endif() + +set(BNB_OUTPUT_NAME "bitsandbytes") + +message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})") + +if(${COMPUTE_BACKEND} STREQUAL "cuda") + if(APPLE) + message(FATAL_ERROR "CUDA is not supported on macOS" ) + endif() + option(NO_CUBLASLT "Disable CUBLAS" OFF) + set(BUILD_CUDA ON) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + if(APPLE) + message(FATAL_ERROR "HIP is not supported on macOS" ) + endif() + option(NO_CUBLASLT "Disable HIPBLASLT" OFF) + set(BUILD_CUDA OFF) + set(BUILD_HIP ON) + set(BUILD_MPS OFF) + message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") +elseif(${COMPUTE_BACKEND} STREQUAL "mps") + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS ON) +else() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) +endif() + + +if(BUILD_CUDA) + enable_language(CUDA) # This will fail if CUDA is not found + find_package(CUDAToolkit REQUIRED) + + # Convert the CUDA version from X.Y.z to XY. There's probably a shorter way of doing this + string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") + string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") + + # Expose a cache variable that the user can set to ensure the correct version of CUDA is found + set(CUDA_VERSION "${CUDA_VERSION_SHORT}" CACHE STRING "Expected CUDA Version Shortcode") + + message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") + message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + + # It should match the discovered version + if(NOT CUDA_VERSION STREQUAL "${CUDA_VERSION_SHORT}") + message(FATAL_ERROR "You've specified CUDA version ${CUDA_VERSION} however the CUDA compiler found is ${CUDA_VERSION_SHORT}." + " Ensure the desired CUDA compiler is the first one available on your PATH." + ) + endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.0") + message(FATAL_ERROR "CUDA Version < 11 is not supported") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + message(FATAL_ERROR "CUDA Version > 12 is not supported") + endif() + + # CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL. + if(CMAKE_VERSION VERSION_LESS "3.23.0") + message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...") + + # 11.x and 12.x both support these at a minimum. + set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80) + set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80) + + # CUDA 11.1 adds Ampere support for GA102-GA107. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.1") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 86) + endif() + + # CUDA 11.4 adds Ampere support for GA10B. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.4") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 87) + endif() + + # CUDA 11.8 adds support for Ada and Hopper. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 89 90) + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 90) + endif() + endif() + + string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math") + + if(PTXAS_VERBOSE) + # Verbose? Outputs register usage information, and other things... + string(APPEND CMAKE_CUDA_FLAGS " -Xptxas=-v") + endif() + + foreach(capability ${CMAKE_CUDA_ARCHITECTURES_ALL}) + # Most of the items here are like: `xx-real`, so we just extract the `xx` portion + string(REGEX MATCH "[0-9]+" capability_id "${capability}") + if(capability_id GREATER 0) + list(APPEND POSSIBLE_CAPABILITIES ${capability_id}) + endif() + endforeach() + + # This can be changed via -D argument to CMake + # By default all possible capabilities are compiled + set(COMPUTE_CAPABILITY "${POSSIBLE_CAPABILITIES}" CACHE STRING "Compute Capabilities Targeted") + + message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}") + message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") + + # Use the "real" option to build native cubin for all selections. + # Ensure we build the PTX for the latest version. + # This behavior of adding a PTX (virtual) target for the highest architecture + # is similar to how the "all" and "all-major" options would behave in CMake >= 3.23. + # TODO: Consider bumping CMake requirement and using CMAKE_CUDA_ARCHITECTURES=[all | native] by default + list(REMOVE_DUPLICATES COMPUTE_CAPABILITY) + list(SORT COMPUTE_CAPABILITY COMPARE NATURAL) + list(POP_BACK COMPUTE_CAPABILITY _LATEST_CAPABILITY) + list(TRANSFORM COMPUTE_CAPABILITY APPEND "-real" OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES) + list(APPEND CMAKE_CUDA_ARCHITECTURES ${_LATEST_CAPABILITY}) + + message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") + + list(APPEND SRC_FILES ${CUDA_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + if(NO_CUBLASLT) + string(APPEND BNB_OUTPUT_NAME "_nocublaslt") + endif() + add_compile_definitions(BUILD_CUDA) +elseif(BUILD_HIP) + enable_language(HIP) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + if(DEFINED BNB_ROCM_ARCH) + set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH}) + else() + if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx940;gfx941;gfx942") + elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + endif() + message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") + + list(APPEND SRC_FILES ${HIP_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_hip") + + # get hip version + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") + + if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1") + string(APPEND BNB_OUTPUT_NAME "_nohipblaslt") + endif() + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + add_compile_definitions(BUILD_HIP) +elseif(BUILD_MPS) + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + + enable_language(OBJCXX) + + list(APPEND SRC_FILES ${MPS_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_mps") + add_compile_definitions(BUILD_MPS) + file(MAKE_DIRECTORY "build") + add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" + COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} + COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" + DEPENDS "${METAL_FILES}" + COMMENT "Compiling Metal kernels" + VERBATIM) + add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) +endif() + + +if(WIN32) + # Export all symbols + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + +# Weird MSVC hacks +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") +endif() + +set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) +add_library(bitsandbytes SHARED ${SRC_FILES}) +target_compile_features(bitsandbytes PUBLIC cxx_std_14) +target_include_directories(bitsandbytes PUBLIC csrc include) + + +if(BUILD_CUDA) + target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) + if(NO_CUBLASLT) + target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) + else() + target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) + endif() + + set_target_properties(bitsandbytes + PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + ) +endif() +if(BUILD_HIP) + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + endmacro() + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + + if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1") + target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) + else() + find_package(hipblaslt) + target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt) + endif() +endif() +if(BUILD_MPS) + add_dependencies(bitsandbytes metallib) + target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") +endif() + +if(WIN32) + set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") +endif() +set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME}) +if(MSVC) + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") +endif() + +set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes") diff --git a/Makefile b/Makefile deleted file mode 100644 index 00f5869b3..000000000 --- a/Makefile +++ /dev/null @@ -1,154 +0,0 @@ -MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) -ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) - -GPP:= /usr/bin/g++ -#GPP:= /sw/gcc/11.2.0/bin/g++ -ifeq ($(CUDA_HOME),) - CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) -endif - -ROCM_HOME := /opt/rocm - -ifndef CUDA_VERSION -ifneq ($(MAKECMDGOALS),clean) -$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU) -CUDA_VERSION:= -endif -endif - - - -NVCC := $(CUDA_HOME)/bin/nvcc -HIPCC := $(ROCM_HOME)/bin/hipcc - -########################################### - -CSRC := $(ROOT_DIR)/csrc -BUILD_DIR:= $(ROOT_DIR)/build - -FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu -FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c - -INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include -LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib - -INCLUDE_ROCM := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include -LIB_ROCM := -L $(ROCM_HOME)/lib -lhipblas -lhipblaslt -lhiprand -lhipsparse -L $(CONDA_PREFIX)/lib - -# NVIDIA NVCC compilation flags -COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell -COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell -COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal -COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal -COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta - -CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler -CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler - -# Later versions of CUDA support the new architectures -CC_CUDA11x := -gencode arch=compute_75,code=sm_75 -CC_CUDA11x += -gencode arch=compute_80,code=sm_80 -CC_CUDA11x += -gencode arch=compute_86,code=sm_86 - - -CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 -CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 - -CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 -CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 -CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 - -CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 -CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 - - -all: $(BUILD_DIR) env - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) - -cuda110_nomatmul_kepler: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda11x_nomatmul_kepler: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - - -cuda110_nomatmul: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda11x_nomatmul: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda118_nomatmul: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda12x_nomatmul: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda110: $(BUILD_DIR) env - $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) - -cuda11x: $(BUILD_DIR) env - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) - -cuda118: $(BUILD_DIR) env - $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) - -hip: $(BUILD_DIR) env - $(HIPCC) -std=c++14 -fPIC -c $(INCLUDE_ROCM) $(LIB_ROCM) $(CSRC)/ops.hip -o $(BUILD_DIR)/ops.o - $(HIPCC) -std=c++14 -fPIC -c $(INCLUDE_ROCM) $(LIB_ROCM) $(CSRC)/kernels.hip -o $(BUILD_DIR)/kernels.o - $(GPP) -std=c++14 -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ -DBUILD_HIP -shared -fPIC $(INCLUDE_ROCM) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so $(LIB_ROCM) - -cuda12x: $(BUILD_DIR) env - $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++20 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) - -cpuonly: $(BUILD_DIR) env - $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so - -env: - @echo "ENVIRONMENT" - @echo "============================" - @echo "CUDA_VERSION: $(CUDA_VERSION)" - @echo "============================" - @echo "NVCC path: $(NVCC)" - @echo "HIPCC path: $(HIPCC)" - @echo "GPP path: $(GPP) VERSION: `$(GPP) --version | head -n 1`" - @echo "CUDA_HOME: $(CUDA_HOME)" - @echo "HIP_HOME: $(HIP_HOME)" - @echo "CONDA_PREFIX: $(CONDA_PREFIX)" - @echo "PATH: $(PATH)" - @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" - @echo "============================" - -$(BUILD_DIR): - mkdir -p build - mkdir -p dependencies - -$(ROOT_DIR)/dependencies/cub: - git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub - cd dependencies/cub; git checkout 1.11.0 - -clean: - rm -rf build/* *.egg* - rm -f bitsandbytes/libbitsandbytes*.so diff --git a/README.md b/README.md index d73713c87..377ca2e86 100644 --- a/README.md +++ b/README.md @@ -1,170 +1,52 @@ -# bitsandbytes-rocm +# `bitsandbytes` -The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. -This fork is the ROCm adaptation of bitsandbytes 0.39.1. The repo is inspired by [agrocylo/bitsandbytes-rocm](https://github.com/agrocylo/bitsandbytes-rocm/tree/main/bitsandbytes), which is a ROCm version of bitsandbytes 0.37. While this fork incorporating the majority of features from bitsandbytes 0.39.1, including the crucial 4 bit quantization feature, certain features such as hipblaslt and hip_bfloat16 have been disabled. Enabling these features is listed as a task for the future. +[![Downloads](https://static.pepy.tech/badge/bitsandbytes)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/month)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/week)](https://pepy.tech/project/bitsandbytes) +The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions. +The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module. -Resources: -- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/) +**Installation for ROCm:** -- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) - -## TL;DR -**Requirements** -Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + ROCm >= 5.4.2 or CUDA > 10.0 - - -**Installation**: - - -You need to compile from source for ROCm. - -Compilation quickstart: +To install develop version: ```bash -# Run Docker -docker run -it --network=host --device=/dev/kfd --device=/dev/dri --name=bnb_test --shm-size=8g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --group-add video rocm/pytorch:latest +git clone --recurse https://github.com/ROCm/bitsandbytes +cd bitsandbytes +git checkout rocm_enabled +pip install -r requirements-dev.txt +cmake -DCOMPUTE_BACKEND=hip -S . (Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu arch) +make +pip install . +``` +For ROCm specific versions: -# Install Dependencies +Install Dependencies: +```bash +# hipblaslt installation needed only for rocm<6.0 apt install hipblaslt pip install --upgrade pip pip install einops lion_pytorch accelerate pip install git+https://github.com/ROCm/transformers.git - - -# Install BitsandBytes -git clone --recurse https://github.com/ROCmSoftwarePlatform/bitsandbytes +``` +Install Bitsandbytes: +```bash +git clone --recurse https://github.com/ROCm/bitsandbytes cd bitsandbytes # Checkout branch as needed -# for general use - rocm_enabled # for rocm 5.7 - rocm5.7_internal_testing -# for rocm 6.2 - rocm6.2_internal_testing -git checkout rocm_enabled +# for rocm 6.x - rocm6.2_internal_testing +git checkout make hip python setup.py install - - -# Run the unit test. If it runs successfully, the library has been installed successfully. -pytest -vvv ./tests/ 2>&1 | tee BitsAndBytes_UT_summary.log -``` - -**Using Int8 inference with HuggingFace Transformers** - -```python -from transformers import AutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained( - 'decapoda-research/llama-7b-hf', - device_map='auto', - load_in_8bit=True, - max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') -``` - -A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py). - -**Using 8-bit optimizer**: -1. Comment out optimizer: ``#torch.optim.Adam(....)`` -2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) -3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)`` - - -**Using 8-bit Inference**: -1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)`` -2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same) -3. There are two modes: - - Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default) - - Int8 inference. Pass the argument ``has_fp16_weights=False`` -4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``. -```python -# LLM.int8() -linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0) -# inputs need to be fp16 -out = linear(x.to(torch.float16)) -``` - - -## Features -- 8-bit Matrix multiplication with mixed precision decomposition -- LLM.int8() inference -- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) -- Stable Embedding Layer: Improved stability through better initialization, and normalization -- 8-bit quantization: Quantile, Linear, and Dynamic quantization -- Fast quantile estimation: Up to 100x faster than other algorithms - -## Using bitsandbytes - -### Using Int8 Matrix Multiplication - -For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: -```python -bnb.matmul(..., threshold=6.0) -``` - -For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://huggingface.co/blog/hf-bitsandbytes-integration). - -### Using the 8-bit Optimizers - -With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way: -```python -import bitsandbytes as bnb - -# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer -adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer -adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent - - -torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP models -``` - -Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so: -```python -# parameter tensors with less than 16384 values are optimized in 32-bit -# it is recommended to use multiplies of 4096 -adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) ``` -### Change Bits and other Hyperparameters for Individual Parameters - -If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details - -### Fairseq Users - -To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.). - -## Release and Feature History - -For upcoming features and changes and full history see [Patch Notes](CHANGELOG.md). - -## Errors +**For more details, please head to the official documentation page:** -1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available) -2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) +**[https://huggingface.co/docs/bitsandbytes/main](https://huggingface.co/docs/bitsandbytes/main)** ## License -The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. +The majority of bitsandbytes is licensed under MIT, however small portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license. We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. - -## How to cite us -If you found this library and found LLM.int8() useful, please consider citing our work: - -```bibtex -@article{dettmers2022llmint8, - title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale}, - author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, - journal={arXiv preprint arXiv:2208.07339}, - year={2022} -} -``` - -For 8-bit optimizers or quantization routines, please consider citing the following work: - -```bibtex -@article{dettmers2022optimizers, - title={8-bit Optimizers via Block-wise Quantization}, - author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke}, - journal={9th International Conference on Learning Representations, ICLR}, - year={2022} -} -``` diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 000000000..a04206b8d --- /dev/null +++ b/_typos.toml @@ -0,0 +1,11 @@ +[files] + +[default.extend-identifiers] + +[type.py.extend-words] +"BA" = "BA" # used as a commented-out variable in tests + +[type.cuda.extend-words] +"subtile" = "subtile" +"subtiles" = "subtiles" +"transation" = "transation" # TODO: is this transition, transaction, translation..? diff --git a/benchmarking/accuracy/bnb_accuracy.py b/benchmarking/accuracy/bnb_accuracy.py index bd3b81db4..2860338ec 100644 --- a/benchmarking/accuracy/bnb_accuracy.py +++ b/benchmarking/accuracy/bnb_accuracy.py @@ -1,8 +1,6 @@ import torch -import bitsandbytes as bnb -from bitsandbytes import functional as F - +from bitsandbytes import functional as F def debug_blocksize(block): @@ -11,6 +9,7 @@ def debug_blocksize(block): dq = F.dequantize_fp4(qx, qstate) return torch.sum(torch.linalg.norm(x - dq, ord="fro")) + def test_blocksize(block): x = torch.randn(10, 10).cuda() qx, qstate = F.quantize_fp4(x, blocksize=block) @@ -20,10 +19,8 @@ def test_blocksize(block): print("---------------") print(qstate) - - for block in [128, 256, 512, 1024, 2048]: print(debug_blocksize(block)) -#test_blocksize(2048) +# test_blocksize(2048) diff --git a/benchmarking/switchback/README.md b/benchmarking/switchback/README.md index bb33b5bbd..b73569030 100644 --- a/benchmarking/switchback/README.md +++ b/benchmarking/switchback/README.md @@ -1,4 +1,4 @@ Steps: 1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). -2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. \ No newline at end of file +2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py index 8897564e7..fd0dd7d58 100644 --- a/benchmarking/switchback/make_plot_with_jsonl.py +++ b/benchmarking/switchback/make_plot_with_jsonl.py @@ -1,15 +1,11 @@ +import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import pandas as pd -import numpy as np -import os - -import matplotlib.gridspec as gridspec -cmap=plt.get_cmap('cool') +cmap = plt.get_cmap("cool") -if __name__ == '__main__': - - fig = plt.figure(tight_layout=True, figsize=(12,3.5)) +if __name__ == "__main__": + fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) gs = gridspec.GridSpec(1, 2) dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] @@ -21,25 +17,28 @@ ax = fig.add_subplot(gs[0, 0]) # TODO: change this to what you want. - rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) + rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) df = rdf[rdf.batch_size == batch_size_for_plot1] # first plot the time occupied by different operations for k, marker, ls, color, name in [ - ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), - ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), - - ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), - ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), - ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), - - ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), - ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), - - ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), - ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), - ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), - ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), + ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), + ( + "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", + "o", + "-", + "C4", + "SwitchBack int8 (sum of parts)", + ), + ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), + ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), + ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), + ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), + ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), + ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), + ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), + ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), + ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), ]: xs = [] ys = [] @@ -49,40 +48,46 @@ df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] ys.append(y_ * 0.5) - - ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) - + ax.plot( + xs, + ys, + color=color, + label=name, + marker=marker, + markersize=5 if marker == "s" else 5, + linestyle=ls, + linewidth=2 if "+" in k else 1.0, + ) - ax.set_xlabel('dim', fontsize=13) - ax.set_ylabel('time (ms)', fontsize=13) + ax.set_xlabel("dim", fontsize=13) + ax.set_ylabel("time (ms)", fontsize=13) ax.grid() - ax.set_xscale('log') + ax.set_xscale("log") if logscale_plot1: - ax.set_yscale('log') - - ax.tick_params(axis='x', labelsize=11) - ax.tick_params(axis='y', labelsize=11) + ax.set_yscale("log") + + ax.tick_params(axis="x", labelsize=11) + ax.tick_params(axis="y", labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) - leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) - leg.get_texts()[0].set_fontweight('bold') - leg.get_texts()[1].set_fontweight('bold') + leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight("bold") + leg.get_texts()[1].set_fontweight("bold") plt.subplots_adjust(left=0.1) - ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) - + ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) ax = fig.add_subplot(gs[0, 1]) @@ -90,10 +95,15 @@ for j, batch_size in enumerate(batch_sizes_for_plot2): all_xs, all_ys = [], [] for k, marker, ls, color, name in [ - ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), - ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), + ( + "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", + "o", + "-", + "C4", + "SwitchBack int8 (total time)", + ), ]: - xs, ys = [], [] df = rdf[rdf.batch_size == batch_size] for embed_dim in dims_to_consider: @@ -101,11 +111,11 @@ df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] ys.append(y_ * 0.5) all_xs.append(xs) @@ -113,26 +123,29 @@ color = cmap(j * 0.25) real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] - markers = ['^', 'v', 'P', 'o'] - ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + markers = ["^", "v", "P", "o"] + ax.plot( + all_xs[0], + real_ys, + color=color, + label=f"batch * sequence length = {batch_size}", + marker=markers[j], + markersize=5 if marker == "s" else 5, + ) ax.legend() - ax.set_xlabel('dim', fontsize=13) - ax.set_xscale('log') + ax.set_xlabel("dim", fontsize=13) + ax.set_xscale("log") ax.grid() - ax.set_ylabel(r'% speedup', fontsize=13) - + ax.set_ylabel(r"% speedup", fontsize=13) - ax.tick_params(axis='x', labelsize=11) - ax.tick_params(axis='y', labelsize=11) + ax.tick_params(axis="x", labelsize=11) + ax.tick_params(axis="y", labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) - ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) - - - - plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') + ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) + plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight") diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py index b0983d0b8..eaba0e9cd 100644 --- a/benchmarking/switchback/speed_benchmark.py +++ b/benchmarking/switchback/speed_benchmark.py @@ -1,26 +1,34 @@ import json - import time + import torch -import torch.nn as nn +from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( + int8_matmul_mixed_dequantize, +) +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( + int8_matmul_rowwise_dequantize, +) +from bitsandbytes.triton.quantize_columnwise_and_transpose import ( + quantize_columnwise_and_transpose, +) +from bitsandbytes.triton.quantize_global import ( + quantize_global, + quantize_global_transpose, +) from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. -def get_time(k, fn, info_dict): +def get_time(k, fn, info_dict): for _ in range(repeat // 2): - fn() + fn() torch.cuda.synchronize() start = time.time() for _ in range(repeat): - fn() + fn() torch.cuda.synchronize() end = time.time() @@ -28,16 +36,15 @@ def get_time(k, fn, info_dict): print(f"time {k}: {ms:.3f} ms") info_dict[k] = ms -if __name__ == '__main__': + +if __name__ == "__main__": torch.manual_seed(0) wm = 4 for dim in [1024, 1280, 1408, 1664, 2048, 4096]: # note "batch_size" is actually "batch_size * embed_dim", which is why it's large - for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: - + for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]: # switch switches dim_in and dim_out for switch in [False, True]: - # hparams repeat = 64 batch_size = batch_size @@ -54,7 +61,7 @@ def get_time(k, fn, info_dict): x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() - + x_int8 = x.clone().to(torch.int8) g_int8 = g.clone().to(torch.int8) w_int8 = w.clone().to(torch.int8) @@ -65,35 +72,86 @@ def get_time(k, fn, info_dict): state_w_rowwise = w.max(dim=1)[0] state_w_global = w.max() - info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} - - get_time('standard_fwd', lambda : x.matmul(w.t()), info) - get_time('standard_gw', lambda : g.t().matmul(x), info) - get_time('standard_gx', lambda : g.matmul(w), info) - get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) - get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) - get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) - get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) - get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) - get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) - get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) - get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) - get_time('w_quantize_global', lambda : quantize_global(w), info) - get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) - - time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] - time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] - time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] - - print('TOTAL STANDARD', time_standard) - print('TOTAL ROWWISE', time_rowwise) - print('TOTAL GLOBAL', time_global) - - print('speedup', -100*(time_global - time_standard)/time_standard) - - info['time_standard'] = time_standard - info['time_rowwise'] = time_rowwise - info['time_global'] = time_global + info = { + "repeat": repeat, + "batch_size": batch_size, + "dim_out": dim_out, + "dim_in": dim_in, + "wm": wm, + "switch": switch, + } + + get_time("standard_fwd", lambda: x.matmul(w.t()), info) + get_time("standard_gw", lambda: g.t().matmul(x), info) + get_time("standard_gx", lambda: g.matmul(w), info) + get_time( + "rowwise_fwd", + lambda: int8_matmul_rowwise_dequantize( + x_int8, + w_int8.t(), + state_x_rowwise, + state_w_columnwise, + None, + ), + info, + ) + get_time( + "rowwise_bwd", + lambda: int8_matmul_rowwise_dequantize( + g_int8, + wt_int8.t(), + state_x_rowwise, + state_w_rowwise, + None, + ), + info, + ) + get_time( + "global_fwd", + lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), + info, + ) + get_time( + "global_bwd", + lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), + info, + ) + get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info) + get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info) + get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info) + get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info) + get_time("w_quantize_global", lambda: quantize_global(w), info) + get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info) + + time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"] + time_rowwise = ( + info["x_quantize_rowwise"] + + info["g_quantize_rowwise"] + + info["w_quantize_colwise_transpose"] + + info["w_quantize_rowwise"] + + info["standard_gw"] + + info["rowwise_fwd"] + + info["rowwise_bwd"] + ) + time_global = ( + info["x_quantize_rowwise"] + + info["g_quantize_rowwise"] + + info["w_quantize_global"] + + info["w_quantize_global_transpose"] + + info["standard_gw"] + + info["global_fwd"] + + info["global_bwd"] + ) + + print("TOTAL STANDARD", time_standard) + print("TOTAL ROWWISE", time_rowwise) + print("TOTAL GLOBAL", time_global) + + print("speedup", -100 * (time_global - time_standard) / time_standard) + + info["time_standard"] = time_standard + info["time_rowwise"] = time_rowwise + info["time_global"] = time_global info_json = json.dumps(info) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 01d5527f5..78c99355b 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,20 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils, research +from . import research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, matmul, + matmul_4bit, matmul_cublas, mm_cublas, - matmul_4bit ) -from .cextension import COMPILED_WITH_CUDA from .nn import modules - -if COMPILED_WITH_CUDA: - from .optim import adam +from .optim import adam __pdoc__ = { "libbitsandbytes": False, @@ -24,6 +21,4 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.42.0" - -PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" +__version__ = "0.44.0.dev" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index ebbf2653e..e716b6f3f 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,155 +1,4 @@ -import os -import sys -import shlex -import subprocess +if __name__ == "__main__": + from bitsandbytes.diagnostics.main import main -from warnings import warn -from typing import Tuple -from os.path import isdir - -import torch - -HEADER_WIDTH = 60 - -def execute_and_return(command_string: str) -> Tuple[str, str]: - def _decode(subprocess_err_out_tuple): - return tuple( - to_decode.decode("UTF-8").strip() - for to_decode in subprocess_err_out_tuple - ) - - def execute_and_return_decoded_std_streams(command_string): - return _decode( - subprocess.Popen( - shlex.split(command_string), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ).communicate() - ) - - std_out, std_err = execute_and_return_decoded_std_streams(command_string) - return std_out, std_err - -def find_file_recursive(folder, filename): - folder = shlex.quote(folder) - filename = shlex.quote(filename) - cmd = f'find {folder} -name {filename}' - out, err = execute_and_return(cmd) - if len(err) > 0: - raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?') - - return out - - -def generate_bug_report_information(): - print_header("") - print_header("BUG REPORT INFORMATION") - print_header("") - print('') - - if 'CONDA_PREFIX' in os.environ: - paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so') - print_header("ANACONDA CUDA PATHS") - print(paths) - print('') - if isdir('/usr/local/'): - paths = find_file_recursive('/usr/local', '*cuda*so') - print_header("/usr/local CUDA PATHS") - print(paths) - print('') - - if isdir(os.getcwd()): - paths = find_file_recursive(os.getcwd(), '*cuda*so') - print_header("WORKING DIRECTORY CUDA PATHS") - print(paths) - print('') - - print_header("LD_LIBRARY CUDA PATHS") - if 'LD_LIBRARY_PATH' in os.environ: - lib_path = os.environ['LD_LIBRARY_PATH'].strip() - for path in set(lib_path.split(':')): - try: - if isdir(path): - print_header(f"{path} CUDA PATHS") - paths = find_file_recursive(path, '*cuda*so') - print(paths) - except: - print(f'Could not read LD_LIBRARY_PATH: {path}') - print('') - - - - - -def print_header( - txt: str, width: int = HEADER_WIDTH, filler: str = "+" -) -> None: - txt = f" {txt} " if txt else "" - print(txt.center(width, filler)) - - -def print_debug_info() -> None: - print( - "\nAbove we output some debug information. Please provide this info when " - f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" - ) - - -generate_bug_report_information() - - -from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.env_vars import to_be_ignored -from .cuda_setup.main import get_compute_capabilities - - -print_header("OTHER") -print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") -print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") -print_header("") -print_header("DEBUG INFO END") -print_header("") -print( - """ -Running a quick check that: - + library is importable - + CUDA function is callable -""" -) -print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") - -try: - from bitsandbytes.optim import Adam - - p = torch.nn.Parameter(torch.rand(10, 10).cuda()) - a = torch.rand(10, 10).cuda() - - p1 = p.data.sum().item() - - adam = Adam([p]) - - out = a * p - loss = out.sum() - loss.backward() - adam.step() - - p2 = p.data.sum().item() - - assert p1 != p2 - print("SUCCESS!") - print("Installation was successful!") - sys.exit(0) - -except ImportError: - print() - warn( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!" - ) - print_debug_info() - sys.exit(0) -except Exception as e: - print(e) - print_debug_info() - sys.exit(1) + main() diff --git a/bitsandbytes/archive_functional.py b/bitsandbytes/archive_functional.py index 226c9e51f..dac7430ed 100644 --- a/bitsandbytes/archive_functional.py +++ b/bitsandbytes/archive_functional.py @@ -3,17 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct +from functools import reduce # Required in Python 3 import itertools import operator -import random -import torch -import itertools -import math -from scipy.stats import norm -import numpy as np - -from functools import reduce # Required in Python 3 from typing import Tuple + +import numpy as np +from scipy.stats import norm +import torch from torch import Tensor from .cextension import COMPILED_WITH_CUDA, lib @@ -23,12 +20,13 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) + name2qmap = {} if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16) #, lib.cadam32bit_grad_bf16) + str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16) # , lib.cadam32bit_grad_bf16) str2optimizer32bit["momentum"] = ( lib.cmomentum32bit_grad_32, lib.cmomentum32bit_grad_16, @@ -37,7 +35,7 @@ def prod(iterable): lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_16, ) - str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16) #, lib.clion32bit_grad_bf16) + str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16) # , lib.clion32bit_grad_bf16) str2optimizer32bit["adagrad"] = ( lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_16, @@ -73,7 +71,7 @@ def prod(iterable): str2optimizer8bit_blockwise["adam"] = ( lib.cadam_8bit_blockwise_grad_fp32, lib.cadam_8bit_blockwise_grad_fp16, - #lib.cadam_8bit_blockwise_grad_bf16, + # lib.cadam_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["momentum"] = ( lib.cmomentum_8bit_blockwise_grad_fp32, @@ -86,13 +84,14 @@ def prod(iterable): str2optimizer8bit_blockwise["lion"] = ( lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp16, - #lib.clion_8bit_blockwise_grad_bf16, + # lib.clion_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["adagrad"] = ( lib.cadagrad_8bit_blockwise_grad_fp32, lib.cadagrad_8bit_blockwise_grad_fp16, ) + class GlobalPageManager: _instance = None @@ -110,14 +109,13 @@ def get_instance(cls): return cls._instance def prefetch_all(self, to_cpu=False): - # assume the first added, will be hte + # assume the first added, will be the # ones that are used first, so swap them in last # in the case they are evicted again for t in self.paged_tensors[::-1]: prefetch_tensor(t, to_cpu) - class CUBLAS_Context: _instance = None @@ -150,7 +148,7 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def initialize(self): - #self.context = ct.c_void_p(lib.get_cusparse()) + # self.context = ct.c_void_p(lib.get_cusparse()) if torch.version.cuda: self.context = ct.c_void_p(lib.get_cusparse()) elif torch.version.hip: @@ -163,6 +161,7 @@ def get_instance(cls): cls._instance.initialize() return cls._instance + dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -170,8 +169,9 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): - num_bytes = dtype2bytes[dtype]*prod(shape) + +def get_paged(*shape, dtype=torch.float32, device=torch.device("cuda", index=0)): + num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -180,74 +180,86 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)) out.page_deviceid = device.index return out + def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, 'Only paged tensors can be prefetched!' + assert A.is_paged, "Only paged tensors can be prefetched!" if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype]*A.numel() + num_bytes = dtype2bytes[A.dtype] * A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f'c{func_name}_fp32', None) + func = getattr(lib, f"c{func_name}_fp32", None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f'c{func_name}_uint8', None) + func = getattr(lib, f"c{func_name}_uint8", None) cvalue = ct.c_uint8(value) - if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + if func is None: + raise NotImplementedError(f"Function not implemented: {func_name}") - is_managed = getattr(A, 'is_managed', False) + is_managed = getattr(A, "is_managed", False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: prefetch_tensor(B) + if B is not None: + prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: # paged function are fully asynchronous # if we return from this function, we want to the tensor # to be in the correct state, that is the final state after the - # operation occured. So we synchronize. + # operation occurred. So we synchronize. torch.cuda.synchronize() -def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) -def arange(A, device=None): elementwise_func('arange', A, None, 0) -def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + +def fill(A, value, device=None, prefetch=True): + elementwise_func("fill", A, None, value) + + +def arange(A, device=None): + elementwise_func("arange", A, None, 0) + + +def _mul(A, B, device=None): + elementwise_func("_mul", A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = (-1.0 if signed else 0.0) + sign = -1.0 if signed else 0.0 total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = (2**total_bits if not signed else 2**total_bits-1) + total_values = 2**total_bits if not signed else 2**total_bits - 1 values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel()//2 - return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) + l = values.numel() // 2 + return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) -def create_normal_map(offset=0.9677083, use_extra_value=True): +def create_normal_map(offset=0.9677083, use_extra_value=True): if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -257,38 +269,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): assert values.numel() == 256 return values + def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e+p == total_bits-has_sign + assert e + p == total_bits - has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): evalues.append(2**val) - values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - #for ev in evalues: - bias = 2**(exponent_bits-1) - for evalue in range(2**(exponent_bits)): + # for ev in evalues: + bias = 2 ** (exponent_bits - 1) + for evalue in range(2 ** (exponent_bits)): for bit_pattern in lst: - value = (1 if evalue != 0 else 0) + value = 1 if evalue != 0 else 0 for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) + value += pval * (2 ** -(i + 1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value * 2**-(bias) else: # normals - value = value*2**-(evalue-bias-1) + value = value * 2 ** -(evalue - bias - 1) values.append(value) if signed: values.append(-value) - assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -302,7 +313,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code - def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -329,7 +339,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): if not signed: additional_items = 2 * additional_items for i in range(max_exponent_bits): - fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) + fraction_items = int( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1 + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -353,8 +367,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) + def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = q.tolist() q.append(0) @@ -365,11 +380,13 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q/q.abs().max() + q = q / q.abs().max() return q + def get_special_format_str(): - if not torch.cuda.is_available(): return 'col_turing' + if not torch.cuda.is_available(): + return "col_turing" major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -378,22 +395,27 @@ def get_special_format_str(): return "col_turing" - def is_on_gpu(tensors): on_gpu = True gpu_ids = set() for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged if not is_paged: gpu_ids.add(t.device.index) if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}" + ) if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}" + ) return on_gpu + def get_ptr(A: Tensor) -> ct.c_void_p: """ Get the ctypes pointer from a PyTorch Tensor. @@ -434,9 +456,7 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): return getattr(lib, name) -def get_transform_buffer( - shape, dtype, device, to_order, from_order="row", transpose=False -): +def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): # init_func = torch.empty init_func = torch.zeros dims = len(shape) @@ -489,9 +509,7 @@ def nvidia_transform( else: from_order = state[1] if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1] - ) + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) else: new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) @@ -516,7 +534,7 @@ def nvidia_transform( def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: - ''' + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -543,14 +561,21 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') - if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") - if num_quantiles < 256 and offset == 1/(512): + """ + if A.numel() < 256: + raise NotImplementedError( + f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values." + ) + if num_quantiles > 256: + raise NotImplementedError( + f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}" + ) + if num_quantiles < 256 and offset == 1 / (512): # override default arguments - offset = 1/(2*num_quantiles) + offset = 1 / (2 * num_quantiles) - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: @@ -562,14 +587,16 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n post_call(device) if num_quantiles < 256: - step = round(256/num_quantiles) + step = round(256 / num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] return out -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: +def quantize_blockwise( + A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False +) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -596,7 +623,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou The quantization state to undo the quantization. """ - if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -611,23 +637,34 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != 'cpu': + if A.device.type != "cpu": assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()) + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()) + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) if nested: offset = absmax.mean() @@ -637,8 +674,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou else: state = [absmax, code, blocksize, nested, None, None] - - return out, state @@ -649,7 +684,7 @@ def dequantize_blockwise( code: Tensor = None, out: Tensor = None, blocksize: int = 4096, - nested=False + nested=False, ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -686,41 +721,58 @@ def dequantize_blockwise( out = torch.zeros_like(A, dtype=torch.float32) if quant_state is None: - quant_state = (absmax, code, blocksize) - assert absmax is not None and out is not None + quant_state = (absmax, code, blocksize) + assert absmax is not None and out is not None else: - absmax, code, blocksize, nested, offset, state2 = quant_state - if nested: - absmax = dequantize_blockwise(absmax, state2) - absmax += offset - + absmax, code, blocksize, nested, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset - if A.device.type != 'cpu': + if A.device.type != "cpu": device = pre_call(A.device) code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()) + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()) + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: code = code.cpu() - lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) return out + def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4") + def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4") + -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: +def quantize_4bit( + A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type="fp4" +) -> Tensor: """ Quantize tensor A in blocks of 4-bit values. @@ -746,10 +798,10 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") n = A.numel() input_shape = A.shape @@ -759,9 +811,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - if out is None: - out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + out = torch.zeros(((n + 1) // 2, 1), dtype=torch.uint8, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -769,15 +820,23 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -785,8 +844,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if compress_statistics: offset = absmax.mean() absmax -= offset - #code = create_custom_map().to(absmax.device) - #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) + # code = create_custom_map().to(absmax.device) + # qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] @@ -795,13 +854,35 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz return out, state -def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') +def dequantize_fp4( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") -def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + +def dequantize_nf4( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") + + +def dequantize_4bit( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="fp4", +) -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -829,9 +910,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Dequantized tensor. """ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") if quant_state is None: assert absmax is not None and out is not None @@ -840,7 +923,6 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: else: absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state - if compressed_stats is not None: offset, state2 = compressed_stats absmax = dequantize_blockwise(absmax, state2) @@ -851,26 +933,35 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: n = out.numel() - device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n) + ) else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n) + ) elif out.dtype == torch.float16: - if quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n) + ) else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n) + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out + is_transposed = True if A.shape[0] == 1 else False + if is_transposed: + return out.t() + else: + return out def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: @@ -907,7 +998,7 @@ def dequantize( def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -926,9 +1017,10 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: Quantized 8-bit tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -936,7 +1028,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -955,9 +1047,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: 32-bit output tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1024,16 +1117,17 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1053,7 +1147,8 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel())) + ct.c_int32(g.numel()), + ) post_call(prev_device) @@ -1209,7 +1304,6 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1217,8 +1311,11 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and - len(str2optimizer8bit_blockwise[optimizer_name])==3): + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( @@ -1250,9 +1347,8 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) -def percentile_clipping( - grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 -): + +def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): """Applies percentile clipping grad: torch.Tensor @@ -1294,9 +1390,7 @@ def percentile_clipping( return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d( - histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor -): +def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 @@ -1313,12 +1407,12 @@ def histogram_scatter_add_2d( is_on_gpu([histogram, index1, index2, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError( - f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" - ) + raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}") sA = A.shape sB = B.shape @@ -1359,12 +1453,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if ( - sout[0] == sA[2] - and sout[1] == sB[2] - and sA[0] == sB[0] - and sA[1] == sB[1] - ): + if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]: correct = True else: if len(sA) == 2 and len(sB) == 2: @@ -1402,15 +1491,9 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout -def cutlass3_gemm( - A: Tensor, - B: Tensor, - out: Tensor = None, - transposed_A=False, - transposed_B=False, - state=None -): - #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + +def cutlass3_gemm(A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False, state=None): + # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: Bshape = B.shape bout = Bshape[1] @@ -1489,15 +1572,15 @@ def cutlass3_gemm( # B^T @ A^T = C^T # [km, nk -> mn] - #lda = ldb = ldc = 1 - #lda = 1 + # lda = ldb = ldc = 1 + # lda = 1 if state is not None: m = Bshape[0] k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (ldb+1)//2 - #print(m, n, k, lda, ldb, ldc) + ldb = (ldb + 1) // 2 + # print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1507,19 +1590,19 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) if B.dtype == torch.uint8: - lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference( + m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]) + ) elif A.dtype == torch.float32: lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) elif A.dtype == torch.float16: lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") return out - - def igemm( A: Tensor, B: Tensor, @@ -1604,8 +1687,20 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out @@ -1617,9 +1712,7 @@ def batched_igemm( transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError( - f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" - ) + raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}") sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) @@ -1686,9 +1779,24 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out @@ -1697,14 +1805,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -1713,13 +1821,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") assert dimsB != 3, "len(B.shape)==3 not supported" assert A.device.type == "cuda" @@ -1761,46 +1865,30 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == 'col_turing': + if formatB == "col_turing": if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') + print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) return out, Sout -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): +def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 + if bias is not None: + assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1808,19 +1896,11 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" prev_device = pre_call(A.device) ptrA = get_ptr(A) @@ -1834,15 +1914,15 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols + ) post_call(prev_device) return out -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): +def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): assert A.dtype == torch.float16 device = A.device @@ -1855,18 +1935,12 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) + nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -1940,14 +2014,10 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros( - (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device - ) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values - ) + return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) def coo2csc(cooA): @@ -1956,14 +2026,10 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros( - (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device - ) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values - ) + return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -1973,9 +2039,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -1988,9 +2052,7 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -2008,9 +2070,7 @@ def double_quant( if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) @@ -2069,12 +2129,16 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): +def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2085,7 +2149,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2106,7 +2170,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") post_call(prev_device) @@ -2115,9 +2179,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty( - (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype - ) + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2144,16 +2206,28 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): if out is None: - out = torch.zeros( - (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype - ) + out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) nnz = cooA.nnz prev_device = pre_call(B.device) assert cooA.rowidx.numel() == nnz @@ -2171,9 +2245,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert ( - max_count[0] <= 32 - ), f"Current max count per row is 8 but found {max_count[0]}." + assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -2261,9 +2333,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( - x, dim=dim, keepdim=True - ) + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True) dyna[dyna == 0] = 1 qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) @@ -2371,9 +2441,7 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2383,7 +2451,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing': + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2391,6 +2459,7 @@ def extract_outliers(A, SA, idx): return out + def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index 6b9a7e4d1..f262d89ed 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -1 +1 @@ -from ._functions import undo_layout, get_inverse_transform_indices +from ._functions import get_inverse_transform_indices, undo_layout diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 59b0ac7b2..18ca66b17 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,12 +1,13 @@ -import operator -import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional, List +import operator +from typing import Callable, Optional, Tuple +import warnings from warnings import warn import torch +from bitsandbytes.cextension import BNB_HIP_VERSION import bitsandbytes.functional as F @@ -14,19 +15,18 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor - # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py - """ This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. """ + + class GlobalOutlierPooler: _instance = None @@ -56,7 +56,10 @@ def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): +def get_inverse_transform_indices( + transform_tile: Callable[[torch.Tensor], torch.Tensor], + tile_size: Tuple[int, int], +): """ Compute a permutation of indices that invert the specified (tiled) matrix transformation @@ -83,6 +86,7 @@ def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int break # if all indices fit in i bytes, stop early return permuted_tile_indices + def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: """ Undo a tiled permutation such as turing or ampere layout @@ -159,20 +163,12 @@ def backward(ctx, grad_output): ) if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant( - A.view(-1, A.shape[2]), dim=0, quant_type=quant_type - ) + qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant( - igrad_B, S2.t(), S1, grad_output.dtype, quant_type - ) + grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) else: - qgrad_output, S1 = F.vectorwise_quant( - grad_output, dim=dims, quant_type=quant_type - ) - qA, S2 = F.vectorwise_quant( - A, dim=dims, quant_type=quant_type - ) + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -201,9 +197,7 @@ def backward(ctx, grad_output): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant( - grad_output, dim=dims, quant_type=quant_type - ) + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( @@ -225,11 +219,11 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" if torch.version.hip: - return True + return False if BNB_HIP_VERSION < 601 else True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) - nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series + nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores return True @@ -248,6 +242,7 @@ def get_tile_inds(format, device): with torch.no_grad(): return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) + @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None @@ -498,7 +493,7 @@ class MatMul4Bit(torch.autograd.Function): # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod - def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): + def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -512,7 +507,6 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): else: return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - # 1. Dequantize # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) @@ -534,7 +528,7 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad A, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None @@ -544,19 +538,20 @@ def backward(ctx, grad_output): grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) # not supported by PyTorch. TODO: create work-around - #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None def matmul( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None + bias=None, ): state = state or MatmulLtState() if threshold > 0.0: @@ -564,11 +559,19 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None): +def matmul_4bit( + A: torch.Tensor, + B: torch.Tensor, + quant_state: F.QuantState, + out: Optional[torch.Tensor] = None, + bias=None, +): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: - warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') + warn( + f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", + ) return MatMul4Bit.apply(A, B, out, bias, quant_state) else: out = F.gemv_4bit(A, B.t(), out, state=quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 03a208995..69cf0b15f 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,50 +1,137 @@ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + import ctypes as ct +import logging import os +from pathlib import Path + import torch -from pathlib import Path -from warnings import warn +from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs + +logger = logging.getLogger(__name__) + + +def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: + """ + Get the disk path to the CUDA BNB native library specified by the + given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable. + + The library is not guaranteed to exist at the returned path. + """ + if torch.version.hip: + if BNB_HIP_VERSION < 601: + return PACKAGE_DIR / f"libbitsandbytes_hip_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}" + else: + return PACKAGE_DIR / f"libbitsandbytes_hip{DYNAMIC_LIBRARY_SUFFIX}" + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" + if not cuda_specs.has_cublaslt: + # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt + library_name += "_nocublaslt" + library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + + override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: + library_name_stem, _, library_name_ext = library_name.rpartition(".") + # `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`; + # let's remove any trailing numbers: + library_name_stem = library_name_stem.rstrip("0123456789") + # `library_name_stem` will now be e.g. `libbitsandbytes_cuda`; + # let's tack the new version number and the original extension back on. + library_name = f"{library_name_stem}{override_value}.{library_name_ext}" + logger.warning( + f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" + "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" + "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" + "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" + "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLibrary: + binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" + cuda_specs = get_cuda_specs() + if cuda_specs: + cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) + if cuda_binary_path.exists(): + binary_path = cuda_binary_path + else: + logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") + dll = ct.cdll.LoadLibrary(str(binary_path)) + + if hasattr(dll, "get_context"): # only a CUDA-built library exposes this + return CudaBNBNativeLibrary(dll) + + logger.warning( + "The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", + ) + return BNBNativeLibrary(dll) -setup = CUDASetup.get_instance() -if setup.initialized != True: - setup.run_cuda_setup() -lib = setup.lib try: - if lib is None and torch.cuda.is_available() : - CUDASetup.get_instance().generate_instructions() - CUDASetup.get_instance().print_log_stack() - raise RuntimeError(''' - CUDA Setup failed despite GPU being available. Please run the following command to get more information: - - python -m bitsandbytes - - Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them - to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes - and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - - lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False - lib.get_context.restype = ct.c_void_p - - HIP_ENVIRONMENT = False - if torch.version.cuda: - lib.get_cusparse.restype = ct.c_void_p - elif torch.version.hip: - HIP_ENVIRONMENT = True - lib.get_hipsparse.restype = ct.c_void_p - - lib.cget_managed_ptr.restype = ct.c_void_p - COMPILED_WITH_CUDA = True -except AttributeError as ex: - warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") - COMPILED_WITH_CUDA = False - print(str(ex)) - - -# print the setup details after checking for errors so we do not print twice -#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - #setup.print_log_stack() + if torch.version.hip: + hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) + HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor + else: + HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 + lib = get_native_library() +except Exception as e: + lib = None + logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) + if torch.cuda.is_available(): + logger.warning( + """ +CUDA Setup failed despite CUDA being available. Please run the following command to get more information: + +python -m bitsandbytes + +Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them +to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes +and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues +""", + ) diff --git a/bitsandbytes/consts.py b/bitsandbytes/consts.py new file mode 100644 index 000000000..8242d104e --- /dev/null +++ b/bitsandbytes/consts.py @@ -0,0 +1,12 @@ +from pathlib import Path +import platform + +DYNAMIC_LIBRARY_SUFFIX = { + "Darwin": ".dylib", + "Linux": ".so", + "Windows": ".dll", +}.get(platform.system(), ".so") + +PACKAGE_DIR = Path(__file__).parent +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" +NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx" diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py deleted file mode 100644 index e8268fcaa..000000000 --- a/bitsandbytes/cuda_setup/env_vars.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from typing import Dict - - -def to_be_ignored(env_var: str, value: str) -> bool: - ignorable = { - "PWD", # PWD: this is how the shell keeps track of the current working dir - "OLDPWD", - "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated - "SSH_TTY", - "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks - "HOME", # Linux shell default - "TMUX", # Terminal Multiplexer - "XDG_DATA_DIRS", # XDG: Desktop environment stuff - "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff - "XDG_RUNTIME_DIR", - "MAIL", # something related to emails - "SHELL", # binary for currently invoked shell - "DBUS_SESSION_BUS_ADDRESS", # hardware related - "PATH", # this is for finding binaries, not libraries - "LESSOPEN", # related to the `less` command - "LESSCLOSE", - "_", # current Python interpreter - } - return env_var in ignorable - - -def might_contain_a_path(candidate: str) -> bool: - return "/" in candidate - - -def is_active_conda_env(env_var: str) -> bool: - return "CONDA_PREFIX" == env_var - - -def is_other_conda_env_var(env_var: str) -> bool: - return "CONDA" in env_var - - -def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: - return is_active_conda_env(env_var) or ( - might_contain_a_path(value) and not - is_other_conda_env_var(env_var) and not - to_be_ignored(env_var, value) - ) - - -def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: - return { - env_var: value - for env_var, value in os.environ.items() - if is_relevant_candidate_env_var(env_var, value) - } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index b4962c1a0..b2f9214a4 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -4,7 +4,7 @@ [ ] TODO: Q - What if we have multiple GPUs of different makes? - CUDA version - Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiplication) - CuBLAS-LT: full-build 8-bit optimizer - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) @@ -17,25 +17,32 @@ """ import ctypes as ct -import os import errno -import torch -from warnings import warn -from itertools import product - +import os from pathlib import Path from typing import Set, Union +from warnings import warn + +import torch + from .env_vars import get_potentially_lib_path_containing_env_vars # these are the most common libs names # libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead # we have libcudart.so.11.0 which causes a lot of errors before # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt -CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2'] +CUDA_RUNTIME_LIBS: list = [ + "libcudart.so", + "libcudart.so.11.0", + "libcudart.so.12.0", + "libcudart.so.12.1", + "libcudart.so.12.2", +] # this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths backup_paths = [] -backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0') +backup_paths.append("$CONDA_PREFIX/lib/libcudart.so.11.0") + class CUDASetup: _instance = None @@ -44,59 +51,89 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def generate_instructions(self): - if getattr(self, 'error', False): return + if getattr(self, "error", False): + return print(self.error) self.error = True if not self.cuda_available: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.') - self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') - self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') - self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') - self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)') + self.add_log_entry( + "CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed." + ) + self.add_log_entry( + "CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig." + ) + self.add_log_entry("CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:") + self.add_log_entry( + "CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null" + ) + self.add_log_entry( + "CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a" + ) + self.add_log_entry( + "CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc" + ) + self.add_log_entry( + "CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)" + ) return if self.cudart_path is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') - self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') - self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') - self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh') - self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') - self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') + self.add_log_entry( + "CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected." + ) + self.add_log_entry( + "CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable" + ) + self.add_log_entry( + "CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null" + ) + self.add_log_entry( + "CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a" + ) + self.add_log_entry( + "CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc" + ) + self.add_log_entry("CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.") + self.add_log_entry( + "CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh" + ) + self.add_log_entry( + "CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO." + ) + self.add_log_entry( + 'CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local' + ) return - make_cmd = f'CUDA_VERSION={self.cuda_version_string}' + make_cmd = f"CUDA_VERSION={self.cuda_version_string}" if len(self.cuda_version_string) < 3: - make_cmd += ' make cuda92' - elif self.cuda_version_string == '110': - make_cmd += ' make cuda110' - elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: - make_cmd += ' make cuda11x' - elif self.cuda_version_string[:2] == '12' and 1 >= int(self.cuda_version_string[2]) >= 0: - make_cmd += ' make cuda12x' - elif self.cuda_version_string == '100': - self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') - self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') + make_cmd += " make cuda92" + elif self.cuda_version_string == "110": + make_cmd += " make cuda110" + elif self.cuda_version_string[:2] == "11" and int(self.cuda_version_string[2]) > 0: + make_cmd += " make cuda11x" + elif self.cuda_version_string[:2] == "12" and 1 >= int(self.cuda_version_string[2]) >= 0: + make_cmd += " make cuda12x" + elif self.cuda_version_string == "100": + self.add_log_entry("CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.") + self.add_log_entry( + "CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables." + ) return - has_cublaslt = is_cublasLt_compatible(self.cc) if not has_cublaslt: - make_cmd += '_nomatmul' + make_cmd += "_nomatmul" - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') - self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git') - self.add_log_entry('cd bitsandbytes') + self.add_log_entry("CUDA SETUP: Something unexpected happened. Please compile from source:") + self.add_log_entry("git clone https://github.com/TimDettmers/bitsandbytes.git") + self.add_log_entry("cd bitsandbytes") self.add_log_entry(make_cmd) - self.add_log_entry('python setup.py install') + self.add_log_entry("python setup.py install") def initialize(self): - if not getattr(self, 'initialized', False): + if not getattr(self, "initialized", False): self.has_printed = False self.lib = None self.initialized = False @@ -104,16 +141,18 @@ def initialize(self): def manual_override(self): if torch.cuda.is_available(): - if 'BNB_CUDA_VERSION' in os.environ: - if len(os.environ['BNB_CUDA_VERSION']) > 0: - warn((f'\n\n{"="*80}\n' - 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' - 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' - 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' - 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' - 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: 0: + warn( + f'\n\n{"="*80}\n' + 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' + 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' + 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' + 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' + 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} @@ -202,7 +253,7 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: if path.exists(): existent_directories.add(path) except PermissionError as pex: - # Handle the PermissionError first as it is a subtype of OSError + # Handle the PermissionError first as it is a subtype of OSError # https://docs.python.org/3/library/exceptions.html#exception-hierarchy pass except OSError as exc: @@ -211,8 +262,11 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: non_existent_directories: Set[Path] = candidate_paths - existent_directories if non_existent_directories: - CUDASetup.get_instance().add_log_entry("The following directories listed in your path were found to " - f"be non-existent: {non_existent_directories}", is_warning=False) + CUDASetup.get_instance().add_log_entry( + "The following directories listed in your path were found to " + f"be non-existent: {non_existent_directories}", + is_warning=False, + ) return existent_directories @@ -238,9 +292,7 @@ def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: - return get_cuda_runtime_lib_paths( - resolve_paths_list(paths_list_candidate) - ) + return get_cuda_runtime_lib_paths(resolve_paths_list(paths_list_candidate)) def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: @@ -248,27 +300,28 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: warning_msg = ( f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " "We select the PyTorch default libcudart.so, which is {torch.version.cuda}," - "but this might missmatch with the CUDA version that is needed for bitsandbytes." + "but this might mismatch with the CUDA version that is needed for bitsandbytes." "To override this behavior set the BNB_CUDA_VERSION= environmental variable" "For example, if you want to use the CUDA version 122" "BNB_CUDA_VERSION=122 python ..." "OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122" "In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g." - "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2") + "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2" + ) CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) def determine_cuda_runtime_lib_path() -> Union[Path, None]: """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. """ candidate_env_vars = get_potentially_lib_path_containing_env_vars() @@ -282,8 +335,11 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: if conda_cuda_libs: cuda_runtime_libs.update(conda_cuda_libs) - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) + CUDASetup.get_instance().add_log_entry( + f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', + is_warning=True, + ) if "LD_LIBRARY_PATH" in candidate_env_vars: lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) @@ -292,11 +348,15 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: cuda_runtime_libs.update(lib_ld_cuda_libs) warn_in_case_of_duplicates(lib_ld_cuda_libs) - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) + CUDASetup.get_instance().add_log_entry( + f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', + is_warning=True, + ) remaining_candidate_env_vars = { - env_var: value for env_var, value in candidate_env_vars.items() + env_var: value + for env_var, value in candidate_env_vars.items() if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} } @@ -305,13 +365,15 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: cuda_runtime_libs.update(find_cuda_lib_in(value)) if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') - cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) + CUDASetup.get_instance().add_log_entry( + "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths..." + ) + cuda_runtime_libs.update(find_cuda_lib_in("/usr/local/cuda/lib64")) warn_in_case_of_duplicates(cuda_runtime_libs) cuda_setup = CUDASetup.get_instance() - cuda_setup.add_log_entry(f'DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}') + cuda_setup.add_log_entry(f"DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}") return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None @@ -321,9 +383,12 @@ def get_cuda_version(): major, minor = map(int, torch.version.cuda.split(".")) if major < 11: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') + CUDASetup.get_instance().add_log_entry( + "CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!" + ) + + return f"{major}{minor}" - return f'{major}{minor}' def get_compute_capabilities(): ccs = [] @@ -338,25 +403,34 @@ def get_compute_capabilities(): def evaluate_cuda_setup(): cuda_setup = CUDASetup.get_instance() - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - cuda_setup.add_log_entry('') - cuda_setup.add_log_entry('='*35 + 'BUG REPORT' + '='*35) - cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), - ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) - cuda_setup.add_log_entry('='*80) - if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None - if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None + if "BITSANDBYTES_NOWELCOME" not in os.environ or str(os.environ["BITSANDBYTES_NOWELCOME"]) == "0": + cuda_setup.add_log_entry("") + cuda_setup.add_log_entry("=" * 35 + "BUG REPORT" + "=" * 35) + cuda_setup.add_log_entry( + ("Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n"), + ( + "and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues" + ), + ) + cuda_setup.add_log_entry("=" * 80) + if not torch.cuda.is_available(): + return "libbitsandbytes_cpu.so", None, None, None + if torch.version.hip: + return "libbitsandbytes_hip_nohipblaslt.so", None, None, None cudart_path = determine_cuda_runtime_lib_path() ccs = get_compute_capabilities() ccs.sort() - cc = ccs[-1] # we take the highest capability + cc = ccs[-1] # we take the highest capability cuda_version_string = get_cuda_version() - cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") - cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:" - "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md") - + cuda_setup.add_log_entry( + f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}." + ) + cuda_setup.add_log_entry( + "CUDA SETUP: To manually override the PyTorch CUDA version please see:" + "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" + ) # 7.5 is the minimum CC vor cublaslt has_cublaslt = is_cublasLt_compatible(cc) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py new file mode 100644 index 000000000..50c139317 --- /dev/null +++ b/bitsandbytes/cuda_specs.py @@ -0,0 +1,44 @@ +import dataclasses +from typing import List, Optional, Tuple + +import torch + + +@dataclasses.dataclass(frozen=True) +class CUDASpecs: + highest_compute_capability: Tuple[int, int] + cuda_version_string: str + cuda_version_tuple: Tuple[int, int] + + @property + def has_cublaslt(self) -> bool: + return self.highest_compute_capability >= (7, 5) + + +def get_compute_capabilities() -> List[Tuple[int, int]]: + return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) + + +def get_cuda_version_tuple() -> Tuple[int, int]: + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION + if torch.version.cuda: + major, minor = map(int, torch.version.cuda.split(".")) + elif torch.version.hip: + major, minor = map(int, torch.version.hip.split(".")[0:2]) + return major, minor + + +def get_cuda_version_string() -> str: + major, minor = get_cuda_version_tuple() + return f"{major}{minor}" + + +def get_cuda_specs() -> Optional[CUDASpecs]: + if not torch.cuda.is_available(): + return None + + return CUDASpecs( + highest_compute_capability=(get_compute_capabilities()[-1]), + cuda_version_string=(get_cuda_version_string()), + cuda_version_tuple=get_cuda_version_tuple(), + ) diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/diagnostics/__init__.py similarity index 100% rename from bitsandbytes/cuda_setup/__init__.py rename to bitsandbytes/diagnostics/__init__.py diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py new file mode 100644 index 000000000..8974c6400 --- /dev/null +++ b/bitsandbytes/diagnostics/cuda.py @@ -0,0 +1,176 @@ +import logging +import os +from pathlib import Path +from typing import Dict, Iterable, Iterator + +import torch + +from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.consts import NONPYTORCH_DOC_URL +from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.diagnostics.utils import print_dedented + +CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") + +CUDART_PATH_IGNORED_ENVVARS = { + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "LESSCLOSE", + "LESSOPEN", # related to the `less` command + "MAIL", # something related to emails + "OLDPWD", + "PATH", # this is for finding binaries, not libraries + "PWD", # PWD: this is how the shell keeps track of the current working dir + "SHELL", # binary for currently invoked shell + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "_", # current Python interpreter +} + +CUDA_RUNTIME_LIB_PATTERNS = ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows +) + +logger = logging.getLogger(__name__) + + +def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: + for dir_string in paths_list_candidate.split(os.pathsep): + if not dir_string: + continue + if os.sep not in dir_string: + continue + try: + dir = Path(dir_string) + try: + if not dir.exists(): + logger.warning(f"The directory listed in your path is found to be non-existent: {dir}") + continue + except OSError: # Assume an esoteric error trying to poke at the directory + pass + for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: + for pth in dir.glob(lib_pattern): + if pth.is_file(): + yield pth + except (OSError, PermissionError): + pass + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return ( + env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location + or ( + os.sep in value # might contain a path + and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored + and "CONDA" not in env_var # not another conda envvar + and "BASH_FUNC" not in env_var # not a bash function defined via envvar + and "\n" not in value # likely e.g. a script or something? + ) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} + + +def find_cudart_libraries() -> Iterator[Path]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + for envvar in CUDART_PATH_PREFERRED_ENVVARS: + if envvar in candidate_env_vars: + directory = candidate_env_vars[envvar] + yield from find_cuda_libraries_in_path_list(directory) + candidate_env_vars.pop(envvar) + + for env_var, value in candidate_env_vars.items(): + yield from find_cuda_libraries_in_path_list(value) + + +def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: + print( + f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " + f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + ) + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """, + ) + + cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """, + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + + # 7.5 is the minimum CC for cublaslt + if not cuda_specs.has_cublaslt: + print_dedented( + """ + WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! + If you run into issues with 8-bit matmul, you can try 4-bit quantization: + https://huggingface.co/blog/4bit-transformers-bitsandbytes + """, + ) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + +def print_cuda_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate CUDA runtime files (see below). + + We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, + but this might mismatch with the CUDA version that is needed for bitsandbytes. + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """, + ) + for pth in cudart_paths: + print(f"* Found CUDA runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py new file mode 100644 index 000000000..1ce096f69 --- /dev/null +++ b/bitsandbytes/diagnostics/main.py @@ -0,0 +1,85 @@ +import sys +import traceback + +import torch + +from bitsandbytes.consts import PACKAGE_GITHUB_URL +from bitsandbytes.cuda_specs import get_cuda_specs +from bitsandbytes.diagnostics.cuda import ( + print_cuda_diagnostics, + print_cuda_runtime_diagnostics, +) +from bitsandbytes.diagnostics.utils import print_dedented, print_header + + +def sanity_check(): + from bitsandbytes.cextension import lib + + if lib is None: + print_dedented( + """ + Couldn't load the bitsandbytes library, likely due to missing binaries. + Please ensure bitsandbytes is properly installed. + + For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`. + See the documentation for more details if needed. + + Trying a simple check anyway, but this will likely fail... + """, + ) + + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + p1 = p.data.sum().item() + adam = Adam([p]) + out = a * p + loss = out.sum() + loss.backward() + adam.step() + p2 = p.data.sum().item() + assert p1 != p2 + + +def main(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + + print_header("OTHER") + cuda_specs = get_cuda_specs() + print("CUDA specs:", cuda_specs) + if not torch.cuda.is_available(): + print("Torch says CUDA is not available. Possible reasons:") + print("1. CUDA driver not installed") + print("2. CUDA not installed") + print("3. You have multiple conflicting CUDA libraries") + if cuda_specs: + print_cuda_diagnostics(cuda_specs) + print_cuda_runtime_diagnostics() + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + try: + sanity_check() + print("SUCCESS!") + print("Installation was successful!") + return + except ImportError: + print( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!", + ) + except Exception: + traceback.print_exc() + print_dedented( + f""" + Above we output some debug information. + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + """, + ) + sys.exit(1) diff --git a/bitsandbytes/diagnostics/utils.py b/bitsandbytes/diagnostics/utils.py new file mode 100644 index 000000000..770209b9d --- /dev/null +++ b/bitsandbytes/diagnostics/utils.py @@ -0,0 +1,12 @@ +import textwrap + +HEADER_WIDTH = 60 + + +def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "+") -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_dedented(text): + print("\n".join(textwrap.dedent(text).strip().split("\n"))) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a4e93bf37..37728bb4a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,97 +3,106 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct +from functools import reduce # Required in Python 3 import itertools import operator -import random -import torch -import itertools -import math -import numpy as np +from typing import Any, Dict, Optional, Tuple -from functools import reduce # Required in Python 3 -from typing import Tuple, Any, Dict +import numpy as np +import torch from torch import Tensor -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import COMPILED_WITH_CUDA, lib, HIP_ENVIRONMENT +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -# Remark: for AMD GPU we need to disable blocksize == 64 +from .cextension import HIP_ENVIRONMENT, lib # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) + name2qmap = {} -if COMPILED_WITH_CUDA: +if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS""" - str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) - str2optimizer32bit["momentum"] = ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, - ) - str2optimizer32bit["rmsprop"] = ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, - ) - str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) - str2optimizer32bit["adagrad"] = ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ) + str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + } + + str2optimizer8bit = { + "adam": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "momentum": ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, + ), + "lion": ( + lib.clion_static_8bit_grad_32, + lib.clion_static_8bit_grad_16, + ), + "lamb": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "lars": ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ), + } + + str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + ), + } - str2optimizer8bit = {} - str2optimizer8bit["adam"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["momentum"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - str2optimizer8bit["rmsprop"] = ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, - ) - str2optimizer8bit["lion"] = ( - lib.clion_static_8bit_grad_32, - lib.clion_static_8bit_grad_16, - ) - str2optimizer8bit["lamb"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["lars"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - - str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise["adam"] = ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["momentum"] = ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["rmsprop"] = ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["lion"] = ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["adagrad"] = ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - ) class GlobalPageManager: _instance = None @@ -112,14 +121,13 @@ def get_instance(cls): return cls._instance def prefetch_all(self, to_cpu=False): - # assume the first added, will be hte + # assume the first added, will be the # ones that are used first, so swap them in last # in the case they are evicted again for t in self.paged_tensors[::-1]: prefetch_tensor(t, to_cpu) - class CUBLAS_Context: _instance = None @@ -152,7 +160,7 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def initialize(self): - #self.context = ct.c_void_p(lib.get_cusparse()) + # self.context = ct.c_void_p(lib.get_cusparse()) if torch.version.cuda: self.context = ct.c_void_p(lib.get_cusparse()) elif torch.version.hip: @@ -165,6 +173,7 @@ def get_instance(cls): cls._instance.initialize() return cls._instance + dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -172,8 +181,11 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): - num_bytes = dtype2bytes[dtype]*prod(shape) +FIRST_CUDA_DEVICE = torch.device("cuda", index=0) + + +def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): + num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -182,74 +194,92 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)) out.page_deviceid = device.index return out + def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, 'Only paged tensors can be prefetched!' + assert A.is_paged, "Only paged tensors can be prefetched!" if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype]*A.numel() + num_bytes = dtype2bytes[A.dtype] * A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f'c{func_name}_fp32', None) + func = getattr(lib, f"c{func_name}_fp32", None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f'c{func_name}_uint8', None) + func = getattr(lib, f"c{func_name}_uint8", None) cvalue = ct.c_uint8(value) - if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + if func is None: + raise NotImplementedError(f"Function not implemented: {func_name}") - is_managed = getattr(A, 'is_managed', False) + is_managed = getattr(A, "is_managed", False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: prefetch_tensor(B) + if B is not None: + prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: # paged function are fully asynchronous # if we return from this function, we want to the tensor # to be in the correct state, that is the final state after the - # operation occured. So we synchronize. + # operation occurred. So we synchronize. torch.cuda.synchronize() -def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) -def arange(A, device=None): elementwise_func('arange', A, None, 0) -def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + +def fill(A, value, device=None, prefetch=True): + elementwise_func("fill", A, None, value) + + +def arange(A, device=None): + elementwise_func("arange", A, None, 0) + + +def _mul(A, B, device=None): + elementwise_func("_mul", A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = (-1.0 if signed else 0.0) + sign = -1.0 if signed else 0.0 total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = (2**total_bits if not signed else 2**total_bits-1) + total_values = 2**total_bits if not signed else 2**total_bits - 1 values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel()//2 - return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) + l = values.numel() // 2 # noqa: E741 + return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) + def create_normal_map(offset=0.9677083, use_extra_value=True): - from scipy.stats import norm + try: + from scipy.stats import norm + except ImportError as ie: + raise ImportError( + "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.", + ) from ie if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -262,38 +292,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): return values + def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e+p == total_bits-has_sign + assert e + p == total_bits - has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): evalues.append(2**val) - values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - #for ev in evalues: - bias = 2**(exponent_bits-1) - for evalue in range(2**(exponent_bits)): + # for ev in evalues: + bias = 2 ** (exponent_bits - 1) + for evalue in range(2 ** (exponent_bits)): for bit_pattern in lst: - value = (1 if evalue != 0 else 0) + value = 1 if evalue != 0 else 0 for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) + value += pval * (2 ** -(i + 1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value * 2**-(bias) else: # normals - value = value*2**-(evalue-bias-1) + value = value * 2 ** -(evalue - bias - 1) values.append(value) if signed: values.append(-value) - assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -307,7 +336,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code - def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -332,7 +360,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) + fraction_items = int( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -358,8 +390,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) + def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = q.tolist() q.append(0) @@ -370,11 +403,13 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q/q.abs().max() + q = q / q.abs().max() return q + def get_special_format_str(): - if not torch.cuda.is_available(): return 'col_turing' + if not torch.cuda.is_available(): + return "col_turing" major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -383,23 +418,28 @@ def get_special_format_str(): return "col_turing" - def is_on_gpu(tensors): on_gpu = True gpu_ids = set() for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged if not is_paged: gpu_ids.add(t.device.index) if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}", + ) if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}", + ) return on_gpu -def get_ptr(A: Tensor) -> ct.c_void_p: + +def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: """ Get the ctypes pointer from a PyTorch Tensor. @@ -433,15 +473,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): if not hasattr(lib, name): print(name) raise ValueError( - f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}", ) else: return getattr(lib, name) -def get_transform_buffer( - shape, dtype, device, to_order, from_order="row", transpose=False -): +def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): # init_func = torch.empty init_func = torch.zeros dims = len(shape) @@ -490,17 +528,15 @@ def nvidia_transform( ld=None, ): if HIP_ENVIRONMENT: - to_order = "col" if to_order in ["col32","col_turing","col_ampere"] else to_order - from_order = "col" if from_order in ["col32","col_turing","col_ampere"] else from_order + to_order = "col" if to_order in ["col32", "col_turing", "col_ampere"] else to_order + from_order = "col" if from_order in ["col32", "col_turing", "col_ampere"] else from_order if state is None: state = (A.shape, from_order) else: from_order = state[1] if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1], transpose - ) + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) else: new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) @@ -524,8 +560,13 @@ def nvidia_transform( return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: - ''' +def estimate_quantiles( + A: Tensor, + out: Optional[torch.Tensor] = None, + offset: float = 1 / 512, + num_quantiles=256, +) -> Tensor: + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -552,14 +593,21 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') - if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") - if num_quantiles < 256 and offset == 1/(512): + """ + if A.numel() < 256: + raise NotImplementedError( + f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.", + ) + if num_quantiles > 256: + raise NotImplementedError( + f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}", + ) + if num_quantiles < 256 and offset == 1 / (512): # override default arguments - offset = 1/(2*num_quantiles) + offset = 1 / (2 * num_quantiles) - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: @@ -571,7 +619,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n post_call(device) if num_quantiles < 256: - step = round(256/num_quantiles) + step = round(256 / num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] @@ -579,13 +627,36 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n class QuantState: - """container for quantization state components to work with Params4bit and similar clases""" - valid_quant_types = ('fp4', 'nf4') - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', - 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] + """container for quantization state components to work with Params4bit and similar classes""" - def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): self.absmax = absmax self.shape = shape self.code = code @@ -604,13 +675,20 @@ def __get_item__(self, idx): state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] """ if self.nested: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] else: list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] return list_repr[idx] @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": """ unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. @@ -622,37 +700,39 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState # unpacking tensor with non-tensor components qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and 'quant_type' not in qs_dict: + if not len(qs_key) and "quant_type" not in qs_dict: raise ValueError("Expected packed or unpacked quant_state items, found neither") elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: - qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key))) + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - if 'nested_absmax' in qs_dict: - offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) state2 = cls( - absmax=qs_dict['nested_absmax'].to(device), - blocksize=qs_dict['nested_blocksize'], - code=qs_dict['nested_quant_map'].to(device), - dtype=getattr(torch, qs_dict['nested_dtype']), + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), ) else: offset, state2 = None, None quant_state = cls( - quant_type=qs_dict['quant_type'], - absmax=qs_dict['absmax'].to(device), - blocksize=qs_dict['blocksize'], - code=qs_dict['quant_map'].to(device), - dtype=getattr(torch, qs_dict['dtype']), - shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, offset=offset, state2=state2, ) @@ -664,21 +744,23 @@ def as_dict(self, packed=False): param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving """ qs_dict = { - 'quant_type': self.quant_type, - 'absmax': self.absmax, - 'blocksize': self.blocksize, - 'quant_map': self.code, - 'dtype': str(self.dtype).strip('torch.'), - 'shape': tuple(self.shape), + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), } if self.nested: - qs_dict.update({ - 'nested_absmax': self.state2.absmax, - 'nested_blocksize': self.state2.blocksize, - 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - 'nested_dtype': str(self.state2.dtype).strip('torch.'), - 'nested_offset': self.offset.item(), - }) + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + }, + ) if not packed: return qs_dict @@ -696,8 +778,38 @@ def to(self, device): self.state2.absmax = self.state2.absmax.to(device) self.state2.code = self.state2.code.to(device) + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) + ) + -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: +def quantize_blockwise( + A: Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of size 4096 values. @@ -724,7 +836,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou The quantization state to undo the quantization. """ - if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -739,7 +850,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != 'cpu': + if A.device.type != "cpu": if not HIP_ENVIRONMENT: assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] else: @@ -749,24 +860,59 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) if nested: offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) else: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) @@ -775,12 +921,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou def dequantize_blockwise( A: Tensor, - quant_state: QuantState = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, blocksize: int = 4096, - nested=False + nested=False, ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -814,46 +960,79 @@ def dequantize_blockwise( code = name2qmap["dynamic"] if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != 'cpu': + if A.device.type != "cpu": device = pre_call(A.device) code = quant_state.code.to(A.device) supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] if HIP_ENVIRONMENT: supported_blocksizes = supported_blocksizes[:-1] if quant_state.blocksize not in supported_blocksizes: - raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}") + raise ValueError( + f"The blockwise of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}", + ) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_bf16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(quant_state.absmax), + get_ptr(out), + ct.c_longlong(quant_state.blocksize), + ct.c_longlong(A.numel()), + ) return out + def get_4bit_type(typename, device=None, blocksize=64): - if device is None: device = 'cuda' + if device is None: + device = "cuda" data = None - if typename == 'nf4': - ''' Implements the NF4 data type. + if typename == "nf4": + """ Implements the NF4 data type. Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1]. @@ -862,12 +1041,26 @@ def get_4bit_type(typename, device=None, blocksize=64): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, - -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, - 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, - 0.7229568362236023, 1.0] - elif typename == 'fp4': + """ + data = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + elif typename == "fp4": # 0b000 = 0 # 0b001 = 0.0625 # 0b010 = 8 @@ -878,20 +1071,35 @@ def get_4bit_type(typename, device=None, blocksize=64): # 0b111 = 3 # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] - elif typename == 'int4': + elif typename == "int4": data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] - elif typename == 'af4': + elif typename == "af4": # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # https://arxiv.org/abs/2306.06965 if blocksize == 64: - data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, - -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, - 0.42563882, 0.55496234, 0.72424863, 1.][::-1] + data = [ + -1.0, + -0.69441008, + -0.51243739, + -0.3736951, + -0.25607552, + -0.14982478, + -0.04934812, + 0.0, + 0.04273164, + 0.12934483, + 0.21961274, + 0.31675666, + 0.42563882, + 0.55496234, + 0.72424863, + 1.0, + ][::-1] else: - raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') + raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.") if data is None: - raise NotImplementedError(f'Typename {typename} not supported') + raise NotImplementedError(f"Typename {typename} not supported") data = Tensor(data) data /= data.abs().max() @@ -900,17 +1108,41 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): +def quantize_fp4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=None, + compress_statistics=False, + quant_storage=torch.uint8, +): if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) + -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): +def quantize_nf4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=None, + compress_statistics=False, + quant_storage=torch.uint8, +): if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) + -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor: +def quantize_4bit( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=None, + compress_statistics=False, + quant_type="fp4", + quant_storage=torch.uint8, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of 4-bit values. @@ -938,10 +1170,10 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz """ if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") n = A.numel() input_shape = A.shape @@ -951,10 +1183,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) if not HIP_ENVIRONMENT: assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -965,20 +1196,62 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -990,25 +1263,63 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) return out, state -def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None) -> Tensor: + +def dequantize_fp4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, +) -> Tensor: if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") + -def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None) -> Tensor: +def dequantize_nf4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, +) -> Tensor: if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None, quant_type='fp4') -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") + + +def dequantize_4bit( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, + quant_type="fp4", +) -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1037,29 +1348,37 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = """ if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] if HIP_ENVIRONMENT: supported_blocksizes = supported_blocksizes[:-1] if blocksize not in supported_blocksizes: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}", + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) else: absmax = quant_state.absmax - if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -1069,30 +1388,78 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out + is_transposed = True if A.shape[0] == 1 else False + if is_transposed: + return out.t() + else: + return out -def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: +def quantize( + A: Tensor, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -1100,7 +1467,8 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -1108,10 +1476,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def dequantize( A: Tensor, - state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + state: Optional[Tuple[Tensor, Tensor]] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> Tensor: assert state is not None or absmax is not None if code is None and state is None: @@ -1126,8 +1494,8 @@ def dequantize( return out * state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' +def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -1146,17 +1514,18 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: Quantized 8-bit tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -1175,9 +1544,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: 32-bit output tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1193,11 +1563,11 @@ def optimizer_update_32bit( eps: float, step: int, lr: float, - state2: Tensor = None, + state2: Optional[torch.Tensor] = None, beta2: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, skip_zeros=False, ) -> None: @@ -1244,16 +1614,17 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1273,7 +1644,8 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel())) + ct.c_int32(g.numel()), + ) post_call(prev_device) @@ -1282,21 +1654,21 @@ def optimizer_update_8bit( g: Tensor, p: Tensor, state1: Tensor, - state2: Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Tensor, + qmap2: Optional[torch.Tensor], max1: Tensor, - max2: Tensor, + max2: Optional[torch.Tensor], new_max1: Tensor, - new_max2: Tensor, + new_max2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, ) -> None: """ @@ -1405,7 +1777,7 @@ def optimizer_update_8bit( ) else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) post_call(prev_device) @@ -1415,21 +1787,20 @@ def optimizer_update_8bit_blockwise( g: Tensor, p: Tensor, state1: Tensor, - state2: Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Tensor, + qmap2: Optional[torch.Tensor], absmax1: Tensor, - absmax2: Tensor, + absmax2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1437,12 +1808,15 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and - len(str2optimizer8bit_blockwise[optimizer_name])==3): + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) post_call(prev_device) @@ -1470,9 +1844,8 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) -def percentile_clipping( - grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 -): + +def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): """Applies percentile clipping grad: torch.Tensor @@ -1514,9 +1887,7 @@ def percentile_clipping( return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d( - histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor -): +def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 @@ -1533,12 +1904,12 @@ def histogram_scatter_add_2d( is_on_gpu([histogram, index1, index2, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError( - f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" - ) + raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}") sA = A.shape sB = B.shape @@ -1579,12 +1950,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if ( - sout[0] == sA[2] - and sout[1] == sB[2] - and sA[0] == sB[0] - and sA[1] == sB[1] - ): + if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]: correct = True else: if len(sA) == 2 and len(sB) == 2: @@ -1617,26 +1983,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 if not correct: raise ValueError( - f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.", ) return sout + def gemv_4bit( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, - state=None + state=None, ): prev_device = pre_call(A.device) - #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") if A.numel() != A.shape[-1]: - raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + raise ValueError( + 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', + ) Bshape = state.shape bout = Bshape[0] @@ -1656,7 +2025,7 @@ def gemv_4bit( k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (A.shape[-1]+1)//2 + ldb = (A.shape[-1] + 1) // 2 is_on_gpu([B, A, out, absmax, state.code]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1667,25 +2036,65 @@ def gemv_4bit( if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") post_call(prev_device) return out + def igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): @@ -1747,7 +2156,7 @@ def igemm( assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}", ) transposed_A = True @@ -1766,22 +2175,32 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out def batched_igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError( - f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" - ) + raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}") sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) @@ -1848,9 +2267,24 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out @@ -1859,14 +2293,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -1876,22 +2310,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): if dimsA == 2 and out is None: if HIP_ENVIRONMENT: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col", "row") else: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") elif dimsA == 3 and out is None: if HIP_ENVIRONMENT: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col", "row") else: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") assert dimsB != 3, "len(B.shape)==3 not supported" assert A.device.type == "cuda" @@ -1942,48 +2368,35 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == 'col_turing' or HIP_ENVIRONMENT: + if formatB == "col_turing" or HIP_ENVIRONMENT: if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu/ops.hip` + raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") - if has_error == 1: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') + if has_error: + print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) return out, Sout -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): +def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): if HIP_ENVIRONMENT: - A, quant_state = nvidia_transform(A, "row", state = quant_state) + A, quant_state = nvidia_transform(A, "row", state=quant_state) assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 + if bias is not None: + assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1991,19 +2404,11 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" prev_device = pre_call(A.device) ptrA = get_ptr(A) @@ -2017,15 +2422,23 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + ptrBias, + numRows, + numCols, + ) post_call(prev_device) return out -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): +def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): assert A.dtype == torch.float16 device = A.device @@ -2038,18 +2451,12 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) + nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -2123,14 +2530,10 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros( - (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device - ) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values - ) + return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) def coo2csc(cooA): @@ -2139,14 +2542,10 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros( - (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device - ) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values - ) + return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -2156,9 +2555,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -2171,9 +2568,7 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -2191,9 +2586,7 @@ def double_quant( if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) @@ -2252,15 +2645,19 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): +def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): if HIP_ENVIRONMENT: - return nvidia_transform(A,to_order,from_order,out,transpose,state,ld) + return nvidia_transform(A, to_order, from_order, out, transpose, state, ld) prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2271,7 +2668,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2292,7 +2689,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") post_call(prev_device) @@ -2301,9 +2698,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty( - (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype - ) + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2330,16 +2725,28 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): if out is None: - out = torch.zeros( - (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype - ) + out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) nnz = cooA.nnz prev_device = pre_call(B.device) assert cooA.rowidx.numel() == nnz @@ -2357,9 +2764,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert ( - max_count[0] <= 32 - ), f"Current max count per row is 8 but found {max_count[0]}." + assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -2447,9 +2852,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( - x, dim=dim, keepdim=True - ) + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True) dyna[dyna == 0] = 1 qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) @@ -2560,9 +2963,7 @@ def extract_outliers(A, SA, idx): assert formatA in ["col"] assert A.device.type == "cuda" - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2572,7 +2973,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing' or HIP_ENVIRONMENT: + if formatA == "col_turing" or HIP_ENVIRONMENT: lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2580,6 +2981,7 @@ def extract_outliers(A, SA, idx): return out + def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 6fa6d1183..96f4359bf 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,5 +2,21 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding -from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear +from .modules import ( + Embedding, + Int8Params, + Linear4bit, + Linear8bitLt, + LinearFP4, + LinearNF4, + OutlierAwareLinear, + Params4bit, + StableEmbedding, + SwitchBackLinearBnb, +) +from .triton_based_modules import ( + StandardLinear, + SwitchBackLinear, + SwitchBackLinearGlobal, + SwitchBackLinearVectorwise, +) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index e0d94d861..3684badf6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,23 +2,50 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, Optional, TypeVar, Union, overload - import warnings + import torch -import torch.nn.functional as F from torch import Tensor, device, dtype, nn +import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.autograd._functions import get_tile_inds, undo_layout +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import QuantState -from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims +from bitsandbytes.utils import OutlierTracer T = TypeVar("T", bound="torch.nn.Module") class StableEmbedding(torch.nn.Embedding): + """ + Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization. + + Example: + + ``` + # Initialize StableEmbedding layer with vocabulary size 1000, embedding dimension 300 + embedding_layer = StableEmbedding(num_embeddings=1000, embedding_dim=300) + + # Reset embedding parameters + embedding_layer.reset_parameters() + + # Perform a forward pass with input tensor + input_tensor = torch.tensor([1, 2, 3]) + output_embedding = embedding_layer(input_tensor) + ``` + + Attributes: + norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding. + + Methods: + reset_parameters(): Reset embedding parameters using Xavier uniform initialization. + forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. + """ + def __init__( self, num_embeddings: int, @@ -32,6 +59,25 @@ def __init__( device=None, dtype=None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, @@ -45,9 +91,7 @@ def __init__( dtype, ) self.norm = torch.nn.LayerNorm(embedding_dim, device=device) - GlobalOptimManager.get_instance().register_module_override( - self, "weight", {"optim_bits": 32} - ) + GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -83,6 +127,10 @@ def forward(self, input: Tensor) -> Tensor: class Embedding(torch.nn.Embedding): + """ + Embedding class to store and retrieve word embeddings from their indices. + """ + def __init__( self, num_embeddings: int, @@ -95,6 +143,25 @@ def __init__( _weight: Optional[Tensor] = None, device: Optional[device] = None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, @@ -104,11 +171,9 @@ def __init__( scale_grad_by_freq, sparse, _weight, - device=device - ) - GlobalOptimManager.get_instance().register_module_override( - self, "weight", {"optim_bits": 32} + device=device, ) + GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -141,22 +206,24 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - # Remark: change blocksize to 128 for AMD gpu def __new__( - cls, - data: Optional[torch.Tensor] = None, - requires_grad=True, - quant_state: QuantState = None, - blocksize: int = 128, - compress_statistics: bool = True, - quant_type: str = 'fp4', - quant_storage: torch.dtype = torch.uint8, - module: Optional["Linear4bit"] = None, - bnb_quantized: bool = False + cls, + data: Optional[torch.Tensor] = None, + requires_grad=False, # quantized weights should be frozen by default + quant_state: Optional[QuantState] = None, + blocksize: Optional[int] = None, + compress_statistics: bool = True, + quant_type: str = "fp4", + quant_storage: torch.dtype = torch.uint8, + module: Optional["Linear4bit"] = None, + bnb_quantized: bool = False, ) -> "Params4bit": if data is None: data = torch.empty(0) + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics @@ -168,8 +235,46 @@ def __new__( self.module = module return self + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.requires_grad = state["requires_grad"] + self.blocksize = state["blocksize"] + self.compress_statistics = state["compress_statistics"] + self.quant_type = state["quant_type"] + self.quant_state = state["quant_state"] + self.data = state["data"] + self.quant_storage = state["quant_storage"] + self.bnb_quantized = state["bnb_quantized"] + self.module = state["module"] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quant_state = copy.deepcopy(state["quant_state"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + @classmethod - def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": + def from_prequantized( + cls, + data: torch.Tensor, + quantized_stats: Dict[str, Any], + requires_grad: bool = False, + device="cuda", + **kwargs, + ) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) self.requires_grad = requires_grad self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) @@ -181,8 +286,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], def _quantize(self, device): w = self.data.contiguous().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type, quant_storage=self.quant_storage) + w_4bit, quant_state = bnb.functional.quantize_4bit( + w, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + quant_storage=self.quant_storage, + ) self.data = w_4bit self.quant_state = quant_state if self.module is not None: @@ -191,42 +301,107 @@ def _quantize(self, device): return self def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): - return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) @overload - def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: - ... + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: - ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: - ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and device.type == "cuda" and not self.bnb_quantized): + if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: self.quant_state.to(device) - new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, quant_state=self.quant_state, - blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type) + new_param = Params4bit( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + quant_state=self.quant_state, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + ) return new_param class Linear4bit(nn.Linear): + """ + This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314). + QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various + compute datatypes such as FP4 and NF4. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear4bit module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear4bit + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + quantized_model = nn.Sequential( + Linear4bit(64, 64), + Linear4bit(64, 64) + ) + + quantized_model.load_state_dict(fp16_model.state_dict()) + quantized_model = quantized_model.to(0) # Quantization happens here + ``` + """ - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_type="fp4", + quant_storage=torch.uint8, + device=None, + ): + """ + Initialize Linear4bit class. + + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) + self.weight = Params4bit( + self.weight.data, + requires_grad=False, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + module=self, + ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False @@ -243,11 +418,15 @@ def set_compute_type(self, x): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') - warnings.filterwarnings('ignore', message='.*inference.') + warnings.warn( + "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.", + ) + warnings.filterwarnings("ignore", message=".*inference.") if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') - warnings.filterwarnings('ignore', message='.*inference or training') + warnings.warn( + "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.", + ) + warnings.filterwarnings("ignore", message=".*inference or training") def _save_to_state_dict(self, destination, prefix, keep_vars): """ @@ -265,8 +444,8 @@ def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, 'quant_state', None) is None: - if getattr(self, 'quant_state', None) is not None: + if getattr(self.weight, "quant_state", None) is None: + if getattr(self, "quant_state", None) is not None: # the quant state got lost when the parameter got converted. This happens for example for fsdp # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 @@ -274,7 +453,9 @@ def forward(self, x: torch.Tensor): self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight.quant_state = self.quant_state else: - print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", + ) if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -292,23 +473,82 @@ def forward(self, x: torch.Tensor): class LinearFP4(Linear4bit): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) + """ + Implements the FP4 data type. + """ + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_storage=torch.uint8, + device=None, + ): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ + super().__init__( + input_features, + output_features, + bias, + compute_dtype, + compress_statistics, + "fp4", + quant_storage, + device, + ) class LinearNF4(Linear4bit): - ''' Implements the NF4 data type. + """Implements the NF4 data type. - Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that - is normalized into the range [-1, 1]. + Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that + is normalized into the range [-1, 1]. - For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) + For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) - Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in - the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) + Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in + the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. + """ + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_storage=torch.uint8, + device=None, + ): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ + super().__init__( + input_features, + output_features, + bias, + compute_dtype, + compress_statistics, + "nf4", + quant_storage, + device, + ) class Int8Params(torch.nn.Parameter): @@ -325,7 +565,9 @@ def __new__( cls.SCB = None if data is None: data = torch.empty(0) - return torch.Tensor._make_subclass(cls, data, requires_grad) + obj = torch.Tensor._make_subclass(cls, data, requires_grad) + obj.CB, obj.SCB = cls.CB, cls.SCB + return obj def cuda(self, device): if self.has_fp16_weights: @@ -338,8 +580,8 @@ def cuda(self, device): del CBt del SCBt self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + self.CB = CB + self.SCB = SCB return self @@ -349,33 +591,22 @@ def to( device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ..., - ) -> T: - ... + ) -> T: ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: - ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: - ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( - *args, **kwargs - ) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if ( - device is not None - and device.type == "cuda" - and self.data.device.type == "cpu" - ): + if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) else: new_param = Int8Params( - super().to( - device=device, dtype=dtype, non_blocking=non_blocking - ), + super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -398,8 +629,59 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k class Linear8bitLt(nn.Linear): - def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None, device=None): + """ + This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. + To read more about it, have a look at the paper. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear8bitLt + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + int8_model = nn.Sequential( + Linear8bitLt(64, 64, has_fp16_weights=False), + Linear8bitLt(64, 64, has_fp16_weights=False) + ) + + int8_model.load_state_dict(fp16_model.state_dict()) + int8_model = int8_model.to(0) # Quantization happens here + ``` + """ + + def __init__( + self, + input_features: int, + output_features: int, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + device=None, + ): + """ + Initialize Linear8bitLt class. + + Args: + input_features (`int`): + Number of input features of the linear layer. + output_features (`int`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, device) assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() @@ -441,19 +723,36 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[format_name] = self.state.formatB - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) unexpected_copy = list(unexpected_keys) for key in unexpected_copy: - input_name = key[len(prefix):] + input_name = key[len(prefix) :] if input_name == "SCB": if self.weight.SCB is None: # buffers not yet initialized, can't access them directly without quantizing first - raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " - "not supported. Please call module.cuda() before module.load_state_dict()") + raise RuntimeError( + "Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()", + ) input_param = state_dict[key] self.weight.SCB.copy_(input_param) @@ -496,18 +795,18 @@ def __init__(self, input_features, output_features, bias=True, device=None): self.is_quantized = False def forward_with_outliers(self, x, outlier_idx): - raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function") def quantize_weight(self, w, outlier_idx): - raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function") def forward(self, x): if self.outlier_dim is None: tracer = OutlierTracer.get_instance() if not tracer.is_initialized(): - print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer") outlier_idx = tracer.get_outliers(self.weight) - #print(outlier_idx, tracer.get_hvalue(self.weight)) + # print(outlier_idx, tracer.get_hvalue(self.weight)) self.outlier_dim = outlier_idx if not self.is_quantized: @@ -515,6 +814,7 @@ def forward(self, x): self.weight.data.copy_(w) self.is_quantized = True + class SwitchBackLinearBnb(nn.Linear): def __init__( self, @@ -525,11 +825,9 @@ def __init__( memory_efficient_backward=False, threshold=0.0, index=None, - device=None + device=None, ): - super().__init__( - input_features, output_features, bias, device - ) + super().__init__(input_features, output_features, bias, device) self.state = bnb.MatmulLtState() self.index = index @@ -539,9 +837,7 @@ def __init__( if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index de07ac647..aa8494942 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -1,20 +1,27 @@ -import torch -import torch.nn as nn -import time from functools import partial -from bitsandbytes.triton.triton_utils import is_triton_available +import torch +import torch.nn as nn from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( + int8_matmul_mixed_dequantize, +) +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( + int8_matmul_rowwise_dequantize, +) +from bitsandbytes.triton.quantize_columnwise_and_transpose import ( + quantize_columnwise_and_transpose, +) +from bitsandbytes.triton.quantize_global import ( + quantize_global, + quantize_global_transpose, +) from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize +from bitsandbytes.triton.triton_utils import is_triton_available class _switchback_global(torch.autograd.Function): - @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -29,9 +36,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D.size()[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -48,7 +53,8 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_global_transpose(W) grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], -1 + *G_3D.size()[:-1], + -1, ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -58,8 +64,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias -class _switchback_vectorrize(torch.autograd.Function): +class _switchback_vectorrize(torch.autograd.Function): @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -73,9 +79,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call kernel which expects rowwise quantized X and W - return int8_matmul_rowwise_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D.size()[:-1], -1) + return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -91,7 +95,8 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_columnwise_and_transpose(W) grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], -1 + *G_3D.size()[:-1], + -1, ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -101,8 +106,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias -class _switchback_global_mem_efficient(torch.autograd.Function): +class _switchback_global_mem_efficient(torch.autograd.Function): @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -119,9 +124,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D_sz[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -143,35 +146,34 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) del G W_int8 = W_int8.t().contiguous() - grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D_sz[:-1], -1 - ) + grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1) return grad_X, grad_W, grad_bias + class SwitchBackLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - vector_wise_quantization: bool = False, - mem_efficient : bool = False, - ): + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + vector_wise_quantization: bool = False, + mem_efficient: bool = False, + ): super().__init__(in_features, out_features, bias, device, dtype) - if not is_triton_available: - raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. - Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') + if not is_triton_available(): + raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""") # By default, we use the global quantization. self.vector_wise_quantization = vector_wise_quantization if self.vector_wise_quantization: self._fn = _switchback_vectorrize if mem_efficient: - print('mem efficient is not supported for vector-wise quantization.') + print("mem efficient is not supported for vector-wise quantization.") exit(1) else: if mem_efficient: @@ -187,7 +189,7 @@ def prepare_for_eval(self): # if hasattr(m, "prepare_for_eval"): # m.prepare_for_eval() # model.apply(cond_prepare) - print('=> preparing for eval.') + print("=> preparing for eval.") if self.vector_wise_quantization: W_int8, state_W = quantize_rowwise(self.weight) else: @@ -211,18 +213,22 @@ def forward(self, x): X_int8, state_X = quantize_rowwise(X) if self.vector_wise_quantization: - return int8_matmul_rowwise_dequantize( - X_int8, self.W_int8.t(), state_X, self.state_W, self.bias - ).view(*x.size()[:-1], -1) + return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( + *x.size()[:-1], + -1, + ) else: - return int8_matmul_mixed_dequantize( - X_int8, self.W_int8.t(), state_X, self.state_W, self.bias - ).view(*x.size()[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( + *x.size()[:-1], + -1, + ) + SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) + # This is just the standard linear function. class StandardLinearFunction(torch.autograd.Function): @staticmethod @@ -252,7 +258,7 @@ def backward(ctx, grad_output_3D): return grad_input, grad_weight, grad_bias -class StandardLinear(nn.Linear): +class StandardLinear(nn.Linear): def forward(self, x): return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 83a57bd9f..b4c95793a 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,14 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from bitsandbytes.cextension import COMPILED_WITH_CUDA - from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit +from .adamw import ( + AdamW, + AdamW8bit, + AdamW32bit, + PagedAdamW, + PagedAdamW8bit, + PagedAdamW32bit, +) from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 7d8df58ac..7459dece1 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -20,12 +20,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -62,12 +87,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 8): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -105,12 +155,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 86981eb86..740db26ac 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,34 +14,370 @@ class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class PagedAdam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Paged Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit paged Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Paged 32-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. @@ -119,9 +455,7 @@ def step(self, closure=None): if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError( - "Adam does not support sparse gradients, please consider SparseAdam instead" - ) + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") amsgrad = group.get("amsgrad", False) assert not amsgrad @@ -138,15 +472,9 @@ def step(self, closure=None): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state["abserrors"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) - state["relerrors"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) - state["counts"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) + state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -154,25 +482,19 @@ def step(self, closure=None): state["exp_avg"] = state["exp_avg"].to(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) if amsgrad: - state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( - p_data_fp32 - ) + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) state["step"] += 1 beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = ( - group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - ) + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] if group["weight_decay"] != 0: - p_data_fp32.add_( - p_data_fp32, alpha=-group["weight_decay"] * group["lr"] - ) + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: @@ -185,10 +507,7 @@ def step(self, closure=None): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if ( - p_data_fp32.numel() <= 8192 - or p_data_fp32.numel() > 50000 * 1000 - ): + if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -227,9 +546,7 @@ def step(self, closure=None): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError( - f"Invalid analysis value: {self.analysis}!" - ) + raise ValueError(f"Invalid analysis value: {self.analysis}!") denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -241,9 +558,7 @@ def step(self, closure=None): F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d( - counts, C1.int(), C2.int(), torch.ones_like(abserr) - ) + F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) p_data_fp32 += -step_size * update_fp32 @@ -251,18 +566,10 @@ def step(self, closure=None): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join( - [str(dim) for dim in p_data_fp32.shape] - ) - pathe = os.path.join( - self.savedir, f"{p_id}_{shapestr}_abserr.pkl" - ) - pathrele = os.path.join( - self.savedir, f"{p_id}_{shapestr}_relerr.pkl" - ) - pathcounts = os.path.join( - self.savedir, f"{p_id}_{shapestr}_counts.pkl" - ) + shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl") + pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl") + pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl") torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 21077f1a0..4bf3f6436 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,35 +5,364 @@ from bitsandbytes.optim.optimizer import Optimizer2State - class AdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class AdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class AdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) class PagedAdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 8-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 32-bit AdamW optimizer. + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 1fbb6fadc..8d29cbbfe 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -23,6 +23,39 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + Base LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, @@ -56,6 +89,37 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + 8-bit LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, @@ -89,6 +153,37 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + 32-bit LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 73554e3cc..90c3686fe 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -23,10 +23,35 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + Base LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -57,10 +82,33 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + 8-bit LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -91,10 +139,33 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + 32-bit LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -127,9 +198,7 @@ def __init__( if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, @@ -140,9 +209,7 @@ def __init__( max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError( - "Nesterov momentum requires a momentum and zero dampening" - ) + raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2bde1a447..2e4163694 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,27 +4,315 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Lion(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Lion8bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Lion32bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) class PagedLion(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedLion8bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 8-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedLion32bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 32-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index fb83eddf0..f1e60e5e7 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import abc as container_abcs -from collections import defaultdict +from collections import abc as container_abcs, defaultdict from copy import deepcopy from itertools import chain @@ -19,6 +18,10 @@ def __init__(self, initial_data): class GlobalOptimManager: + """ + A global optimizer manager for enabling custom optimizer configs. + """ + _instance = None def __init__(self): @@ -46,30 +49,44 @@ def register_parameters(self, params): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[ - id(p) - ] + self.index2config[(group_index, p_index)] = self.pid2config[id(p)] - def override_config( - self, parameters, key=None, value=None, key_value_dict=None - ): + def override_config(self, parameters, key=None, value=None, key_value_dict=None): """ - Overrides initial optimizer config for specific parameters. + Override initial optimizer config with specific hyperparameters. The key-values of the optimizer config for the input parameters are overridden - This can be both, optimizer parameters like "betas", or "lr" or it can be - 8-bit specific parameters like "optim_bits", "percentile_clipping". - - Parameters - ---------- - parameters : torch.Tensor or list(torch.Tensors) - The input parameters. - key : str - The hyperparamter to override. - value : object - The value for the hyperparamters. - key_value_dict : dict - A dictionary with multiple key-values to override. + This can be both, optimizer parameters like `betas` or `lr`, or it can be + 8-bit specific parameters like `optim_bits` or `percentile_clipping`. + + Arguments: + parameters (`torch.Tensor` or `list(torch.Tensors)`): + The input parameters. + key (`str`): + The hyperparamter to override. + value: + The hyperparameter values. + key_value_dict (`dict`): + A dictionary with multiple key-values to override. + + Example: + + ```py + import torch + import bitsandbytes as bnb + + mng = bnb.optim.GlobalOptimManager.get_instance() + + model = MyModel() + mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU + + model = model.cuda() + # use 8-bit optimizer states for all parameters + adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) + + # 2. override: the parameter model.fc1.weight now uses 32-bit Adam + mng.override_config(model.fc1.weight, 'optim_bits', 32) + ``` """ self.uses_config_override = True if isinstance(parameters, torch.nn.Parameter): @@ -93,6 +110,17 @@ def register_module_override(self, module, param_name, config): class Optimizer8bit(torch.optim.Optimizer): def __init__(self, params, defaults, optim_bits=32, is_paged=False): + """ + Base 8-bit optimizer class. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__(params, defaults) self.initialized = False self.name2qmap = {} @@ -101,18 +129,18 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { - "qmap1", - "qmap2", - "max1", - "max2", - "new_max1", - "new_max2", - "state1", - "state2", - "gnorm_vec", - "absmax1", - "absmax2", - "unorm_vec", + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", } if optim_bits == 8: @@ -126,11 +154,11 @@ def __setstate__(self, state): super().__setstate__(state) def load_state_dict(self, state_dict): - r"""Loads the optimizer state. + """Load an optimizer state. - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. + Arguments: + state_dict (`dict`): + An optimizer state (should be returned from a call to `state_dict`) to load. """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) @@ -139,16 +167,12 @@ def load_state_dict(self, state_dict): saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError( - "loaded state dict has a different number of " - "parameter groups" - ) + raise ValueError("loaded state dict has a different number of parameter groups") param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( - "loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group" + "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", ) # Update the state @@ -197,9 +221,7 @@ def update_group(group, new_group): new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups) - ] + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -209,7 +231,7 @@ def to_gpu(self): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - is_paged = getattr(v, 'is_paged', False) + is_paged = getattr(v, "is_paged", False) if not is_paged: self.state[p][k] = v.to(p.device) @@ -217,9 +239,7 @@ def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance( - pmodule, torch.Parameter - ) + assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) found = False for gindex, group in enumerate(self.param_groups): if found: @@ -231,18 +251,16 @@ def check_overrides(self): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[ - (gindex, pindex) - ] = self.mng.pid2config[id(p)] + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] found = True @torch.no_grad() def step(self, closure=None): - """Performs a single optimization step. + """Perform a single optimization step. Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + closure (`Callable`, *optional*, defaults to `None`): + A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: @@ -256,7 +274,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True - #if self.is_paged: self.page_mng.prefetch_all() + # if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -273,7 +291,6 @@ def step(self, closure=None): # to sync to make sure all tensors are in the right state torch.cuda.synchronize() - return loss def get_config(self, gindex, pindex, group): @@ -297,9 +314,7 @@ def init_state(self, group, p, gindex, pindex): raise NotImplementedError("init_state method needs to be overridden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError( - "The update_step method needs to be overridden" - ) + raise NotImplementedError("The update_step method needs to be overridden") def get_state_buffer(self, p, dtype=torch.float32): if not self.is_paged or p.numel() < 1e5: @@ -314,12 +329,12 @@ def get_state_buffer(self, p, dtype=torch.float32): def prefetch_state(self, p): if self.is_paged: state = self.state[p] - s1 = state['state1'] - is_paged = getattr(s1, 'is_paged', False) + s1 = state["state1"] + is_paged = getattr(s1, "is_paged", False) if is_paged: - F.prefetch_tensor(state['state1']) - if 'state2' in state: - F.prefetch_tensor(state['state2']) + F.prefetch_tensor(state["state1"]) + if "state2" in state: + F.prefetch_tensor(state["state2"]) class Optimizer2State(Optimizer8bit): @@ -338,8 +353,41 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False + is_paged=False, ): + """ + Base 2-state update optimizer class. + + Arguments: + optimizer_name (`str`): + The name of the optimizer. + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple`, defaults to (0.9, 0.999)): + The beta values for the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value for the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 0.0): + The maximum value to normalize each block with. + skip_zeros (`bool`, defaults to `False`): + Whether to skip zero values for sparse gradients and models to ensure correct updates. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: @@ -350,13 +398,9 @@ def __init__( betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError( - f"Invalid beta parameter at index {i}: {betas[i]}" - ) + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -385,9 +429,7 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError( - f'Amount of optimizer bits not supported: {config["optim_bits"]}' - ) + raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -395,21 +437,15 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or ( - dtype == torch.uint8 and p.numel() < 4096 - ): + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( - p.device - ) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( - p.device - ) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -422,25 +458,13 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) - state["absmax2"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: - state["max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["max2"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max2"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -460,7 +484,10 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, state["gnorm_vec"], step, config["percentile_clipping"] + grad, + state["gnorm_vec"], + step, + config["percentile_clipping"], ) else: gnorm_scale = 1.0 @@ -504,9 +531,7 @@ def update_step(self, group, p, gindex, pindex): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] - if config["max_unorm"] > 0.0 - else None, + unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], ) @@ -551,21 +576,50 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False + is_paged=False, ): + """ + Base 1-state update optimizer class. + + Arguments: + optimizer_name (`str`): + The name of the optimizer. + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple`, defaults to (0.9, 0.0)): + The beta values for the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value for the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 0.0): + The maximum value to normalize each block with. + skip_zeros (`bool`, defaults to `False`): + Whether to skip zero values for sparse gradients and models to ensure correct updates. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError( - f"Invalid beta parameter at index {i}: {betas[i]}" - ) + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -594,9 +648,7 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError( - f'Amount of optimizer bits not supported: {config["optim_bits"]}' - ) + raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -604,17 +656,13 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or ( - dtype == torch.uint8 and p.numel() < 4096 - ): + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( - p.device - ) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -624,16 +672,10 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: - state["max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -653,7 +695,10 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, state["gnorm_vec"], step, config["percentile_clipping"] + grad, + state["gnorm_vec"], + step, + config["percentile_clipping"], ) else: gnorm_scale = 1.0 @@ -669,7 +714,7 @@ def update_step(self, group, p, gindex, pindex): step, config["lr"], None, - config['betas'][1], + config["betas"][1], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 2853ca723..25611309b 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -21,10 +21,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -57,10 +84,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -93,11 +147,38 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index 3c0fc2b9f..ec18f036c 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -20,6 +20,33 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( @@ -51,6 +78,31 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( @@ -82,6 +134,31 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py index 47b720d78..31db4f282 100644 --- a/bitsandbytes/research/__init__.py +++ b/bitsandbytes/research/__init__.py @@ -1,6 +1,6 @@ from . import nn from .autograd._functions import ( - switchback_bnb, matmul_fp8_global, matmul_fp8_mixed, + switchback_bnb, ) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 883121759..e5655b546 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -1,21 +1,19 @@ +from functools import reduce # Required in Python 3 import operator +from typing import Optional import warnings -from dataclasses import dataclass -from functools import reduce # Required in Python 3 import torch -import bitsandbytes.functional as F - -from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler +from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState from bitsandbytes.cextension import HIP_ENVIRONMENT +import bitsandbytes.functional as F # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor class MatMulFP8Mixed(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs @@ -86,7 +84,7 @@ def backward(ctx, grad_output): # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) # not supported by PyTorch. TODO: create work-around - if req_gradA: + if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) if req_gradB: @@ -170,7 +168,7 @@ def backward(ctx, grad_output): # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) # not supported by PyTorch. TODO: create work-around - if req_gradA: + if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) if req_gradB: @@ -187,7 +185,9 @@ def backward(ctx, grad_output): class SwitchBackBnb(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # TODO: the B008 on the line below is a likely bug; the current implementation will + # have each SwitchBackBnb instance share a single MatmulLtState instance!!! + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008 # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -196,9 +196,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -217,9 +217,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A.to(torch.float16), threshold=state.threshold - ) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -235,14 +233,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - #print('A shape', A.shape) + # print('A shape', A.shape) if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None # 2. Quantize B if state.has_fp16_weights: - #print('B shape', B.shape) + # print('B shape', B.shape) has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: @@ -273,12 +271,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0) - .t() - .contiguous() - .to(A.dtype) - ) + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -321,14 +314,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + clone_func = torch.clone if len(output_shape) == 3 else lambda x: x return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors @@ -343,9 +335,7 @@ def backward(ctx, grad_output): # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.reshape( - -1, grad_output.shape[-1] - ).contiguous() + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) @@ -358,25 +348,24 @@ def backward(ctx, grad_output): if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) # print('back B shape', state.CxBt.shape) # print('back grad shape', C32grad.shape) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception('State must contain either CBt or CB matrix for backward') + raise Exception("State must contain either CBt or CB matrix for backward") return grad_A, grad_B, None, grad_bias, None + def get_block_sizes(input_matrix, weight_matrix): input_features = input_matrix.shape[-1] - output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) + output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1] if not HIP_ENVIRONMENT: array = [4096, 2048, 1024, 512, 256, 128, 64, 0] else: @@ -393,21 +382,42 @@ def get_block_sizes(input_matrix, weight_matrix): return bsz, bsz2 -def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + +def matmul_fp8_global( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): + if bsz == -1 or bsz2 == -1: + bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) -def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + +def matmul_fp8_mixed( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): + if bsz == -1 or bsz2 == -1: + bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + def switchback_bnb( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None + bias=None, ): state = state or MatmulLtState() if threshold > 0.0: diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py index 8faec10bb..417011218 100644 --- a/bitsandbytes/research/nn/__init__.py +++ b/bitsandbytes/research/nn/__init__.py @@ -1 +1 @@ -from .modules import LinearFP8Mixed, LinearFP8Global +from .modules import LinearFP8Global, LinearFP8Mixed diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py index 2a46b40c3..57c0f3358 100644 --- a/bitsandbytes/research/nn/modules.py +++ b/bitsandbytes/research/nn/modules.py @@ -1,12 +1,9 @@ -from typing import Optional, TypeVar, Union, overload +from typing import TypeVar import torch -import torch.nn.functional as F -from torch import Tensor, device, dtype, nn +from torch import nn import bitsandbytes as bnb -from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") @@ -31,12 +28,20 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.research.matmul_fp8_mixed( + x, + self.weight.t(), + fw_code=self.fw_code, + bw_code=self.bw_code, + bsz=self.bsz, + bsz2=self.bsz2, + ) if self.bias is not None: out += self.bias return out + class LinearFP8Global(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) @@ -57,7 +62,14 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.matmul_fp8_global( + x, + self.weight.t(), + fw_code=self.fw_code, + bw_code=self.bw_code, + bsz=self.bsz, + bsz2=self.bsz2, + ) if self.bias is not None: out += self.bias diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py index e092680b8..26eab84f2 100644 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -1,35 +1,36 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None -else: + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _dequantize_rowwise( @@ -50,7 +51,6 @@ def _dequantize_rowwise( max_val = tl.load(state_x + pid) output = max_val * x * inv_127 tl.store(output_ptr + offsets, output, mask=row_mask) - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) @@ -60,5 +60,5 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (x.shape[0],) - _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index b0961f558..583371d91 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -1,15 +1,16 @@ import torch + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None -else: + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): + return None +else: import triton import triton.language as tl from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - # This is a matmul kernel based on triton.ops.matmul # It is modified to support rowwise quantized input and global quantized weight # It's purpose is fused matmul then dequantize @@ -26,57 +27,83 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ), + ) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ), + ) return configs - @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + *get_configs_io_bound(), + ], + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + ) + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }, ) - @triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, - }) @triton.jit - def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): + def _int8_matmul_mixed_dequantize( + A, + B, + C, + bias, + state_x_ptr, + state_w_ptr, + M, + N, + K, + divfactor: tl.constexpr, + has_bias: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -113,13 +140,13 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - - acc = (w_factor * (x_factor * (acc * divfactor))) + + acc = w_factor * (x_factor * (acc * divfactor)) acc = acc.to(C.dtype.element_ty) # conditionally add bias @@ -135,10 +162,9 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, else: tl.atomic_add(C, acc, mask=mask) - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): device = a.device - divfactor = 1. / (127. * 127.) + divfactor = 1.0 / (127.0 * 127.0) has_bias = 0 if bias is None else 1 # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -152,12 +178,28 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_mixed_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + _int8_matmul_mixed_dequantize[grid]( + a, + b, + c, + bias, + state_x, + state_w, + M, + N, + K, + divfactor, + has_bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 33f4d13f2..e3d192ded 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -3,7 +3,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None + + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): + return None else: import triton import triton.language as tl @@ -17,7 +19,6 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None def init_to_zero(name): return lambda nargs: nargs[name].zero_() - def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -26,57 +27,83 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ), + ) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ), + ) return configs - @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + *get_configs_io_bound(), + ], + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + ) + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }, ) - @triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, - }) @triton.jit - def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): + def _int8_matmul_rowwise_dequantize( + A, + B, + C, + bias, + state_x_ptr, + state_w_ptr, + M, + N, + K, + divfactor, + has_bias: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -113,13 +140,13 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - - acc = (w_factor * (x_factor * (acc * divfactor))) + + acc = w_factor * (x_factor * (acc * divfactor)) acc = acc.to(C.dtype.element_ty) if has_bias: @@ -134,9 +161,8 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, else: tl.atomic_add(C, acc, mask=mask) - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - divfactor = 1. / (127. * 127.) + divfactor = 1.0 / (127.0 * 127.0) has_bias = 0 if bias is None else 1 @@ -153,12 +179,28 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_rowwise_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + _int8_matmul_rowwise_dequantize[grid]( + a, + b, + c, + bias, + state_x, + state_w, + M, + N, + K, + divfactor, + has_bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index 54220d95a..b8eeffd0c 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -1,37 +1,38 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_columnwise_and_transpose(x: torch.Tensor): return None -else: + def quantize_columnwise_and_transpose(x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # This kernel does fused columnwise quantization and transpose. # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_stages=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=16, num_warps=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _quantize_columnwise_and_transpose( @@ -39,7 +40,8 @@ def _quantize_columnwise_and_transpose( output_ptr, output_maxs, n_elements, - M : tl.constexpr, N : tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, ): @@ -47,14 +49,14 @@ def _quantize_columnwise_and_transpose( block_start = pid p2_arange = tl.arange(0, P2) p2_arange_mask = p2_arange < M - arange = p2_arange * N + arange = p2_arange * N offsets = block_start + arange x = tl.load(x_ptr + offsets, mask=p2_arange_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) + output = tl.libdevice.llrint(127.0 * (x / max_val)) - new_start = pid * M + new_start = pid * M new_offsets = new_start + p2_arange tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) tl.store(output_maxs + pid, max_val) @@ -68,7 +70,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) return output, output_maxs - diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index 845db6ecd..f35bdd304 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -1,25 +1,25 @@ -import math import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_global_transpose(input): return None - def quantize_global(x: torch.Tensor): return None -else: + def quantize_global_transpose(input): + return None + + def quantize_global(x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # global quantize @triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), - triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), - - ], - key=['n_elements'] + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), + ], + key=["n_elements"], ) @triton.jit def _quantize_global( @@ -35,73 +35,90 @@ def _quantize_global( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127. * (x * absmax_inv)) + output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) tl.store(output_ptr + offsets, output, mask=mask) def quantize_global(x: torch.Tensor): absmax = x.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax - output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + absmax_inv = 1.0 / absmax + output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_global[grid](x, absmax_inv, output, n_elements) return output, absmax - # global quantize and transpose @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - - # ... - ], - key=['M', 'N'] + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), + # ... + ], + key=["M", "N"], ) @triton.jit - def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, - BLOCK_M : tl.constexpr, - BLOCK_N : tl.constexpr, - GROUP_M : tl.constexpr): + def _quantize_global_transpose( + A, + absmax_inv_ptr, + B, + stride_am, + stride_an, + stride_bn, + stride_bm, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + ): pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N - + width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // group_size - + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) mask = (rm < M)[:, None] & (rn < N)[None, :] a = tl.load(A, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - + # rematerialize to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) mask = (rm < M)[:, None] & (rn < N)[None, :] - output = tl.libdevice.llrint(127. * (a * absmax_inv)) + output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) tl.store(B, output, mask=mask) def quantize_global_transpose(input): absmax = input.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax + absmax_inv = 1.0 / absmax M, N = input.shape - out = torch.empty(N, M, device='cuda', dtype=torch.int8) - + out = torch.empty(N, M, device="cuda", dtype=torch.int8) + assert out.size(0) == N and out.size(1) == M assert input.stride(0) == 1 or input.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) - _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) - return out, absmax + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _quantize_global_transpose[grid]( + input, + absmax_inv, + out, + input.stride(0), + input.stride(1), + out.stride(0), + out.stride(1), + M, + N, + ) + return out, absmax diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index 26d218321..f92ace02c 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -1,36 +1,36 @@ import math + import torch -import time from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_rowwise(x: torch.Tensor): return None -else: + def quantize_rowwise(x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _quantize_rowwise( @@ -47,10 +47,10 @@ def _quantize_rowwise( offsets = block_start + arange row_mask = arange < BLOCK_SIZE x = tl.load(x_ptr + offsets, mask=row_mask) - + abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) + output = tl.libdevice.llrint(127.0 * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val) @@ -65,4 +65,3 @@ def quantize_rowwise(x: torch.Tensor): grid = lambda meta: (x.shape[0],) _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output, output_maxs - diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py index c74c23962..6bbdbf1c1 100644 --- a/bitsandbytes/triton/triton_utils.py +++ b/bitsandbytes/triton/triton_utils.py @@ -1,4 +1,5 @@ import importlib + def is_triton_available(): return importlib.util.find_spec("triton") is not None diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 48373a1fe..0229e59e2 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,9 +1,11 @@ import json import shlex import subprocess -import torch from typing import Tuple +import torch + + def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) tracer = OutlierTracer.get_instance() @@ -28,7 +30,7 @@ def outlier_hook(module, input): # (1) zscore test of std of hidden dimension outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) # (2) magnitude > 6 test - dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) + dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1))) outlier_idx2 = torch.where(dims > 0)[0] outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() tracer.hvalue2outlier_idx[hvalue] = outlier_idx @@ -37,7 +39,7 @@ def outlier_hook(module, input): hook.remove() -class OutlierTracer(object): +class OutlierTracer: _instance = None def __init__(self): @@ -57,14 +59,14 @@ def initialize(self, model): self.hooks.append(m.register_forward_pre_hook(outlier_hook)) def is_initialized(self): - return getattr(self, 'initialized', False) + return getattr(self, "initialized", False) def get_hvalue(self, weight): return weight.data.storage().data_ptr() def get_outliers(self, weight): if not self.is_initialized(): - print('Outlier tracer is not initialized...') + print("Outlier tracer is not initialized...") return None hvalue = self.get_hvalue(weight) if hvalue in self.hvalue2outlier_idx: @@ -78,6 +80,7 @@ def get_instance(cls): cls._instance = cls.__new__(cls) return cls._instance + def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): if rdm: return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() @@ -85,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) m = weight.mean(reduction_dim) mm = m.mean() mstd = m.std() - zm = (m-mm)/mstd + zm = (m - mm) / mstd std = weight.std(reduction_dim) stdm = std.mean() stdstd = std.std() - zstd = (std-stdm)/stdstd + zstd = (std - stdm) / stdstd if topk is not None: val, idx = torch.topk(std.abs(), k=topk, dim=0) @@ -103,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) def execute_and_return(command_string: str) -> Tuple[str, str]: def _decode(subprocess_err_out_tuple): - return tuple( - to_decode.decode("UTF-8").strip() - for to_decode in subprocess_err_out_tuple - ) + return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple) def execute_and_return_decoded_std_streams(command_string): return _decode( @@ -114,15 +114,20 @@ def execute_and_return_decoded_std_streams(command_string): shlex.split(command_string), stdout=subprocess.PIPE, stderr=subprocess.PIPE, - ).communicate() + ).communicate(), ) std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err - -def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): +def replace_linear( + model, + linear_replacement, + skip_modules=("lm_head",), + copy_weights=False, + post_processing_function=None, +): """ Replace linear modules with a new Linear module. Parameters: @@ -135,7 +140,7 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei List of modules names not to convert. Defaults to `lm_head`. copy_weights (`bool`): Copy the weights from the old linear module to the new one - post_processing_fun_name (`str`): + post_processing_function (`str`): A function name of the replacement linear class that is called after processing. """ @@ -155,8 +160,9 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei model._modules[name].bias = old_module.bias if post_processing_function is not None: - func = getattr(module, post_processing_function, None) - if func is not None: func(module) + func = getattr(module, post_processing_function, None) + if func is not None: + func(module) return model @@ -171,7 +177,7 @@ def pack_dict_to_tensor(source_dict): A torch tensor containing the packed data. """ json_str = json.dumps(source_dict) - json_bytes = json_str.encode('utf-8') + json_bytes = json_str.encode("utf-8") tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) return tensor_data @@ -188,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data): A Python dictionary containing the unpacked data. """ json_bytes = bytes(tensor_data.cpu().numpy()) - json_str = json_bytes.decode('utf-8') + json_str = json_bytes.decode("utf-8") unpacked_dict = json.loads(json_str) return unpacked_dict diff --git a/check_bnb_install.py b/check_bnb_install.py index e50afb0a1..7a9dc93fc 100644 --- a/check_bnb_install.py +++ b/check_bnb_install.py @@ -1,16 +1,15 @@ -import bitsandbytes as bnb import torch -p = torch.nn.Parameter(torch.rand(10,10).cuda()) -a = torch.rand(10,10).cuda() +import bitsandbytes as bnb + +p = torch.nn.Parameter(torch.rand(10, 10).cuda()) +a = torch.rand(10, 10).cuda() p1 = p.data.sum().item() adam = bnb.optim.Adam([p]) - - -out = a*p +out = a * p loss = out.sum() loss.backward() adam.step() @@ -18,5 +17,5 @@ p2 = p.data.sum().item() assert p1 != p2 -print('SUCCESS!') -print('Installation was successful!') +print("SUCCESS!") +print("Installation was successful!") diff --git a/compile_from_source.md b/compile_from_source.md deleted file mode 100644 index 23afe1591..000000000 --- a/compile_from_source.md +++ /dev/null @@ -1,40 +0,0 @@ -# Compiling from source - -Basic steps. -1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` -2. `python setup.py install` - -To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). - -You can install CUDA locally without sudo by following the following steps: - -```bash -wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh -# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122} -# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True - -# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc -bash install_cuda.sh 117 ~/local 1 -``` - -By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. - -Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed - -If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following: -``` -++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++ -/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so -``` -You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this. - -``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` - - -If you have problems compiling the library with these instructions from source, please open an issue. - -## Compilation with Kepler - -Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler` - diff --git a/csrc/common.cpp b/csrc/common.cpp index 52f029917..0a9601689 100644 --- a/csrc/common.cpp +++ b/csrc/common.cpp @@ -1,39 +1,35 @@ #include #include -void *quantize_block(void *arguments) { +void quantize_block(const quantize_block_args& args) { // 1. find absmax in block // 2. divide input value by absmax to normalize into [-1.0, 1.0] // 3. do binary search to find the closest value // 4. check minimal distance // 5. store index - struct quantize_block_args *args = (quantize_block_args *) arguments; - // 1. find absmax in block float absmax_block = -FLT_MAX; - for (long long i = args->block_idx; i < args->block_end; i++) - absmax_block = fmax(absmax_block, fabs(args->A[i])); + for (long long i = args.block_idx; i < args.block_end; i++) + absmax_block = fmax(absmax_block, fabs(args.A[i])); - args->absmax[args->block_idx / args->blocksize] = absmax_block; + args.absmax[args.block_idx / args.blocksize] = absmax_block; - for (long long i = args->block_idx; i < args->block_end; i++) { + for (long long i = args.block_idx; i < args.block_end; i++) { // 2. divide input value by absmax to normalize into [-1.0, 1.0] // 3. do binary search to find the closest value - float normed_value = args->A[i] / absmax_block; - long long idx = args->bin_searcher->scalar(normed_value); + float normed_value = args.A[i] / absmax_block; + long long idx = args.bin_searcher->scalar(normed_value); // 4. check minimal distance // The binary search returns always the value to the left, which might not be the closest value if (idx < 255) { - float dist_left = fabs(normed_value - (args->code[idx])); - float dist_right = fabs(normed_value - (args->code[idx + 1])); + float dist_left = fabs(normed_value - (args.code[idx])); + float dist_right = fabs(normed_value - (args.code[idx + 1])); if (dist_right < dist_left) { idx += 1; } } // 5. store index - args->out[i] = (unsigned char) idx; + args.out[i] = (unsigned char) idx; } - - return NULL; } diff --git a/csrc/common.h b/csrc/common.h index c99034e78..e513f2875 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -20,6 +20,6 @@ struct quantize_block_args { }; -void *quantize_block(void *arguments); +void quantize_block(const quantize_block_args& args); #endif diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e28e7b2c2..e67135360 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,6 +1,6 @@ #include -#include #include +#include using namespace BinSearch; @@ -26,17 +26,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long BinAlgo bin_searcher(code, elements_code); int thread_wave_size = 256; - // we chunk the thresds into waves of 256 since the max limit is + // we chunk the threads into waves of 256 since the max limit is // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) { long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); - - struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); - - for(long long i = 0; i < valid_chunks; i++) - args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + std::vector threads(valid_chunks); + std::vector args(valid_chunks); int chunks_processed = 0; for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) @@ -44,30 +40,24 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long block_end = block_idx + valid_items; - struct quantize_block_args *arg = args[chunks_processed]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx / blocksize; - arg->blocksize = blocksize; - - pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); + struct quantize_block_args& arg = args[chunks_processed]; + arg.bin_searcher = &bin_searcher; + arg.code = code; + arg.A = A; + arg.absmax = absmax; + arg.out = out; + arg.block_end = block_end; + arg.block_idx = block_idx; + arg.threadidx = block_idx / blocksize; + arg.blocksize = blocksize; + + threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); chunks_processed += 1; if(chunks_processed == valid_chunks){ break; } } for (int i = 0; i < valid_chunks; i++) - int err = pthread_join(threads[i], NULL); - - free(threads); - for (int i = 0; i < valid_chunks; i++) - free(args[i]); - free(args); - + threads[i].join(); } } diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9ebe0a69e..f4673359b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -110,7 +110,7 @@ __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) return 1.00000000f*absmax*sign; // 1011 else return 0.66666667f*absmax*sign; // 1010 - else + else if((val & 0b0001) == 1) // 100 return 5.208333333e-03f*absmax*sign; // 1001 else @@ -134,10 +134,10 @@ __device__ unsigned char dQuantizeFP4(float x) // we do a binary search // the pivots are divided by 12 (the FP4 absmax) - // since we assum input data is in [-1.0, 1.0] + // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake - // that is difficult to noice if you add an extra + // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; @@ -174,36 +174,36 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -211,12 +211,12 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -229,36 +229,36 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -266,12 +266,12 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -654,6 +654,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) temp_storage.smem_qidx[j] = -1; + __syncthreads(); + if(threadIdx.x < 256) { float q_interval = (1.0f-(2.0f*offset))/255.0f; @@ -1863,7 +1865,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; - + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -2259,8 +2261,8 @@ template__global__ void kd // data is in 32 column-tile major with tile width 32 columns and numRows rows // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) // C2. Compute normalization values and store col values in register // S1. Store C1 into 16-bit output @@ -2383,7 +2385,7 @@ template __global__ void kd if(valid_items <= 0) // the sub-tile might have more elements than the tile itself break; - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); @@ -2650,7 +2652,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B +//// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C -//// 7. aggreecate files of C into shared memroy block C +//// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix //} @@ -3531,7 +3533,7 @@ template __global__ void kgemm_4bit_inference(int M, i template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { - // per threadblock: + // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block @@ -3764,7 +3766,7 @@ template __global__ void kfunc(T *A, T *B, T value, long { switch(FUNC) { - case FILL: + case FILL: A[i] = (T)value; break; case ARANGE: @@ -3821,12 +3823,12 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 54b6afb9d..ca77dceda 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -5,7 +5,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include +#include #include #include #include @@ -22,7 +22,7 @@ // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda -// Luckily we have atomicmax and atomicmin in ROCm +// Luckily we have atomicmax and atomicmin in ROCm __device__ float dDequantizeFP4(unsigned char val, float absmax) { @@ -86,7 +86,7 @@ __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) return 1.00000000f*absmax*sign; // 1011 else return 0.66666667f*absmax*sign; // 1010 - else + else if((val & 0b0001) == 1) // 100 return 5.208333333e-03f*absmax*sign; // 1001 else @@ -110,10 +110,10 @@ __device__ unsigned char dQuantizeFP4(float x) // we do a binary search // the pivots are divided by 12 (the FP4 absmax) - // since we assum input data is in [-1.0, 1.0] + // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake - // that is difficult to noice if you add an extra + // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; @@ -150,36 +150,36 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -187,12 +187,12 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -205,36 +205,36 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -242,12 +242,12 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -1841,7 +1841,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; - + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -2118,7 +2118,7 @@ template__global__ void kd // data is in 32 column-tile major with tile width 32 columns and numRows rows // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) // C2. Compute normalization values and store col values in register // S1. Store C1 into 16-bit output @@ -2367,7 +2367,7 @@ template __global__ void kd #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]); - + // each block processes SUBTILE_ROWS*32 elements #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -2390,14 +2390,14 @@ template __global__ void kd if(valid_items <= 0) // the sub-tile might have more elements than the tile itself break; - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); @@ -2657,7 +2657,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * { //col-major offset int offset = local_colidx * rowsA + row; - + char val = A[offset]; int out_idx = (row*idx_size) + blockIdx.x; out[out_idx] = val; @@ -3087,11 +3087,11 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B +//// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C -//// 7. aggreecate files of C into shared memroy block C +//// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix //} @@ -3549,7 +3549,7 @@ template __global__ void kgemm_4bit_inference(int M, i template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { - // per threadblock: + // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block @@ -3782,7 +3782,7 @@ template __global__ void kfunc(T *A, T *B, T value, long { switch(FUNC) { - case FILL: + case FILL: A[i] = (T)value; break; case ARANGE: diff --git a/csrc/kernels.hiph b/csrc/kernels_hip.cuh similarity index 99% rename from csrc/kernels.hiph rename to csrc/kernels_hip.cuh index c842cc754..430218736 100644 --- a/csrc/kernels.hiph +++ b/csrc/kernels_hip.cuh @@ -6,7 +6,7 @@ // LICENSE file in the root directory of this source tree. #include -#include +#include #ifndef kernels #define kernels diff --git a/csrc/mps_kernels.metal b/csrc/mps_kernels.metal new file mode 100644 index 000000000..63b3bf78c --- /dev/null +++ b/csrc/mps_kernels.metal @@ -0,0 +1,117 @@ +#include +using namespace metal; + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +template +static unsigned char quantize_scalar( + float rand, + device float* code, + float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = code[pivot]; + } + + if(upper_pivot == 255) + upper = code[upper_pivot]; + if(lower_pivot == 0) + lower = code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabs(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabs(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +kernel void quantize(device float* code [[buffer(0)]], + device float* A [[buffer(1)]], + device uchar* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint id [[thread_position_in_grid]]) { + const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; + const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); + + float vals[NUM]; + uchar qvals[NUM]; + + for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint j = 0; j < valid_items; j++) { + vals[j] = A[i + j]; + } + + for (uint j = 0; j < valid_items; j++) { + qvals[j] = quantize_scalar(0.0f, code, vals[j]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint j = 0; j < valid_items; j++) { + out[i + j] = qvals[j]; + } + } +} diff --git a/csrc/mps_ops.h b/csrc/mps_ops.h new file mode 100644 index 000000000..e69de29bb diff --git a/csrc/mps_ops.mm b/csrc/mps_ops.mm new file mode 100644 index 000000000..d198b3552 --- /dev/null +++ b/csrc/mps_ops.mm @@ -0,0 +1,67 @@ +#import + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +static inline MPSGraph* get_graph() +{ + static MPSGraph* cur = nil; + if(!cur) { + cur = [[MPSGraph alloc] init]; + } + return cur; +} + +static inline id get_device() +{ + NSError *error = nil; + static id device = nil; + if(!device) { + device = MTLCreateSystemDefaultDevice(); + } + if(!device) { + NSLog(@"Failed to get MPS device"); + abort(); + } + return device; +} + +static inline id get_library() +{ + NSError *error = nil; + static id library = nil; + if(!library) { + library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + } + if(!library) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + return library; +} + +/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) +{ + id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; + return out; +}*/ + + +// MPSGraph function for quantize +extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) +{ + id device = get_device(); + id library = get_library(); + static id kernel = nil; + if(!kernel) { + kernel = [library newFunctionWithName:@"quantize"]; + if(!kernel) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + } + NSLog(@"Not implemented"); + return nil; +} diff --git a/csrc/ops.cu b/csrc/ops.cu index 97761216c..796211fed 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -11,6 +11,8 @@ #include #include +#define ERR_NOT_IMPLEMENTED 100 + using namespace BinSearch; using std::cout; @@ -421,14 +423,7 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT - cout << "" << endl; - cout << "=============================================" << endl; - cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; - cout << "=============================================" << endl; - cout << "" << endl; - assert(false); - - return 0; + return ERR_NOT_IMPLEMENTED; #else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; @@ -484,7 +479,7 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; -#endif +#endif // NO_CUBLASLT } int fill_up_to_nearest_multiple(int value, int multiple) diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3af..da9df6af0 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -9,7 +9,6 @@ #include #include -#include #include #include diff --git a/csrc/ops.hip b/csrc/ops.hip index e98e3b817..67cece5c1 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -5,8 +5,8 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include -#include +#include +#include #include #include #include @@ -18,6 +18,7 @@ #include #include +#define ERR_NOT_IMPLEMENTED 100 using namespace BinSearch; using std::cout; @@ -304,7 +305,8 @@ int roundoff(int v, int d) { } -#ifndef NO_HIPBLASLT +#ifdef NO_HIPBLASLT +#else template hipblasLtOrder_t get_order() { switch(ORDER) @@ -377,9 +379,10 @@ template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); -#ifndef NO_HIPBLASLT template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { +#ifdef NO_HIPBLASLT +#else hipblasLtOrder_t orderA = get_order(); hipblasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); @@ -434,6 +437,7 @@ template void trans if (B_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(B_desc)); if (out_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkHipblasStatus(hipblasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif } template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -449,7 +453,7 @@ template void transform(hipblasLtHandle_t ltHandle template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -#endif + static std::string hipError_to_string(const hipError_t ret) { switch(ret) @@ -504,18 +508,11 @@ static std::string hipError_to_string(const hipError_t ret) throw std::runtime_error("unknown hipError"); } } -#ifndef NO_HIPBLASLT + template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { -#ifdef NO_CUBLASLT - cout << "" << endl; - cout << "=============================================" << endl; - cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; - cout << "=============================================" << endl; - cout << "" << endl; - assert(false); - - return 0; +#ifdef NO_HIPBLASLT + return ERR_NOT_IMPLEMENTED; #else int has_error = 0; hipblasLtMatmulDesc_t matmulDesc = NULL; @@ -641,9 +638,8 @@ template int igemmlt(hipblasLtHandl printf("error detected"); return has_error; -#endif +#endif // NO_HIPBLASLT } -#endif int fill_up_to_nearest_multiple(int value, int multiple) { @@ -761,6 +757,10 @@ template void transformRowToFormat(char * A, char *o void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { + +#ifdef NO_HIPBLASLT +#else + hipsparseSpMatDescr_t descA; hipsparseDnMatDescr_t descB, descC; @@ -807,6 +807,7 @@ void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_va CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( hipFree(dBuffer) ); +#endif } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) @@ -930,14 +931,12 @@ template void extractOutliers(char * A, int *idx, char *out, int idx template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -#ifndef NO_HIPBLASLT template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -#endif template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.hiph b/csrc/ops_hip.cuh similarity index 98% rename from csrc/ops.hiph rename to csrc/ops_hip.cuh index 8e41f852a..1b9c13063 100644 --- a/csrc/ops.hiph +++ b/csrc/ops_hip.cuh @@ -16,9 +16,7 @@ #include #include #include -//#ifndef NO_HIPBLASLT #include -//#endif #include #include #include @@ -120,7 +118,6 @@ class Context }; -#ifndef NO_HIPBLASLT class ContextLt { public: @@ -133,7 +130,6 @@ class ContextLt m_handle = handle; } }; -#endif class ContextHipsparse { @@ -185,12 +181,9 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i long long int strideA, long long int strideB, long long int strideC, int batchCount); -#ifndef NO_HIPBLASLT template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); -#endif - void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.cpp similarity index 99% rename from csrc/pythonInterface.c rename to csrc/pythonInterface.cpp index c74357758..be6abc070 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.cpp @@ -7,7 +7,10 @@ #include #endif #if BUILD_HIP -#include +#include +#endif +#if BUILD_MPS +// #include #endif #include @@ -170,7 +173,6 @@ void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } #endif -#ifndef NO_HIPBLASLT #if BUILD_CUDA #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -204,7 +206,6 @@ MAKE_FUNC_TRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32); MAKE_FUNC_TRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8); MAKE_FUNC_TRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32); #endif -#endif void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } @@ -216,8 +217,6 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } -#ifndef NO_HIPBLASLT - #if defined(BUILD_CUDA) int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -258,8 +257,6 @@ void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int row { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } #endif -#endif - void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } @@ -406,8 +403,6 @@ extern "C" ContextHipsparse *get_hipsparse(){ return new ContextHipsparse(); } #endif - -#ifndef NO_HIPBLASLT #if BUILD_CUDA int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -434,7 +429,7 @@ extern "C" { \ transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ } \ - + #endif #if BUILD_HIP @@ -484,7 +479,7 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32) #endif -#endif + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) @@ -555,7 +550,7 @@ extern "C" int hasPrefetch = 0; CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead if (hasPrefetch == 0) return; - + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -577,7 +572,7 @@ extern "C" int hasPrefetch = 0; CUDA_CHECK_RETURN(hipDeviceGetAttribute(&hasPrefetch, hipDeviceAttributeConcurrentManagedAccess, device)); // 40ns overhead if (hasPrefetch == 0) return; - + CUDA_CHECK_RETURN(hipMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(hipPeekAtLastError()); } @@ -605,6 +600,7 @@ extern "C" { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #endif + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } diff --git a/deploy.sh b/deploy.sh index c261ee9a9..e60373627 100644 --- a/deploy.sh +++ b/deploy.sh @@ -5,7 +5,7 @@ echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!" echo $LD_LIBRARY_PATH if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -24,7 +24,7 @@ make cpuonly CUDA_VERSION="CPU" if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -34,7 +34,7 @@ make cuda110 CUDA_VERSION=110 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -44,7 +44,7 @@ make cuda11x CUDA_VERSION=111 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -54,7 +54,7 @@ make cuda11x CUDA_VERSION=114 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -64,7 +64,7 @@ make cuda11x CUDA_VERSION=115 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -74,7 +74,7 @@ make cuda11x CUDA_VERSION=117 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -84,7 +84,7 @@ make cuda118 CUDA_VERSION=118 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -94,7 +94,7 @@ make cuda12x CUDA_VERSION=120 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -104,7 +104,7 @@ make cuda12x CUDA_VERSION=121 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -114,7 +114,7 @@ make cuda12x CUDA_VERSION=122 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -124,7 +124,7 @@ make cuda12x CUDA_VERSION=123 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda123.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -138,7 +138,7 @@ make cuda110_nomatmul CUDA_VERSION=110 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -149,7 +149,7 @@ make cuda11x_nomatmul CUDA_VERSION=111 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -159,7 +159,7 @@ make cuda11x_nomatmul CUDA_VERSION=114 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -169,7 +169,7 @@ make cuda11x_nomatmul CUDA_VERSION=115 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -179,7 +179,7 @@ make cuda11x_nomatmul CUDA_VERSION=117 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -189,7 +189,7 @@ make cuda118_nomatmul CUDA_VERSION=118 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -199,7 +199,7 @@ make cuda12x_nomatmul CUDA_VERSION=120 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -209,7 +209,7 @@ make cuda12x_nomatmul CUDA_VERSION=121 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -219,7 +219,7 @@ make cuda12x_nomatmul CUDA_VERSION=122 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -229,7 +229,7 @@ make cuda12x_nomatmul CUDA_VERSION=123 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda123_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 28da69eb0..2184cce8c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -1,8 +1,58 @@ -- sections: +- title: Get started + sections: - local: index - title: Bits & Bytes + title: bitsandbytes - local: quickstart title: Quickstart - local: installation title: Installation - title: Get started \ No newline at end of file +- title: Guides + sections: + - local: optimizers + title: 8-bit optimizers + - local: algorithms + title: Algorithms + - local: integrations + title: Integrations + - local: errors + title: Troubleshoot + - local: contributing + title: Contribute + - local: faqs + title: FAQs +- title: Explanation + sections: + - local: explanations/optimizers + title: 8-bit optimizers + - local: explanations/resources + title: Papers, resources & how to cite +- title: API reference + sections: + - title: Optimizers + sections: + - local: reference/optim/optim_overview + title: Overview + - local: reference/optim/adagrad + title: AdaGrad + - local: reference/optim/adam + title: Adam + - local: reference/optim/adamw + title: AdamW + - local: reference/optim/lamb + title: LAMB + - local: reference/optim/lars + title: LARS + - local: reference/optim/lion + title: Lion + - local: reference/optim/rmsprop + title: RMSprop + - local: reference/optim/sgd + title: SGD + - title: k-bit quantizers + sections: + - local: reference/nn/linear8bit + title: 8-bit quantizer + - local: reference/nn/linear4bit + title: 4-bit quantizer + - local: reference/nn/embeddings + title: Embedding diff --git a/docs/source/algorithms.mdx b/docs/source/algorithms.mdx new file mode 100644 index 000000000..d9db5cb04 --- /dev/null +++ b/docs/source/algorithms.mdx @@ -0,0 +1,12 @@ +# Other algorithms +_WIP: Still incomplete... Community contributions would be greatly welcome!_ + +This is an overview of the `bnb.functional` API in `bitsandbytes` that we think would also be useful as standalone entities. + +## Using Int8 Matrix Multiplication + +For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: + +```py +bnb.matmul(..., threshold=6.0) +``` diff --git a/docs/source/contributing.mdx b/docs/source/contributing.mdx new file mode 100644 index 000000000..4fe6b7541 --- /dev/null +++ b/docs/source/contributing.mdx @@ -0,0 +1,25 @@ +# Contributors guidelines +... still under construction ... (feel free to propose materials, `bitsandbytes` is a community project) + +## Setup + +### Setup pre-commit hooks +- Install pre-commit hooks with `pip install pre-commit`. +- Run `pre-commit autoupdate` once to configure the hooks. +- Re-run `pre-commit autoupdate` every time a new hook got added. + +Now all the pre-commit hooks will be automatically run when you try to commit and if they introduce some changes, you need to re-add the changed files before being able to commit and push. + +### Ignore formatting revs +- Run `git config blame.ignoreRevsFile .git-blame-ignore-revs`. This will make it so that `git blame` is aware of commits that were logged to be solely formatting-related. + +## Doc-string syntax + +We're following NumPy doc-string conventions with the only notable difference being that we use Markdown instead of Rich text format (RTF) for markup within the doc-strings. + +Please see the existing documentation to see how to generate autodocs. + +## Documentation +- [guideline for documentation syntax](https://github.com/huggingface/doc-builder#readme) +- images shall be uploaded via PR in the `bitsandbytes/` directory [here](https://huggingface.co/datasets/huggingface/documentation-images) +- find the documentation builds for each PR in a link posted to the PR, such as https://moon-ci-docs.huggingface.co/docs/bitsandbytes/pr_1012/en/introduction diff --git a/errors_and_solutions.md b/docs/source/errors.mdx similarity index 68% rename from errors_and_solutions.md rename to docs/source/errors.mdx index 5b8cbcdd5..95594ea11 100644 --- a/errors_and_solutions.md +++ b/docs/source/errors.mdx @@ -1,21 +1,22 @@ -# No kernel image available +# Troubleshoot -This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? +## No kernel image available -If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. +This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. +To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME`` as well as ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? -__If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future. +If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. +## `fatbinwrap` -# fatbinwrap +This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your `$PATH` and `$LD_LIBRARY_PATH` variable. In the conda base environment you can find the library under: -This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under: ```bash ls $CONDA_PREFIX/lib/*cudart* ``` Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart). -If this does not fix the issue, please try [compilation from source](compile_from_source.md) next. +If this does not fix the issue, please try compilation from source next. If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb. diff --git a/docs/source/explanations/optimizers.mdx b/docs/source/explanations/optimizers.mdx new file mode 100644 index 000000000..327938e54 --- /dev/null +++ b/docs/source/explanations/optimizers.mdx @@ -0,0 +1,51 @@ +# 8-bit optimizers + +Stateful optimizers maintain gradient statistics over time, for example, the exponentially smoothed sum (SGD with momentum) or squared sum (Adam) of past gradient values. This state can be used to accelerate optimization compared to plain stochastic gradient descent, but uses memory that might otherwise be allocated to model parameters. As a result, this limits the maximum size of models that can be trained in practice. Now take a look at the biggest models that can be trained with 8-bit optimizers. + +
+
+ +
Depending on your GPU size, you can train a much larger model with a 8-bit optimizer.
+
+
+ +bitsandbytes optimizers use 8-bit statistics, while maintaining the performance levels of using 32-bit optimizer states. + +To overcome the resulting computational, quantization and stability challenges, 8-bit optimizers have three components: + +1. Block-wise quantization: divides input tensors into smaller blocks that are independently quantized, isolating outliers and distributing the error more equally over all bits. Each block is processed in parallel across cores, yielding faster optimization and high precision quantization. +2. Dynamic quantization: quantizes both small and large values with high precision. +3. Stable embedding layer: improves stability during optimization for models with word embeddings. + +With these components, performing an optimizer update with 8-bit states is straightforward. The 8-bit optimizer states are dequantized to 32-bit before you perform the update, and then the states are quantized back to 8-bit for storage. + +The 8-bit to 32-bit conversion happens element-by-element in registers, meaning no slow copies to GPU memory or additional temporary memory are needed to perform quantization and dequantization. For GPUs, this makes 8-bit optimizers much faster than regular 32-bit optimizers. + +
+
+ +
A comparison of memory and time saved using 8-bit and 32-bit optimizers.
+
+
+ +## Stable embedding layer + +The stable embedding layer improves the training stability of the standard word embedding layer for NLP tasks. It addresses the challenge of non-uniform input distributions and mitigates extreme gradient variations. This means the stable embedding layer can support more aggressive quantization strategies without compromising training stability, and it can help achieve stable training outcomes, which is particularly important for models dealing with diverse and complex language data. + +There are three features of the stable embedding layer: + +- Initialization: utilizes Xavier uniform initialization to maintain consistent variance, reducing the likelihood of large gradients. +- Normalization: incorporates layer normalization before adding positional embeddings, aiding in output stability. +- Optimizer states: employs 32-bit optimizer states exclusively for this layer to enhance stability, while the rest of the model may use standard 16-bit precision. + +## Paged optimizers + +Paged optimizers are built on top of the [unified memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) feature of CUDA. Unified memory provides a single memory space the GPU and CPU can easily access. While this feature is not supported by PyTorch, it has been added to bitsandbytes. + +Paged optimizers works like regular CPU paging, which means that it *only becomes active if you run out of GPU memory*. When that happens, memory is transferred page-by-page from GPU to CPU. The memory is mapped, meaning that pages are pre-allocated on the CPU but they are not updated automatically. Pages are only updated if the memory is accessed or a swapping operation is launched. + +The unified memory feature is less efficient than regular asynchronous memory transfers, and you usually won't be able to get full PCIe memory bandwidth utilization. If you do a manual prefetch, transfer speeds can be high but still only about half or worse than the full PCIe memory bandwidth (tested on 16x lanes PCIe 3.0). + +This means performance depends highly on the particular use-case. For example, if you evict 1 GB of memory per forward-backward-optimizer loop, then you can expect about 50% of the PCIe bandwidth as time in the best case. So, 1 GB for PCIe 3.0 with 16x lanes would run at 16 GB/s, which is `1/(16*0.5) = 1/8 = 125ms` of overhead per optimizer step. Other overhead can be estimated for the particular use-case given a PCIe interface, lanes, and the memory evicted in each iteration. + +Compared to CPU offloading, a paged optimizer has zero overhead if all the memory fits onto the device and only some overhead if some of memory needs to be evicted. For offloading, you usually offload fixed parts of the model and need to off and onload all this memory with each iteration through the model (sometimes twice for both forward and backward pass). diff --git a/docs/source/explanations/resources.mdx b/docs/source/explanations/resources.mdx new file mode 100644 index 000000000..56330175a --- /dev/null +++ b/docs/source/explanations/resources.mdx @@ -0,0 +1,92 @@ +# Papers, related resources & how to cite + +The below academic work is ordered in reverse chronological order. + +## [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression (Jun 2023)](https://arxiv.org/abs/2306.03078) + +Authors: Tim Dettmers, Ruslan Svirschevski, Vage Egiazarian, Denis Kuznedelev, Elias Frantar, Saleh Ashkboos, Alexander Borzunov, Torsten Hoefler, Dan Alistarh + +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1666076553665744896) + +``` +@article{dettmers2023spqr, + title={SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression}, + author={Dettmers, Tim and Svirschevski, Ruslan and Egiazarian, Vage and Kuznedelev, Denis and Frantar, Elias and Ashkboos, Saleh and Borzunov, Alexander and Hoefler, Torsten and Alistarh, Dan}, + journal={arXiv preprint arXiv:2306.03078}, + year={2023} +} +``` + +## [QLoRA: Efficient Finetuning of Quantized LLMs (May 2023)](https://arxiv.org/abs/2305.14314) +Authors: Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, Luke Zettlemoyer + +- [Video](https://www.youtube.com/watch?v=y9PHWGOa8HA&ab_channel=LondonMachineLearningMeetup) +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1661379354507476994) + +``` +@article{dettmers2023qlora, + title={Qlora: Efficient finetuning of quantized llms}, + author={Dettmers, Tim and Pagnoni, Artidoro and Holtzman, Ari and Zettlemoyer, Luke}, + journal={arXiv preprint arXiv:2305.14314}, + year={2023} +} +``` + +## [The case for 4-bit precision: k-bit Inference Scaling Laws (Dec 2022)](https://arxiv.org/abs/2212.09720) +Authors: Tim Dettmers, Luke Zettlemoyer + +- [Video](https://www.youtube.com/watch?v=odlQa6AE1gY&ab_channel=TheInsideView) +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1605209171758284805) + +``` +@inproceedings{dettmers2023case, + title={The case for 4-bit precision: k-bit inference scaling laws}, + author={Dettmers, Tim and Zettlemoyer, Luke}, + booktitle={International Conference on Machine Learning}, + pages={7750--7774}, + year={2023}, + organization={PMLR} +} +``` + +## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) +Authors: Tim Dettmers, Mike Lewis, Younes Belkada, Luke Zettlemoyer + +- [LLM.int8() Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) +- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) +- [Introduction to Weight Quantization](https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c) +- [Poster](https://twitter.com/Tim_Dettmers/status/1598351301942951937) + +``` +@article{dettmers2022llm, + title={Llm. int8 (): 8-bit matrix multiplication for transformers at scale}, + author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, + journal={arXiv preprint arXiv:2208.07339}, + year={2022} +} +``` + +## [8-bit Optimizers via Block-wise Quantization (Oct 2021)](https://arxiv.org/abs/2110.02861) +Authors: Tim Dettmers, Mike Lewis, Sam Shleifer, Luke Zettlemoyer + +- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1446472128979562499) + +``` +@article{DBLP:journals/corr/abs-2110-02861, + author = {Tim Dettmers and + Mike Lewis and + Sam Shleifer and + Luke Zettlemoyer}, + title = {8-bit Optimizers via Block-wise Quantization}, + journal = {CoRR}, + volume = {abs/2110.02861}, + year = {2021}, + url = {https://arxiv.org/abs/2110.02861}, + eprinttype = {arXiv}, + eprint = {2110.02861}, + timestamp = {Thu, 21 Oct 2021 16:20:08 +0200}, + biburl = {https://dblp.org/rec/journals/corr/abs-2110-02861.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/docs/source/faqs.mdx b/docs/source/faqs.mdx new file mode 100644 index 000000000..b9549e9d8 --- /dev/null +++ b/docs/source/faqs.mdx @@ -0,0 +1,7 @@ +# FAQs + +Please submit your questions in [this Github Discussion thread](https://github.com/TimDettmers/bitsandbytes/discussions/1013) if you feel that they will likely affect a lot of other users and that they haven't been sufficiently covered in the documentation. + +We'll pick the most generally applicable ones and post the QAs here or integrate them into the general documentation (also feel free to submit doc PRs, please). + +# ... under construction ... diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 68ad433e6..5943e7d1d 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -1,191 +1,13 @@ # bitsandbytes -The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. +bitsandbytes enables accessible large language models via k-bit quantization for PyTorch. bitsandbytes provides three main features for dramatically reducing memory consumption for inference and training: +* 8-bit optimizers uses block-wise quantization to maintain 32-bit performance at a small fraction of the memory cost. +* LLM.Int() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication. +* QLoRA or 4-bit quantization enables large language model training with several memory-saving techniques that don't compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training. +# License -Resources: -- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/) - -- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) - -## TL;DR -**Requirements** -Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. - -(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0) - -**Installation**: - -``pip install bitsandbytes`` - -In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below. - -Compilation quickstart: -```bash -git clone https://github.com/timdettmers/bitsandbytes.git -cd bitsandbytes - -# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120} -# make argument in {cuda110, cuda11x, cuda12x} -# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes -CUDA_VERSION=117 make cuda11x -python setup.py install -``` - -**Using Int8 inference with HuggingFace Transformers** - -```python -from transformers import AutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained( - 'decapoda-research/llama-7b-hf', - device_map='auto', - load_in_8bit=True, - max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') -``` - -A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py). - -**Using 8-bit optimizer**: -1. Comment out optimizer: ``#torch.optim.Adam(....)`` -2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) -3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)`` - - -**Using 8-bit Inference**: -1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)`` -2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same) -3. There are two modes: - - Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default) - - Int8 inference. Pass the argument ``has_fp16_weights=False`` -4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``. -```python -# LLM.int8() -linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0) -# inputs need to be fp16 -out = linear(x.to(torch.float16)) -``` - - -## Features -- 8-bit Matrix multiplication with mixed precision decomposition -- LLM.int8() inference -- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) -- Stable Embedding Layer: Improved stability through better initialization, and normalization -- 8-bit quantization: Quantile, Linear, and Dynamic quantization -- Fast quantile estimation: Up to 100x faster than other algorithms - -## Requirements & Installation - -Requirements: anaconda, cudatoolkit, pytorch - -Hardware requirements: - - LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or newer). - - 8-bit optimizers and quantization: NVIDIA Kepler GPU or newer (>=GTX 78X). - -Supported CUDA versions: 10.2 - 12.0 - -The bitsandbytes library is currently only supported on Linux distributions. Windows is not supported at the moment. - -The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website. - -To install run: - -``pip install bitsandbytes`` - -## Using bitsandbytes - -### Using Int8 Matrix Multiplication - -For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: -```python -bnb.matmul(..., threshold=6.0) -``` - -For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://huggingface.co/blog/hf-bitsandbytes-integration). - -### Using the 8-bit Optimizers - -With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way: -```python -import bitsandbytes as bnb - -# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer -adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer -adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent - - -torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP models -``` - -Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so: -```python -# parameter tensors with less than 16384 values are optimized in 32-bit -# it is recommended to use multiplies of 4096 -adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) -``` - -### Change Bits and other Hyperparameters for Individual Parameters - -If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details - -### Fairseq Users - -To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.). - -## Release and Feature History - -For upcoming features and changes and full history see [Patch Notes](CHANGELOG.md). - -## Errors - -1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available) -2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) - -## Compile from source -To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands. - -```bash -wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh -# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122} -# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True - -# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc -bash install_cuda.sh 117 ~/local 1 -``` - -To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`: - -``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` - -For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions. - -## License - -The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. +bitsandbytes is MIT licensed. We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. - -## How to cite us -If you found this library and found LLM.int8() useful, please consider citing our work: - -```bibtex -@article{dettmers2022llmint8, - title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale}, - author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, - journal={arXiv preprint arXiv:2208.07339}, - year={2022} -} -``` - -For 8-bit optimizers or quantization routines, please consider citing the following work: - -```bibtex -@article{dettmers2022optimizers, - title={8-bit Optimizers via Block-wise Quantization}, - author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke}, - journal={9th International Conference on Learning Representations, ICLR}, - year={2022} -} -``` \ No newline at end of file diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 035e3e70d..d0dd7ba76 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,3 +1,112 @@ # Installation -... work in progress ... \ No newline at end of file +bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**. Select your operating system below to see the installation instructions. + + + + +For Linux systems, make sure your hardware meets the following requirements to use bitsandbytes features. + +| **Feature** | **Hardware requirement** | +|---|---| +| LLM.int8() | NVIDIA Turing (RTX 20 series, T4) or Ampere (RTX 30 series, A4-A100) GPUs | +| 8-bit optimizers/quantization | NVIDIA Kepler (GTX 780 or newer) | + +> [!WARNING] +> bitsandbytes >= 0.39.1 no longer includes Kepler binaries in pip installations. This requires manual compilation, and you should follow the general steps and use `cuda11x_nomatmul_kepler` for Kepler-targeted compilation. + +To install from PyPI. + +```bash +pip install bitsandbytes +``` + +## Compile from source + +To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu: + +```bash +apt-get install -y build-essential cmake +``` + +You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide from NVIDIA. + +Now to install the bitsandbytes package from source, run the following commands: + +```bash +git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ +pip install -r requirements-dev.txt +cmake -DCOMPUTE_BACKEND=cuda -S . +make +pip install . +``` + +> [!TIP] +> If you have multiple versions of CUDA installed or installed it in a non-standard location, please refer to CMake CUDA documentation for how to configure the CUDA compiler. + + + + +Windows systems require Visual Studio with C++ support as well as an installation of the CUDA SDK. + +You'll need to build bitsandbytes from source. To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. + +```bash +git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ +pip install -r requirements-dev.txt +cmake -DCOMPUTE_BACKEND=cuda -S . +cmake --build . --config Release +python -m build --wheel +``` + +Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows. + + + + +> [!TIP] +> MacOS support is still a work in progress! Subscribe to this [issue](https://github.com/TimDettmers/bitsandbytes/issues/1020) to get notified about discussions and to track the integration progress. + + + + +## PyTorch CUDA versions + +Some bitsandbytes features may need a newer CUDA version than the one currently supported by PyTorch binaries from Conda and pip. In this case, you should follow these instructions to load a precompiled bitsandbytes binary. + +1. Determine the path of the CUDA version you want to use. Common paths include: + +* `/usr/local/cuda` +* `/usr/local/cuda-XX.X` where `XX.X` is the CUDA version number + +Then locally install the CUDA version you need with this script from bitsandbytes: + +```bash +wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh +# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124} +# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True + +# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc + +bash install_cuda.sh 117 ~/local 1 +``` + +2. Set the environment variables `BNB_CUDA_VERSION` and `LD_LIBRARY_PATH` by manually overriding the CUDA version installed by PyTorch. + +> [!TIP] +> It is recommended to add the following lines to the `.bashrc` file to make them permanent. + +```bash +export BNB_CUDA_VERSION= +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: +``` + +For example, to use a local install path: + +```bash +export BNB_CUDA_VERSION=117 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7 +``` + +3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. diff --git a/docs/source/integrations.mdx b/docs/source/integrations.mdx new file mode 100644 index 000000000..4badece49 --- /dev/null +++ b/docs/source/integrations.mdx @@ -0,0 +1,137 @@ +# Integrations + +bitsandbytes is widely integrated with many of the libraries in the Hugging Face and wider PyTorch ecosystem. This guide provides a brief overview of the integrations and how to use bitsandbytes with them. For more details, you should refer to the linked documentation for each library. + +## Transformers + +> [!TIP] +> Learn more in the bitsandbytes Transformers integration [guide](https://huggingface.co/docs/transformers/quantization#bitsandbytes). + +With Transformers, it's very easy to load any model in 4 or 8-bit and quantize them on the fly. To configure the quantization parameters, specify them in the [`~transformers.BitsAndBytesConfig`] class. + +For example, to load and quantize a model to 4-bits and use the bfloat16 data type for compute: + +> [!WARNING] +> bfloat16 is the optimal compute data type if your hardware supports it. The default is float32 for backward compatibility and numerical stability, but it can often lead to numerical instabilities. bfloat16 provides the best of both worlds, numerical stability equivalent to float32, but combined with the memory footprint and significant computation speedup of a 16-bit data type. Make sure to check if your hardware supports bfloat16 and if it does, configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]! + +```py +from transformers import AutoModelForCausalLM, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) +model_4bit = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-1b7", + device_map=device_map, + quantization_config=quantization_config, +) +``` + +### 8-bit optimizers + +You can use any of the 8-bit or paged optimizers with Transformers by passing them to the [`~transformers.Trainer`] class on initialization. All bitsandbytes optimizers are supported by passing the correct string in the [`~transformers.TrainingArguments`] `optim` parameter. For example, to load a [`~bitsandbytes.optim.PagedAdamW32bit`] optimizer: + +```py +from transformers import TrainingArguments, Trainer + +training_args = TrainingArguments( + ..., + optim="paged_adamw_32bit", +) +trainer = Trainer(model, training_args, ...) +trainer.train() +``` + +## PEFT + +> [!TIP] +> Learn more in the bitsandbytes PEFT integration [guide](https://huggingface.co/docs/peft/developer_guides/quantization#quantization). + +PEFT builds on the bitsandbytes Transformers integration, and extends it for training with a few more steps. Let's prepare the 4-bit model from the section above for training. + +Call the [`~peft.prepare_model_for_kbit_training`] method to prepare the model for training. This only works for Transformers models! + +```py +from peft import prepare_model_for_kbit_training + +model_4bit = prepare_model_for_kbit_training(model_4bit) +``` + +Setup a [`~peft.LoraConfig`] to use QLoRA: + +```py +from peft import LoraConfig + +config = LoraConfig( + r=16, + lora_alpha=8, + target_modules="all-linear", + lora_dropout=0.05 + bias="none", + task_type="CAUSAL_LM" +) +``` + +Now call the [`~peft.get_peft_model`] function on your model and config to create a trainable [`PeftModel`]. + +```py +from peft import get_peft_model + +model = get_peft_model(model_4bit, config) +``` + +## Accelerate + +> [!TIP] +> Learn more in the bitsandbytes Accelerate integration [guide](https://huggingface.co/docs/accelerate/usage_guides/quantization). + +bitsandbytes is also easily usable from Accelerate and you can quantize any PyTorch model by passing a [`~accelerate.utils.BnbQuantizationConfig`] with your desired settings, and then calling the [`~accelerate.utils.load_and_quantize_model`] function to quantize it. + +```py +from accelerate import init_empty_weights +from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model +from mingpt.model import GPT + +model_config = GPT.get_default_config() +model_config.model_type = 'gpt2-xl' +model_config.vocab_size = 50257 +model_config.block_size = 1024 + +with init_empty_weights(): + empty_model = GPT(model_config) + +bnb_quantization_config = BnbQuantizationConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, # optional + bnb_4bit_use_double_quant=True, # optional + bnb_4bit_quant_type="nf4" # optional +) + +quantized_model = load_and_quantize_model( + empty_model, + weights_location=weights_location, + bnb_quantization_config=bnb_quantization_config, + device_map = "auto" +) +``` + +## PyTorch Lightning and Lightning Fabric + +bitsandbytes is available from: + +- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), a deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. +- [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), a fast and lightweight way to scale PyTorch models without boilerplate. + +Learn more in the bitsandbytes PyTorch Lightning integration [guide](https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#quantization-via-bitsandbytes). + + +## Lit-GPT + +bitsandbytes is integrated with [Lit-GPT](https://github.com/Lightning-AI/lit-gpt), a hackable implementation of state-of-the-art open-source large language models. Lit-GPT is based on Lightning Fabric, and it can be used for quantization during training, finetuning, and inference. + +Learn more in the bitsandbytes Lit-GPT integration [guide](https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md). + +## Blog posts + +To learn in more detail about some of bitsandbytes integrations, take a look at the following blog posts: + +- [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) +- [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes](https://huggingface.co/blog/hf-bitsandbytes-integration) diff --git a/docs/source/optimizers.mdx b/docs/source/optimizers.mdx new file mode 100644 index 000000000..7d04f82b1 --- /dev/null +++ b/docs/source/optimizers.mdx @@ -0,0 +1,94 @@ +# 8-bit optimizers + +With 8-bit optimizers, large models can be finetuned with 75% less GPU memory without losing any accuracy compared to training with standard 32-bit optimizers. The reduced memory requirements means 8-bit optimizers are 4x faster than a standard optimizer, and no hyperparameter tuning is required. + +This guide will show you how to use 8-bit optimizers. + +> [!WARNING] +> 8-bit optimizers reduce memory usage and accelerate optimization on a wide range of tasks. However, since 8-bit optimizers only reduce memory proportional to the number of parameters, models that use large amounts of activation memory, such as convolutional networks, don't really benefit from 8-bit optimizers. 8-bit optimizers are most beneficial for training or finetuning models with many parameters on highly memory-constrained GPUs. + +8-bit optimizers are a drop-in replacement for regular optimizers which means they also accept the same arguments as a regular optimizer. For NLP models, it is recommended to use the [`~nn.StableEmbedding`] class to improve stability and results. + +```diff +import bitsandbytes as bnb + +- adam = torch.optim.Adam(...) ++ adam = bnb.optim.Adam8bit(...) + +# recommended for NLP models +- before: torch.nn.Embedding(...) ++ bnb.nn.StableEmbedding(...) +``` + +By default, all parameter tensors with less than 4096 elements are kept at 32-bits even if you initialize those parameters with 8-bit optimizers. This is done because small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). + +You can change this value with the `min_8bit_size` parameter. For example, if you want to optimize parameters to 8-bits only if the minimum size is 16384 values (it is recommended to use multiples of 4096): + +```py +import bitsandbytes as bnb + +adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) +``` + +Other parameters you can configure include the learning rate (`lr`), the decay rates (`betas`), the number of bits of the optimizer state (`optim_bits`), and percentile clipping (`percentile_clipping`) which can increase stability. For example, to initialize a 32-bit [`~bitsandbytes.optim.Adam`] optimizer with 5th percentile clipping: + +```py +import bitsandbytes as bnb + +adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5) +``` + +## Optimize unstable parameters + +To optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, use the [`~bitsandbytes.optim.GlobalOptimManager`] class to override the specific hyperparameters for a particular layer. You'll need to: + +1. Register the parameters while they're on the CPU. + +```py +import torch +import bitsandbytes as bnb + +mng = bnb.optim.GlobalOptimManager.get_instance() + +model = MyModel() +mng.register_parameters(model.parameters()) +``` + +2. Override the config with the new desired hyperparameters. For example, let's override the `model.fc1.weight` layer to use 32-bit Adam. + +> [!TIP] +> Check the optimizer API documentation for more information about other hyperparameters you can override. + +```py +model = model.cuda() +# use 8-bit optimizer states for all parameters +adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) + +# override the parameter model.fc1.weight now uses 32-bit Adam +mng.override_config(model.fc1.weight, "optim_bits", 32) +``` + +You can also override multiple layers at once by passing them as a list and the new hyperparameters as a dictionary. For example, let's override the `model.special.weight` and `model.also_special.weight` layers to use sparse optimization and a lower learning and decay rate. + +```py +mng.override_config([model.special.weight, model.also_special.weight], + key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) +``` + +For a specific layer, we recommend overriding locally in each module. Pass the module, the parameter, and its attribute name to the [`~bitsandbytes.optim.GlobalOptimManager`]: + +```py +class MyModule(torch.nn.Module): + def __init__(d_in, d_out): + super(MyModule, self).__init__() + self.linear = torch.nn.Linear(d_in, d_out) + # optimization will happen in 32-bit and + # learning rate will be set to 0.0001 independent of the main learning rate + config = {'optim_bits': 32, 'lr' : 0.0001} + GlobalOptimManager.get_instance().register_module_override(self, 'weight', config) + +``` + +## Next steps + +For more conceptual details and explanation about 8-bit optimizers, take a look at the [8-bit optimizers](./explanations/optimizers) guide. diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx index 4dff2ba46..ed92c896b 100644 --- a/docs/source/quickstart.mdx +++ b/docs/source/quickstart.mdx @@ -4,9 +4,12 @@ ... work in progress ... -## Minimal example +(Community contributions would we very welcome!) -The following code illustrates the steps above. +## Minimal examples -```python -``` \ No newline at end of file +The following code illustrates the steps above. + +```py +code examples will soon follow +``` diff --git a/docs/source/reference/nn/embeddings.mdx b/docs/source/reference/nn/embeddings.mdx new file mode 100644 index 000000000..e725ecb17 --- /dev/null +++ b/docs/source/reference/nn/embeddings.mdx @@ -0,0 +1,15 @@ +# Embedding + +The embedding class is used to store and retrieve word embeddings from their indices. There are two types of embeddings in bitsandbytes, the standard PyTorch [`Embedding`] class and the [`StableEmbedding`] class. + +The [`StableEmbedding`] class was introduced in the [8-bit Optimizers via Block-wise Quantization](https://hf.co/papers/2110.02861) paper to reduce gradient variance as a result of the non-uniform distribution of input tokens. This class is designed to support quantization. + +## Embedding + +[[autodoc]] bitsandbytes.nn.Embedding + - __init__ + +## StableEmbedding + +[[autodoc]] bitsandbytes.nn.StableEmbedding + - __init__ diff --git a/docs/source/reference/nn/linear4bit.mdx b/docs/source/reference/nn/linear4bit.mdx new file mode 100644 index 000000000..3cbf6509d --- /dev/null +++ b/docs/source/reference/nn/linear4bit.mdx @@ -0,0 +1,23 @@ +# 4-bit quantization + +[QLoRA](https://hf.co/papers/2305.14314) is a finetuning method that quantizes a model to 4-bits and adds a set of low-rank adaptation (LoRA) weights to the model and tuning them through the quantized weights. This method also introduces a new data type, 4-bit NormalFloat (`LinearNF4`) in addition to the standard Float4 data type (`LinearFP4`). `LinearNF4` is a quantization data type for normally distributed data and can improve performance. + +## Linear4bit + +[[autodoc]] bitsandbytes.nn.Linear4bit + - __init__ + +## LinearFP4 + +[[autdodoc]] bitsandbytes.nn.LinearFP4 + - __init__ + +## LinearNF4 + +[[autodoc]] bitsandbytes.nn.LinearNF4 + - __init__ + +## Params4bit + +[[autodoc]] bitsandbytes.nn.Params4bit + - __init__ diff --git a/docs/source/reference/nn/linear8bit.mdx b/docs/source/reference/nn/linear8bit.mdx new file mode 100644 index 000000000..73254fe67 --- /dev/null +++ b/docs/source/reference/nn/linear8bit.mdx @@ -0,0 +1,13 @@ +# 8-bit quantization + +[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output. + +## Linear8bitLt + +[[autodoc]] bitsandbytes.nn.Linear8bitLt + - __init__ + +## Int8Params + +[[autodoc]] bitsandbytes.nn.Int8Params + - __init__ diff --git a/docs/source/reference/optim/adagrad.mdx b/docs/source/reference/optim/adagrad.mdx new file mode 100644 index 000000000..8dddba04c --- /dev/null +++ b/docs/source/reference/optim/adagrad.mdx @@ -0,0 +1,18 @@ +# AdaGrad + +[AdaGrad (Adaptive Gradient)](https://jmlr.org/papers/v12/duchi11a.html) is an adaptive learning rate optimizer. AdaGrad stores a sum of the squared past gradients for each parameter and uses it to scale their learning rate. This allows the learning rate to be automatically lower or higher depending on the magnitude of the gradient, eliminating the need to manually tune the learning rate. + +## Adagrad[[api-class]] + +[[autodoc]] bitsandbytes.optim.Adagrad + - __init__ + +## Adagrad8bit + +[[autodoc]] bitsandbytes.optim.Adagrad8bit + - __init__ + +## Adagrad32bit + +[[autodoc]] bitsandbytes.optim.Adagrad32bit + - __init__ diff --git a/docs/source/reference/optim/adam.mdx b/docs/source/reference/optim/adam.mdx new file mode 100644 index 000000000..f367bc415 --- /dev/null +++ b/docs/source/reference/optim/adam.mdx @@ -0,0 +1,38 @@ +# Adam + +[Adam (Adaptive moment estimation)](https://hf.co/papers/1412.6980) is an adaptive learning rate optimizer, combining ideas from [`SGD`] with momentum and [`RMSprop`] to automatically scale the learning rate: + +- a weighted average of the past gradients to provide direction (first-moment) +- a weighted average of the *squared* past gradients to adapt the learning rate to each parameter (second-moment) + +bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted. + +## Adam[[api-class]] + +[[autodoc]] bitsandbytes.optim.Adam + - __init__ + +## Adam8bit + +[[autodoc]] bitsandbytes.optim.Adam8bit + - __init__ + +## Adam32bit + +[[autodoc]] bitsandbytes.optim.Adam32bit + - __init__ + +## PagedAdam + +[[autodoc]] bitsandbytes.optim.PagedAdam + - __init__ + +## PagedAdam8bit + +[[autodoc]] bitsandbytes.optim.PagedAdam8bit + - __init__ + +## PagedAdam32bit + +[[autodoc]] bitsandbytes.optim.PagedAdam32bit + - __init__ diff --git a/docs/source/reference/optim/adamw.mdx b/docs/source/reference/optim/adamw.mdx new file mode 100644 index 000000000..e3dd410de --- /dev/null +++ b/docs/source/reference/optim/adamw.mdx @@ -0,0 +1,34 @@ +# AdamW + +[AdamW](https://hf.co/papers/1711.05101) is a variant of the [`Adam`] optimizer that separates weight decay from the gradient update based on the observation that the weight decay formulation is different when applied to [`SGD`] and [`Adam`]. + +bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted. + +## AdamW[[api-class]] + +[[autodoc]] bitsandbytes.optim.AdamW + - __init__ + +## AdamW8bit + +[[autodoc]] bitsandbytes.optim.AdamW8bit + - __init__ + +## AdamW32bit + +[[autodoc]] bitsandbytes.optim.AdamW32bit + - __init__ + +## PagedAdamW + +[[autodoc]] bitsandbytes.optim.PagedAdamW + - __init__ +## PagedAdamW8bit + +[[autodoc]] bitsandbytes.optim.PagedAdamW8bit + - __init__ + +## PagedAdamW32bit + +[[autodoc]] bitsandbytes.optim.PagedAdamW32bit + - __init__ diff --git a/docs/source/reference/optim/lamb.mdx b/docs/source/reference/optim/lamb.mdx new file mode 100644 index 000000000..d581380ed --- /dev/null +++ b/docs/source/reference/optim/lamb.mdx @@ -0,0 +1,21 @@ +# LAMB + +[LAMB (Layerwise adaptive large batch optimization)](https://hf.co/papers/1904.00962) is an adaptive optimizer designed for training with large batch sizes to accelerate training, combining ideas from [`LARS`] and [`Adam`] to automatically scale the learning rate for each layer: + +- calculates a *trust ratio* between the weight and gradient norm in a layer and clips the ratio to prevent overly large or small updates +- updates weights with the first and second-moments + +## LAMB[[api-class]] + +[[autodoc]] bitsandbytes.optim.LAMB + - __init__ + +## LAMB8bit + +[[autodoc]] bitsandbytes.optim.LAMB8bit + - __init__ + +## LAMB32bit + +[[autodoc]] bitsandbytes.optim.LAMB32bit + - __init__ diff --git a/docs/source/reference/optim/lars.mdx b/docs/source/reference/optim/lars.mdx new file mode 100644 index 000000000..93b5c55c3 --- /dev/null +++ b/docs/source/reference/optim/lars.mdx @@ -0,0 +1,18 @@ +# LARS + +[LARS (Layer-wise Adaptive Rate Scaling)](https:/hf.co/papers/1708.03888) is an optimizer designed for training with large batch sizes to accelerate training. LARS uses a separate learning rate for each *layer* instead of each parameter. The learning rate is calculated from a *trust ratio* between the weight and gradient norm in a layer. This helps calibrate a stable update size. + +## LARS[[api-class]] + +[[autodoc]] bitsandbytes.optim.LARS + - __init__ + +## LARS8bit + +[[autodoc]] bitsandbytes.optim.LARS8bit + - __init__ + +## LARS32bit + +[[autodoc]] bitsandbytes.optim.LARS32bit + - __init__ diff --git a/docs/source/reference/optim/lion.mdx b/docs/source/reference/optim/lion.mdx new file mode 100644 index 000000000..8183c27e7 --- /dev/null +++ b/docs/source/reference/optim/lion.mdx @@ -0,0 +1,33 @@ +# Lion + +[Lion (Evolved Sign Momentum)](https://hf.co/papers/2302.06675) is a unique optimizer that uses the sign of the gradient to determine the update direction of the momentum. This makes Lion more memory-efficient and faster than [`AdamW`] which tracks and store the first and second-order moments. + +## Lion[[api-class]] + +[[autodoc]] bitsandbytes.optim.Lion + - __init__ + +## Lion8bit + +[[autodoc]] bitsandbytes.optim.Lion8bit + - __init__ + +## Lion32bit + +[[autodoc]] bitsandbytes.optim.Lion32bit + - __init__ + +## PagedLion + +[[autodoc]] bitsandbytes.optim.PagedLion + - __init__ + +## PagedLion8bit + +[[autodoc]] bitsandbytes.optim.PagedLion8bit + - __init__ + +## PagedLion32bit + +[[autodoc]] bitsandbytes.optim.PagedLion32bit + - __init__ diff --git a/docs/source/reference/optim/optim_overview.mdx b/docs/source/reference/optim/optim_overview.mdx new file mode 100644 index 000000000..48e12b544 --- /dev/null +++ b/docs/source/reference/optim/optim_overview.mdx @@ -0,0 +1,24 @@ +# Overview + +[8-bit optimizers](https://hf.co/papers/2110.02861) reduce the memory footprint of 32-bit optimizers without any performance degradation which means you can train large models with many parameters faster. At the core of 8-bit optimizers is block-wise quantization which enables quantization accuracy, computational efficiency, and stability. + +bitsandbytes provides 8-bit optimizers through the base [`Optimizer8bit`] class, and additionally provides [`Optimizer2State`] and [`Optimizer1State`] for 2-state (for example, [`Adam`]) and 1-state (for example, [`Adagrad`]) optimizers respectively. To provide custom optimizer hyperparameters, use the [`GlobalOptimManager`] class to configure the optimizer. + +## Optimizer8bit + +[[autodoc]] bitsandbytes.optim.optimizer.Optimizer8bit + - __init__ + +## Optimizer2State + +[[autodoc]] bitsandbytes.optim.optimizer.Optimizer2State + - __init__ + +## Optimizer1State + +[[autodoc]] bitsandbytes.optim.optimizer.Optimizer1State + - __init__ + +## Utilities + +[[autodoc]] bitsandbytes.optim.optimizer.GlobalOptimManager diff --git a/docs/source/reference/optim/rmsprop.mdx b/docs/source/reference/optim/rmsprop.mdx new file mode 100644 index 000000000..33d839f6b --- /dev/null +++ b/docs/source/reference/optim/rmsprop.mdx @@ -0,0 +1,15 @@ +# RMSprop + +RMSprop is an adaptive learning rate optimizer that is very similar to [`Adagrad`]. RMSprop stores a *weighted average* of the squared past gradients for each parameter and uses it to scale their learning rate. This allows the learning rate to be automatically lower or higher depending on the magnitude of the gradient, and it prevents the learning rate from diminishing. + +## RMSprop[[api-class]] + +[[autodoc]] bitsandbytes.optim.RMSprop + +## RMSprop8bit + +[[autodoc]] bitsandbytes.optim.RMSprop8bit + +## RMSprop32bit + +[[autodoc]] bitsandbytes.optim.RMSprop32bit diff --git a/docs/source/reference/optim/sgd.mdx b/docs/source/reference/optim/sgd.mdx new file mode 100644 index 000000000..a0d09d1e8 --- /dev/null +++ b/docs/source/reference/optim/sgd.mdx @@ -0,0 +1,20 @@ +# SGD + +Stochastic gradient descent (SGD) is a basic gradient descent optimizer to minimize loss given a set of model parameters and updates the parameters in the opposite direction of the gradient. The update is performed on a randomly sampled mini-batch of data from the dataset. + +bitsandbytes also supports momentum and Nesterov momentum to accelerate SGD by adding a weighted average of past gradients to the current gradient. + +## SGD[[api-class]] + +[[autodoc]] bitsandbytes.optim.SGD + - __init__ + +## SGD8bit + +[[autodoc]] bitsandbytes.optim.SGD8bit + - __init__ + +## SGD32bit + +[[autodoc]] bitsandbytes.optim.SGD32bit + - __init__ diff --git a/environment-bnb.yml b/environment-bnb.yml new file mode 100644 index 000000000..1214f7930 --- /dev/null +++ b/environment-bnb.yml @@ -0,0 +1,21 @@ +# for cmake build +name: bnb +channels: + - pytorch + - nvidia + - conda-forge + +dependencies: + - python + #- accelerate + #- einops + - scipy + #- transformers + - pytest + - pytest-cases + - ipython + - debugpy + - yapf + - monkeytype + - rich + - pytest-sugar diff --git a/environment.yml b/environment.yml index c0e07f153..af421b3c6 100644 --- a/environment.yml +++ b/environment.yml @@ -27,6 +27,7 @@ dependencies: - conda-forge::monkeytype # infer type annotations - conda-forge::rich # better, colored tracebacks, etc - conda-forge::pytest-sugar # better pytest output + # - conda-forge::nodejs # for `doc-builder preview` (optional) ## ENV CREATION - steps to reproduce: # mamba env remove -n bnb @@ -42,4 +43,4 @@ dependencies: ## ENV UPDATE: # # add new packages to environment.yml, then: -# mamba env update -n bnb -f environment.yml \ No newline at end of file +# mamba env update -n bnb -f environment.yml diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py index dc80a44db..2d4c77952 100644 --- a/examples/int8_inference_huggingface.py +++ b/examples/int8_inference_huggingface.py @@ -1,27 +1,19 @@ import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import LlamaForCausalLM, LlamaTokenizer MAX_NEW_TOKENS = 128 -model_name = 'decapoda-research/llama-7b-hf' +model_name = "meta-llama/Llama-2-7b-hf" -text = 'Hamburg is in which country?\n' -tokenizer = AutoTokenizer.from_pretrained(model_name) +text = "Hamburg is in which country?\n" +tokenizer = LlamaTokenizer.from_pretrained(model_name) input_ids = tokenizer(text, return_tensors="pt").input_ids -free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) -max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' +max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" n_gpus = torch.cuda.device_count() max_memory = {i: max_memory for i in range(n_gpus)} -model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map='auto', - load_in_8bit=True, - max_memory=max_memory -) +model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory) + generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) - - - diff --git a/how_to_use_nonpytorch_cuda.md b/how_to_use_nonpytorch_cuda.md deleted file mode 100644 index b5f01fbe5..000000000 --- a/how_to_use_nonpytorch_cuda.md +++ /dev/null @@ -1,46 +0,0 @@ -## How to use a CUDA version that is different from PyTorch - -Some features of bitsandbytes may need a newer CUDA version than regularly supported by PyTorch binaries from conda / pip. In that case you can use the following instructions to load a precompiled bitsandbytes binary that works for you. - -## Installing or determining the CUDA installation - -Determine the path of the CUDA version that you want to use. Common paths paths are: -```bash -/usr/local/cuda -/usr/local/cuda-XX.X -``` - -where XX.X is the CUDA version number. - -You can also install CUDA version that you need locally with a script provided by bitsandbytes as follows: - -```bash -wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh -# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122} -# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True - -# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc - -bash cuda_install.sh 117 ~/local 1 -``` - -## Setting the environmental variables BNB_CUDA_VERSION, and LD_LIBRARY_PATH - -To manually override the PyTorch installed CUDA version you need to set to variable, like so: - -```bash -export BNB_CUDA_VERSION= -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: -``` - -For example, to use the local install path from above: - -```bash -export BNB_CUDA_VERSION=117 -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7 -``` - -It is best to add these lines to the `.bashrc` file to make them permanent. - -If you now launch bitsandbytes with these environmental variables the PyTorch CUDA version will be overridden by the new CUDA version and a different bitsandbytes library is loaded (in this case version 117). diff --git a/howto_config_override.md b/howto_config_override.md deleted file mode 100644 index 55b24e3ab..000000000 --- a/howto_config_override.md +++ /dev/null @@ -1,40 +0,0 @@ -# How to override config hyperparameters for particular weights/parameters - -If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details - -For global overrides in many different places in your code you can do: -```python -import torch -import bitsandbytes as bnb - -mng = bnb.optim.GlobalOptimManager.get_instance() - -model = MyModel() -mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU - -model = model.cuda() -# use 8-bit optimizer states for all parameters -adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) - -# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam -mng.override_config(model.fc1.weight, 'optim_bits', 32) - -# 2b. override: the two special layers use -# sparse optimization + different learning rate + different Adam betas -mng.override_config([model.special.weight, model.also_special.weight], - key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) -``` -Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm` - -For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: -```python -class MyModule(torch.nn.Module): - def __init__(din, dout): - super(MyModule, self).__init__() - self.linear = torch.nn.Linear(din, dout) - # optimization will happen in 32-bit and - # learning rate will be set to 0.0001 independent of the main learning rate - config = {'optim_bits': 32, 'lr' : 0.0001} - GlobalOptimManager.get_instance().register_module_override(self, 'weight', config) - -``` diff --git a/include/Algo-Direct-Common.h b/include/Algo-Direct-Common.h index c97084904..7b40edea9 100644 --- a/include/Algo-Direct-Common.h +++ b/include/Algo-Direct-Common.h @@ -190,7 +190,7 @@ struct DirectInfo xi = xws; } else { - myassert(Gap==1, "if Gap>1 then X workspace must be provided"); + myassert((Gap==1), "if Gap>1 then X workspace must be provided"); xi = x; } diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h index 347ec9c5e..91dded6f4 100644 --- a/include/Algo-Direct2.h +++ b/include/Algo-Direct2.h @@ -52,6 +52,7 @@ struct AlgoVecBase::val private: typedef AlgoScalarBase base_t; +#ifdef USE_SSE2 FORCE_INLINE //NO_INLINE void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const @@ -135,6 +136,7 @@ struct AlgoVecBase::val pr[0] = u.ui32[0]; pr[1] = u.ui32[2]; } +#endif // USE_SSE2 #ifdef USE_AVX @@ -157,7 +159,7 @@ struct AlgoVecBase::val FVec vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float)); IVec ip = idxm; -#else // do not use gather instrucions +#else // do not use gather instructions union U { __m256i vec; diff --git a/include/Portable.h b/include/Portable.h index 1710b0502..090a25065 100644 --- a/include/Portable.h +++ b/include/Portable.h @@ -4,10 +4,40 @@ #include #include +#if defined(__aarch64__) +#ifdef __CUDACC__ +#undef USE_NEON // Doesn't work with nvcc, undefined symbols +#else +#include +#undef USE_NEON // Not yet implemented +#endif +#undef USE_AVX // x86_64 only +#undef USE_AVX2 // x86_64 only +#undef USE_SSE2 // x86_64 only +#undef USE_SSE41 // x86_64 only +#undef USE_SSE42 // x86_64 only +#undef USE_FMA // x86_64 only +#ifdef USE_NEON +typedef float32x4_t __m128; +typedef int32x4_t __m128i; +typedef float64x2_t __m128d; +#else +typedef struct {float a; float b; float c; float d;} __m128; +typedef struct {int a; int b; int c; int d;} __m128i; +typedef struct {double a; double b;} __m128d; +#endif +#else +#undef USE_NEON // ARM64 only #ifdef __FMA__ #define USE_FMA #endif +#if !defined(__SSE2__) && !defined(_MSC_VER) +#error Compiler must support SSE2 +#endif +#define USE_SSE2 +#if defined(__aarch64__) +#else #ifdef __AVX2__ #define USE_AVX2 #endif @@ -24,7 +54,8 @@ #ifdef __SSE4_2__ #define USE_SSE42 #endif - +#endif +#endif #ifndef _MSC_VER #include @@ -147,5 +178,5 @@ inline T prev(T x) return x; } -} // namepsace Details +} // namespace Details } // namespace BinSearch diff --git a/include/SIMD.h b/include/SIMD.h index a2ac1a9ae..e97f5fc33 100644 --- a/include/SIMD.h +++ b/include/SIMD.h @@ -2,6 +2,46 @@ #include "Portable.h" +#ifdef USE_SSE2 +#include +#if defined(USE_AVX) || defined(USE_AVX2) +#include +#else +#ifdef USE_SSE41 +#include +#endif +#endif +#endif + +namespace BinSearch { +namespace Details { + +template +struct FTOITraits{}; + +template +struct FVec; + +template +struct IVec; + +template +struct FVec1; + +template <> struct InstrFloatTraits +{ + typedef __m128 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128d vec_t; +}; + +} +} + +#if !defined(__aarch64__) #ifdef USE_SSE42 #ifndef _MSC_VER #include @@ -26,29 +66,11 @@ FORCE_INLINE int popcnt32(int x32) } // namespace #endif -#if defined(USE_AVX) || defined(USE_AVX2) -#include -#else -#include -#ifdef USE_SSE41 -#include -#endif -#endif - #include "Type.h" namespace BinSearch { namespace Details { -template -struct FVec; - -template -struct IVec; - -template -struct FVec1; - template <> struct InstrIntTraits { typedef __m128i vec_t; @@ -64,8 +86,8 @@ template <> struct InstrFloatTraits typedef __m128d vec_t; }; -template -struct FTOITraits +template <> +struct FTOITraits { typedef IVec vec_t; }; @@ -285,9 +307,11 @@ FORCE_INLINE FVec operator- (const FVec& a, const FVec< FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_ps( a, b ); } FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_ps( a, b ); } FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttps_epi32(a); } +#if !defined(__clang__) || defined(__HIP_PLATFORM_AMD__) // Conflicts with builtin operator FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); } FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); } FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); } +#endif #ifdef USE_FMA FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm_fmsub_ps(a, b, c); } #endif @@ -339,9 +363,11 @@ FORCE_INLINE FVec operator- (const FVec& a, const FVec FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_pd( a, b ); } FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_pd( a, b ); } FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttpd_epi32(a); } +#if !defined(__clang__) || defined(__HIP_PLATFORM_AMD__) // Conflicts with builtin operator FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); } FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); } FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); } +#endif #ifdef USE_FMA FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c ) { return _mm_fmsub_pd(a, b, c); } #endif @@ -558,5 +584,6 @@ FORCE_INLINE FVec mulSub(const FVec& a, const FVec=42", - "wheel" -] +requires = [ "setuptools", "wheel" ] build-backend = "setuptools.build_meta" [tool.ruff] @@ -11,9 +8,11 @@ src = [ "tests", "benchmarking" ] -fix = true +target-version = "py38" +line-length = 119 + +[tool.ruff.lint] select = [ - "A", # prevent using keywords that clobber python builtins "B", # bugbear: security warnings "E", # pycodestyle "F", # pyflakes @@ -22,16 +21,40 @@ select = [ "UP", # alert you when better syntax is available in your python version "RUF", # the ruff developer's own rules ] -target-version = "py38" ignore = [ - "E712", # Allow using if x == False, as it's not always equivalent to if x. + "B007", # Loop control variable not used within the loop body (TODO: enable) + "B028", # Warning without stacklevel (TODO: enable) "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. - "F401", + "E701", # Multiple statements on one line (TODO: enable) + "E712", # Allow using if x == False, as it's not always equivalent to if x. + "E731", # Do not use lambda + "F841", # Local assigned but not used (TODO: enable, these are likely bugs) + "RUF012", # Mutable class attribute annotations ] ignore-init-module-imports = true # allow to expose in __init__.py via imports -[tool.ruff.isort] +[tool.ruff.lint.extend-per-file-ignores] +"**/__init__.py" = ["F401"] # allow unused imports in __init__.py +"{benchmarking,tests}/**/*.py" = [ + "B007", + "B011", + "B023", + "E701", + "E731", + "F841", + "UP030", +] + +[tool.ruff.lint.isort] combine-as-imports = true detect-same-package = true force-sort-within-sections = true -known-first-party = ["bitsandbytes"] \ No newline at end of file +known-first-party = ["bitsandbytes"] + +[[tool.mypy.overrides]] +module = "triton.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "scipy.stats" +ignore_missing_imports = true diff --git a/pytest.ini b/pytest.ini index 9902b98fa..ac6d72e63 100644 --- a/pytest.ini +++ b/pytest.ini @@ -7,4 +7,7 @@ addopts = -rP log_cli = True log_cli_level = INFO -log_file = logs/pytest.log \ No newline at end of file +log_file = logs/pytest.log +markers = + benchmark: mark test as benchmark + slow: mark test as slow diff --git a/requirements-ci.txt b/requirements-ci.txt new file mode 100644 index 000000000..61f92018a --- /dev/null +++ b/requirements-ci.txt @@ -0,0 +1,6 @@ +# Requirements used for GitHub actions +pytest==8.1.1 +einops==0.7.0 +lion-pytorch==0.1.2 +scipy==1.10.1; python_version < "3.9" +scipy==1.12.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..fc5449ba7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +# Requirements used for local development +setuptools>=63 +pytest~=8.1.1 +einops~=0.7.0 +wheel~=0.43.0 +lion-pytorch~=0.1.2 +scipy~=1.12.0 +pandas~=2.2.1 +matplotlib~=3.8.3 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 3bde2dc6a..000000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -lion-pytorch -pytest -scipy diff --git a/scripts/stale.py b/scripts/stale.py index b7f34c1fb..a65652aeb 100644 --- a/scripts/stale.py +++ b/scripts/stale.py @@ -15,13 +15,12 @@ Script to close stale issue. Taken in part from the AllenNLP repository. https://github.com/allenai/allennlp. """ + +from datetime import datetime as dt, timezone import os -from datetime import datetime as dt -from datetime import timezone from github import Github - # All labels that we don't want to touch LABELS_TO_EXEMPT = [ "feature-request", @@ -52,9 +51,9 @@ def main(): issue.create_comment( "This issue has been automatically marked as stale because it has not had " "recent activity. If you think this still needs to be addressed " - "please comment on this thread.\n\n" + "please comment on this thread.\n\n", ) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/setup.py b/setup.py index c07451d20..a51b3867c 100644 --- a/setup.py +++ b/setup.py @@ -6,9 +6,9 @@ import os from setuptools import find_packages, setup +from setuptools.dist import Distribution - -libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) +libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.*")) libs = [os.path.basename(p) for p in libs] print("libs:", libs) @@ -17,9 +17,15 @@ def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() +# Tested with wheel v0.29.0 +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + setup( - name=f"bitsandbytes", - version="0.42.0", + name="bitsandbytes", + version="0.44.0.dev", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", @@ -28,12 +34,16 @@ def read(fname): url="https://github.com/TimDettmers/bitsandbytes", packages=find_packages(), package_data={"": libs}, - install_requires=['torch', 'numpy', 'scipy'], - extras_require={'benchmark': ['pandas', 'matplotlib']}, + install_requires=["torch", "numpy"], + extras_require={ + "benchmark": ["pandas", "matplotlib"], + "test": ["scipy"], + }, long_description=read("README.md"), long_description_content_type="text/markdown", classifiers=[ "Development Status :: 4 - Beta", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], + distclass=BinaryDistribution, ) diff --git a/tests/conftest.py b/tests/conftest.py index 0b4b91225..17ffd281c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,19 @@ def pytest_runtest_call(item): try: item.runtest() + except NotImplementedError as nie: + if "NO_CUBLASLT" in str(nie): + pytest.skip("CUBLASLT not available") + raise except AssertionError as ae: if str(ae) == "Torch not compiled with CUDA enabled": pytest.skip("Torch not compiled with CUDA enabled") raise + except RuntimeError as re: + # CUDA-enabled Torch build, but no CUDA-capable device found + if "Found no NVIDIA driver on your system" in str(re): + pytest.skip("No NVIDIA driver found") + raise @pytest.fixture(scope="session") diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 000000000..e93c11b70 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,71 @@ +from io import BytesIO +from itertools import product +import random +from typing import Any, List + +import torch + +test_dims_rng = random.Random(42) + + +TRUE_FALSE = (True, False) +BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool) +BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) + + +def torch_save_to_buffer(obj): + buffer = BytesIO() + torch.save(obj, buffer) + buffer.seek(0) + return buffer + + +def torch_load_from_buffer(buffer): + buffer.seek(0) + obj = torch.load(buffer) + buffer.seek(0) + return obj + + +def get_test_dims(min: int, max: int, *, n: int) -> List[int]: + return [test_dims_rng.randint(min, max) for _ in range(n)] + + +def format_with_label(label: str, value: Any) -> str: + if isinstance(value, bool): + formatted = "T" if value else "F" + elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): + formatted = "".join("T" if b else "F" for b in value) + else: + formatted = str(value) + return f"{label}={formatted}" + + +def id_formatter(label: str): + """ + Return a function that formats the value given to it with the given label. + """ + return lambda value: format_with_label(label, value) + + +DTYPE_NAMES = { + torch.bfloat16: "bf16", + torch.bool: "bool", + torch.float16: "fp16", + torch.float32: "fp32", + torch.float64: "fp64", + torch.int32: "int32", + torch.int64: "int64", + torch.int8: "int8", +} + + +def describe_dtype(dtype: torch.dtype) -> str: + return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] + + +def get_blocksizes(hip_env: bool) -> List[int]: + if not hip_env: + return [4096, 2048, 1024, 512, 256, 128, 64] + else: + return [4096, 2048, 1024, 512, 256, 128] diff --git a/tests/test_autograd.py b/tests/test_autograd.py index f045fda4c..eafa01f0e 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,61 +1,47 @@ -from itertools import permutations, product +from typing import Tuple import pytest import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT - -n = 1 -k = 25 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() -funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] -str_funcs = ["bmm", "matmul"] -req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ["FF", "TF", "TT", "FT"] -transpose = [(False, False), (False, True), (True, True), (True, False)] -str_transpose = ["FF", "FT", "TT", "TF"] -dtype = [torch.float32, torch.float16] -values = list( - product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +from bitsandbytes.cextension import BNB_HIP_VERSION +from tests.helpers import ( + BOOLEAN_TRIPLES, + BOOLEAN_TUPLES, + TRUE_FALSE, + describe_dtype, + get_test_dims, + id_formatter, ) -str_values = list( - product( - dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose - ) -) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format( - *vals - ) - for vals in str_values -] + +TRANSPOSE_VALS = [(False, True), (False, False)] +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", - values, - ids=names, + "funcs", + [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], + ids=["func=bmm", "func=matmul"], ) -def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) +def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]): if dim2 > 0: dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) - for i in range(k): - + for i in range(25): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) - target = torch.randn( - size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] - ) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]) torch.nn.init.xavier_uniform_(B) if not transpose[0] and not transpose[1]: @@ -87,9 +73,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -97,18 +81,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: @@ -135,9 +115,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_close( - out_bnb, out_torch, atol=0.027, rtol=0.2 - ) + torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -149,9 +127,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -159,9 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -208,9 +182,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -218,9 +190,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -229,83 +199,23 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() < n * 0.02 -n = 1 -k = 3 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() - -dim2.append(0) - -decomp = [0.0, 6.0] -funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] -str_funcs = ["matmullt", 'switchback_bnb'] -req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.bfloat16, torch.float32] -has_fp16_weights = [True, False] -has_bias = [True, False] -values = list( - product( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - decomp, - has_fp16_weights, - has_bias - ) -) -str_values = list( - product( - dim1, - dim2, - dim3, - dim4, - str_funcs, - dtype, - req_grad_str, - str_transpose, - decomp, - has_fp16_weights, - has_bias - ) -) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] - +@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", - values, - ids=names, + "funcs", + [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], + ids=["func=matmul", "func=switchback_bnb"], ) -def test_matmullt( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - decomp, - has_fp16_weights, - has_bias -): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") @@ -313,19 +223,14 @@ def test_matmullt( req_grad = list(req_grad) req_grad[2] = False - for i in range(k): - + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn( - size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype - ) + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn( - size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype - ) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn( size=(dim2, dim4), device="cuda", @@ -335,7 +240,7 @@ def test_matmullt( bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -380,9 +285,7 @@ def test_matmullt( if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss( - out_bnb, target - ).mean() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad @@ -392,9 +295,7 @@ def test_matmullt( gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -405,9 +306,7 @@ def test_matmullt( bias.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() if dim2 > 0: @@ -421,53 +320,43 @@ def test_matmullt( assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) -n = 1 -k = 3 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() - -dim2.append(0) - -funcs = [(torch.matmul, bnb.matmul_4bit)] -str_funcs = ["matmul"] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.float32] -compress_statistics = [False, True] -has_fp16_weights = [True, False] -has_bias = [True, False] -quant_type = ['fp4', 'nf4'] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) -def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) +def test_matmul_4bit( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + has_bias, + compress_statistics, + quant_type, +): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) @@ -476,11 +365,15 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) + B2, quant_state = bnb.functional.quantize_4bit( + B, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) @@ -499,7 +392,7 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if n > 0: assert err < 0.115 - #assert err < 0.20 + # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -513,7 +406,7 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -524,38 +417,31 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias.grad = None if req_grad[0]: - torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) -funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] -str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.float32] -has_fp16_weights = [True, False] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values] -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) -def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize( + "funcs", + [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], + ids=["matmul_fp8_mixed", "matmul_fp8_global"], +) +def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) @@ -580,7 +466,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: assert err < 0.115 - #assert err < 0.20 + # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -591,7 +477,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -599,7 +485,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() @@ -614,9 +500,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - grad_err = (gradB1-gradB2).abs().mean() + grad_err = (gradB1 - gradB2).abs().mean() assert grad_err.item() < 0.003 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) - + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 5e9ccf590..53dd25044 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,28 +1,41 @@ -import os import pytest -import torch -from pathlib import Path -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cuda_specs import CUDASpecs -# hardcoded test. Not good, but a sanity check for now -# TODO: improve this -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -def test_manual_override(requires_cuda): - manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2')) - pytorch_version = torch.version.cuda.replace('.', '') +@pytest.fixture +def cuda120_spec() -> CUDASpecs: + return CUDASpecs( + cuda_version_string="120", + highest_compute_capability=(8, 6), + cuda_version_tuple=(12, 0), + ) - assert pytorch_version != 122 # TODO: this will never be true... - os.environ['CUDA_HOME']='{manual_cuda_path}' - os.environ['BNB_CUDA_VERSION']='122' - #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] - import bitsandbytes as bnb - loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name - #assert loaded_lib == 'libbitsandbytes_cuda122.so' +@pytest.fixture +def cuda111_noblas_spec() -> CUDASpecs: + return CUDASpecs( + cuda_version_string="111", + highest_compute_capability=(7, 2), + cuda_version_tuple=(11, 1), + ) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): + monkeypatch.setenv("BNB_CUDA_VERSION", "110") + assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" + assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? + +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_functional.py b/tests/test_functional.py index 5dba4ef5f..04a898d4b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,31 +1,30 @@ +from itertools import product import math import random import time -from itertools import product import einops +import numpy as np import pytest +from scipy.stats import norm import torch -import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT -from scipy.stats import norm +from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT +from tests.helpers import BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, get_blocksizes, get_test_dims, id_formatter -torch.set_printoptions( - precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 -) +torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) k = 20 def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: if throw: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) return sumval @@ -91,9 +90,8 @@ def setup(): def teardown(): pass -@pytest.mark.parametrize( - "dtype", [torch.float32, torch.float16], ids=["float", "half"] -) + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_estimate_quantiles(dtype): A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) @@ -110,6 +108,7 @@ def test_estimate_quantiles(dtype): diff = torch.abs(code - quantiles) assert (diff > 5e-02).sum().item() == 0 + def test_quantile_quantization(): for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") @@ -128,7 +127,6 @@ def test_quantile_quantization(): assert diff < 0.001 - def test_dynamic_quantization(): diffs = [] reldiffs = [] @@ -141,8 +139,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + print(sum(diffs) / len(diffs)) + print(sum(reldiffs) / len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device="cuda") @@ -153,17 +151,12 @@ def test_dynamic_quantization(): assert diff < 0.004 -def get_blocksizes(hip_env=False): - if not hip_env: - return [4096, 2048, 1024, 512, 256, 128, 64] - else: - return [4096, 2048, 1024, 512, 256, 128] -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("blocksize", get_blocksizes(HIP_ENVIRONMENT)) -@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False']) +@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): - #print('') + # print('') diffs = [] reldiffs = [] for i in range(100): @@ -174,10 +167,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) - #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) assert abserr < 0.011 assert relerr < 0.018 assert A2.dtype == dtype @@ -192,9 +185,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) + # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) if signed: assert abserr < 0.0035 assert relerr < 0.015 @@ -202,14 +195,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): assert abserr < 0.00175 assert relerr < 0.012 assert A2.dtype == dtype - #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) - -@pytest.mark.parametrize( - "gtype", [torch.float32, torch.float16], ids=["float", "half"] -) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_percentile_clipping(gtype): gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda") @@ -219,9 +209,7 @@ def test_percentile_clipping(gtype): for i in range(k): step += 1 g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping( - g, gnorm_vec2, step, percentile=percentile - ) + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 gnorm2 = torch.norm(g.float()) @@ -284,40 +272,28 @@ def mean(xx): return sum(xx) / float(len(xx)) -# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() -# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() -dim1 = [1024 * 2] -dim2 = [1024 * 16] -methods = [ - ( +methods = { + "linear": ( lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant, - ) -] -methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) -# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) -method_names = ["linear", "vectorwise"] -batched = [False, True] -values = list(product(dim1, dim2, methods, batched)) -values_names = list(product(dim1, dim2, method_names, batched)) -names = [ - "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals) - for vals in values_names -] + ), + "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant), +} -@pytest.mark.parametrize( - "dim1, dim2, quant_methods, batched", values, ids=names -) +@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) +@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) +@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) def test_approx_igemm(dim1, dim2, quant_methods, batched): dim1 = dim1 - (dim1 % 32) dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - #print("") + # print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -329,9 +305,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_close( - quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 - ) + torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) if batched: out2 = torch.bmm(A, B) C = torch.bmm(Ac.float(), Bc.float()) @@ -346,8 +320,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - #print(mean(errors)) - #print(mean(relerrors)) + # print(mean(errors)) + # print(mean(relerrors)) def test_stable_embedding(): @@ -355,36 +329,17 @@ def test_stable_embedding(): layer.reset_parameters() -n = 2 -hidden_dim = torch.randint(32, 256, size=(n,)).tolist() -batch_dim = torch.randint(16, 256, size=(n,)).tolist() -seq_dim = torch.randint(16, 256, size=(n,)).tolist() -transpose = [(False, False), (False, True), (True, False), (True, True)] -values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) -names = [ - "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize( - "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names -) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) +@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): - shapeA = ( - (batch_dim, hidden_dim) - if not transpose[0] - else (hidden_dim, batch_dim) - ) - shapeB = ( - (32 * random.randint(1, 4), hidden_dim) - if transpose[1] - else (hidden_dim, 32 * random.randint(1, 4)) - ) + shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -404,11 +359,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = ( - (32 * random.randint(1, 4), hidden_dim) - if transpose[1] - else (hidden_dim, 32 * random.randint(1, 4)) - ) + shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -421,52 +372,27 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): torch.testing.assert_close(out.float(), out2) -n = 3 -seq_dim = torch.randint(32, 512, size=(n,)).tolist() -hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() -batch_dim = torch.randint(2, 16, size=(n,)).tolist() -values = list(product(seq_dim, hidden_dim, batch_dim)) -names = [ - "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) +@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): seq_dim = seq_dim - (seq_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 2) for i in range(25): - A = torch.randint( - -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" - ).to(torch.int8) - B = torch.randint( - -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda" - ).to(torch.int8) + A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) - iout = torch.empty( - A.shape[2], B.shape[2], dtype=torch.int32, device=A.device - ) + iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) out = F.igemm(A, B, out=iout) torch.testing.assert_close(out.float(), out2) -n = 2 -seq_dim = torch.randint(32, 512, size=(n,)).tolist() -hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() -batch_dim = torch.randint(2, 16, size=(n,)).tolist() -transpose = [False, True] -values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) -names = [ - "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize( - "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names -) +@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -482,9 +408,7 @@ def min_max(x): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal( - 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda" - ) + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") if transpose: B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: @@ -536,20 +460,11 @@ def min_max(x): assert mean(relerrs) < 0.3 -n = 2 -dim1 = torch.randint(1, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 128, size=(n,)).tolist() -dim3 = torch.randint(32, 256, size=(n,)).tolist() -dim4 = torch.randint(32, 256, size=(n,)).tolist() -transpose = [(False, False), (True, False), (False, True), (True, True)] -values = list(product(dim1, dim2, dim3, dim4, transpose)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_ibmm(dim1, dim2, dim3, dim4, transpose): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -570,22 +485,15 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out = F.igemm(A.permute([0, 2, 1]), B) elif transpose[0] and transpose[1]: - out2 = torch.bmm( - A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() - ) + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_close(out.float(), out2.float()) -n = 1 -dim1 = torch.randint(1, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 128, size=(n,)).tolist() -dim3 = torch.randint(32, 256, size=(n,)).tolist() -values = list(product(dim1, dim2, dim3)) -names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values] - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) def test_vector_quant(dim1, dim2, dim3): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -594,31 +502,24 @@ def test_vector_quant(dim1, dim2, dim3): qA, SA = F.vectorwise_quant(A, dim=0) A1 = F.vectorwise_dequant(qA, SA) n = A1.numel() - assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) - - - + assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) -n = 2 -dim1 = torch.randint(2, 256, size=(n,)).tolist() -dim2 = torch.randint(2, 256, size=(n,)).tolist() -dim3 = torch.randint(2, 256, size=(n,)).tolist() -# dim1, dim2 = (256,), (256,) -dtype = [torch.int8, torch.int32] -a_order = ["row"] -out_order = ["col", "row"] if HIP_ENVIRONMENT else ["col", "row", "col32"] -transpose = [False] -dims = [2, 3] -values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) -names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) +@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize( + "orderOut", ["col", "row"] if HIP_ENVIRONMENT else ["col", "row", "col32"], ids=id_formatter("orderOut") +) +@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) +@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and out_order != "col32": + if dims == 3 and orderOut != "col32": return - if dtype == torch.int32 and out_order != "col32": + if dtype == torch.int32 and orderOut != "col32": return try: func = F.get_transform_func(dtype, orderA, orderOut, transpose) @@ -628,9 +529,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( - dtype - ) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) out, S = F.nvidia_transform(A, to_order=orderOut) @@ -642,17 +541,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = ( - A.shape[0] - * A.shape[1] - * (A.shape[2] + (32 - (A.shape[2] % 32))) - ) + n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) assert out.numel() == n elif orderOut == "col_turing": # 32 col 8 row tiles - n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( - A.shape[1] + (32 - (A.shape[1] % 32)) - ) + n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) assert out.numel() == n total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) for row in range(A.shape[0]): @@ -661,9 +554,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ( - (row // 8) + (1 if row % 8 != 0 else 0) - ) * total_coltile + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 row2 = (row % 8) * 32 @@ -674,46 +565,23 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": - out2, S = F.nvidia_transform( - out, from_order=orderOut, to_order="row", state=S - ) + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) torch.testing.assert_close(A, out2) -n = 1 -dim1 = torch.randint(1, 256, size=(n,)).tolist() -dim2 = torch.randint(32, 512, size=(n,)).tolist() -dim3 = torch.randint(32, 1024, size=(n,)).tolist() -dim4 = torch.randint(32, 1024, size=(n,)).tolist() - -# dim1 = [2] -# dim2 = [2] -# dim3 = [2] -# dim4 = [2] - -dims = (2, 3) -ldb = [0] -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals) - for vals in values -] - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) +@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( - torch.int8 - ) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) elif dims == 3: - A = torch.randint( - -128, 127, size=(dim1, dim2, dim3), device="cuda" - ).to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( - torch.int8 - ) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") @@ -722,10 +590,8 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_close(C1, C3.float()) - ## transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( - torch.int8 - ) + # transpose + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.float()) B2t, SBt = F.transform(B, "col_turing", transpose=True) @@ -734,29 +600,18 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.testing.assert_close(C1, C3.float()) -dim1 = [32] -dim2 = [32] -dim3 = [32] -dim4 = [32] - -dims = (2,) -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dim3, dim4, dims)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals) - for vals in values -] - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) +@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): formatB = F.get_special_format_str() for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal( - 0, 0.5, size=(dim1, dim2, dim3), device="cuda" - ).half() + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) @@ -788,23 +643,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # torch.testing.assert_close(C1, C3.float()) -batch_size = 2 -seqdim = 512 -# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] -values = [ - (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), - (batch_size, seqdim, 5120, 3 * 5120), - (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), -] - - -# values = list(product(batch, seq, model, hidden)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [ + pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"), + pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"), + pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"), + ], +) +@pytest.mark.benchmark def test_bench_8bit_training(batch, seq, model, hidden): formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -825,7 +672,6 @@ def test_bench_8bit_training(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(k): - out1 = torch.matmul(A, w1.t()) # fc1 # out2 = torch.matmul(out1, w2.t())# fc2 @@ -954,33 +800,23 @@ def test_bench_8bit_training(batch, seq, model, hidden): # print(t8) -n = 2 -dim1 = torch.randint(64, 256, size=(n,)).tolist() -dim4 = torch.randint(64, 1024, size=(n,)).tolist() - -#dim1 = [2*1024] -#dim4 = [2*1024] - -#dim1 = [4] -#dim4 = [4] - -dims = (2,) -formatB = ["col_turing", "col_ampere"] -has_bias = [True, False] -values = list(product(dim1, dim4, dims, formatB, has_bias)) -names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] - -@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) +@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() bias = None - if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) + if has_bias: + bias = torch.randn(dim4, device="cuda", dtype=torch.float16) formatB = F.get_special_format_str() for i in range(1): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) - if has_bias: C1 += bias + if has_bias: + C1 += bias A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -991,36 +827,27 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) - if has_bias: C4 += bias + if has_bias: + C4 += bias # TODO: is something wrong here? If so, the problem goes deeper - #n = C1.numel() - #p = 0.06 + # n = C1.numel() + # p = 0.06 std = C1.std(0).view(1, -1) C1 /= std C4 /= std - #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) - #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" + # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) + # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() - assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) - - -n = 2 -dim1 = [1 * 1024] -dim2 = [1 * 1024] -# dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() - -dims = (2,) -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dims)) -names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values] + assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) -@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_colrow_absmax(dim1, dim2, dims): for i in range(k): threshold = 3.0 @@ -1035,9 +862,7 @@ def test_colrow_absmax(dim1, dim2, dims): else: assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( - A, threshold=threshold - ) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) A_blocked = einops.rearrange( torch.abs(A), @@ -1057,26 +882,15 @@ def test_colrow_absmax(dim1, dim2, dims): torch.testing.assert_close(row_stats1_trunc, row_stats2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( - A, threshold=0.0 - ) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(row_stats1, row_stats2) assert nnz_block_ptr2 is None -n = 2 -# dim1 = [8*1024] -# dim2 = [4*1024] -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) def test_double_quant(dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -1090,40 +904,34 @@ def test_double_quant(dim1, dim2): torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() - num_not_close_rows = ( - (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() - ) - num_not_close_cols = ( - (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() - ) + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() # allow for 1:500 error due to rounding differences min_error = 1 / 500 if num_not_close_cols > (min_error * n): - print( - f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" - ) + print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}") assert False if num_not_close_rows > (min_error * n): - print( - f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" - ) + print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}") assert False torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) -n = 4 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) in zip( + get_test_dims(1, 4 * 1024, n=4), + get_test_dims(1, 4 * 1024, n=4), + get_test_dims(1, 4 * 1024, n=4), + ) + ), +) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): A = torch.randn(dim1, inner, device="cuda").half() @@ -1158,16 +966,17 @@ def test_integrated_igemmlt(dim1, dim4, inner): assert err2 <= err1 * 1.025 -n = 6 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) in zip( + get_test_dims(1, 4 * 1024, n=6), + get_test_dims(1, 4 * 1024, n=6), + get_test_dims(1, 4 * 1024, n=6), + ) + ), +) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): formatB = F.get_special_format_str() @@ -1190,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt( - A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale - ) + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1234,17 +1041,17 @@ def test_igemmlt_row_scale(dim1, dim4, inner): print(sum(err3) / len(err3)) -dim1 = [1024, 2048] -inner = [12288 * 4, 4096 * 4] -dim4 = [12288, 4096] - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + [ + pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"), + pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"), + ], +) @pytest.mark.skip("Row scale has some bugs for ampere") +@pytest.mark.benchmark def test_row_scale_bench(dim1, dim4, inner): + formatB = F.get_special_format_str() err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] scale = 1 @@ -1273,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt( - A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale - ) + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1289,43 +1094,20 @@ def test_row_scale_bench(dim1, dim4, inner): print("vector-wise", time.time() - t0) -n = 2 -dim1 = torch.randint(2, 1024, size=(n,)).tolist() -dim2 = torch.randint(2, 1024, size=(n,)).tolist() -# dim1 = [8*1024] -# dim2 = [4*1024] - -dim3 = [0] -dtype = [torch.int8] -a_order = ["row"] -out_order = ["col32", "col_turing", "col_ampere"] -transpose = [False, True] -dims = [2] -values = list( - product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) -) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format( - *vals - ) - for vals in values -] - -@pytest.mark.parametrize( - "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", - values, - ids=names, -) +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( - dtype - ) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint( - 10, 99, size=(dim1, dim2, dim3), device="cuda" - ).to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1343,22 +1125,6 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): torch.testing.assert_close(out1, out2) -n = 2 -# dim1 = torch.randint(2,1024, size=(n,)).tolist() -# dim2 = torch.randint(2,1024, size=(n,)).tolist() -dim1 = [1] -dim2 = [33] - -dtype = [torch.int8] -# a_order = ['col_turing', 'col_ampere'] -a_order = ["col_turing"] -out_order = ["row"] -values = list(product(dim1, dim2, dtype, a_order, out_order)) -names = [ - "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals) - for vals in values -] - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_overflow(): formatB = F.get_special_format_str() @@ -1374,17 +1140,8 @@ def test_overflow(): c2 = torch.matmul(a.float(), b.float().t()) -n = 2 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -# dim1 = [4] -# dim2 = [5] - -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) def test_coo_double_quant(dim1, dim2): threshold = 3.00 for i in range(k): @@ -1392,36 +1149,23 @@ def test_coo_double_quant(dim1, dim2): idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( - A, threshold=threshold - ) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) - A2[ - coo_tensor.rowidx.long(), coo_tensor.colidx.long() - ] = coo_tensor.values + A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close( - A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 - ) + torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) -n = 2 -dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() -# dim1 = [7] -# dim2 = [11] -transposed_B = [False, True] -values = list(product(dim1, dim2, transposed_B)) -names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values] - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) def test_spmm_coo(dim1, dim2, transposed_B): threshold = 1.5 dim3 = torch.randint(32, 128, size=(1,)).item() @@ -1437,9 +1181,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx if transposed_B: @@ -1453,6 +1195,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") +@pytest.mark.benchmark def test_spmm_bench(): batch = 2 model = 1024 * 1 @@ -1479,9 +1222,7 @@ def test_spmm_bench(): print(nnz / idx.numel()) rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) for i in range(10): out2 = F.spmm_coo(cooA, B) @@ -1496,14 +1237,9 @@ def test_spmm_bench(): print(tsp / t8) -n = 2 -dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() -dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 formatB = "col_turing" @@ -1521,9 +1257,7 @@ def test_integrated_sparse_decomp(dim1, dim2): out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( - A, threshold=threshold - ) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) @@ -1553,23 +1287,10 @@ def test_matmuls(): print(err1, err2) -n = 2 -# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1 * 2048] -dim2 = [12288] -# dim1 = [32] -# dim2 = [32] -# dtype = [torch.float16, torch.int8] -dtype = [torch.float16] -out_function = ["zeros", "ones"] -values = list(product(dim1, dim2, dtype, out_function)) -names = [ - "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func")) def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): out_func = getattr(torch, out_func) @@ -1591,9 +1312,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx out1 = torch.matmul(A2.half(), B.half()) out = out_func(out1.shape, dtype=torch.float16, device=out1.device) @@ -1608,9 +1327,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): std = out1.std() out1 /= std out2 /= std - assert_all_approx_close( - out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count - ) + assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) @@ -1638,9 +1355,7 @@ def test_coo2csr(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx csrA = F.coo2csr(cooA) counts = csrA.rowptr[1:] - csrA.rowptr[:-1] @@ -1658,9 +1373,7 @@ def test_coo2csc(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx cscA = F.coo2csc(cooA) counts = cscA.colptr[1:] - cscA.colptr[:-1] @@ -1672,21 +1385,10 @@ def test_coo2csc(): torch.testing.assert_close(A2.t()[idx], cscA.values) -n = 2 -# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1 * 2048] -# dim2 = [12288] -dim2 = [2048] -# dim1 = [2] -# dim2 = [2] -dtype = [torch.int8] -values = list(product(dim1, dim2, dtype)) -names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values] - - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 2048]) +@pytest.mark.parametrize("dim2", [2048]) +@pytest.mark.parametrize("dtype", [torch.int8]) def test_spmm_coo_dequant(dim1, dim2, dtype): threshold = 6.0 # threshold = 2.8 @@ -1706,9 +1408,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out1 = torch.matmul(A2, B.half()) @@ -1787,22 +1487,11 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 -values = [] -#values.append((batch_size, seqdim, 768, 4 * 768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5120, 4*5120)) -values.append((batch_size, seqdim, 6656, 4*6656)) -#values.append((batch_size, seqdim, 8192, 4*8192)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) -names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] -@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")], +) +@pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): iters = 1000 formatB = F.get_special_format_str() @@ -1823,8 +1512,8 @@ def test_bench_matmul(batch, seq, model, hidden): outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) - #linearMixedBit.eval() + linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half() + # linearMixedBit.eval() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() @@ -1841,121 +1530,123 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print( + f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", + ) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - #torch.cuda.synchronize() - #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + # torch.cuda.synchronize() + # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - #torch.cuda.synchronize() - #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + # torch.cuda.synchronize() + # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() - print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) torch.cuda.synchronize() - print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - + print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul(A, B) - #torch.cuda.synchronize() - #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul(A, B, threshold=6.0) - #torch.cuda.synchronize() - #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - #C32A, SA = F.transform(CA, "col32") - #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - #CxB, SB = F.transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + # C32A, SA = F.transform(CA, "col32") + # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + # CxB, SB = F.transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - #torch.cuda.synchronize() - #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #BA, statsB = F.vectorwise_quant(B, dim=1) - #CxB, SB = F.nvidia_transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1) + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1) # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - #torch.cuda.synchronize() - #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - #CxB, SB = F.nvidia_transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # out = Cout * statsB * statsA * (1.0 / (127 * 127)) - #torch.cuda.synchronize() - #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linearMixedBit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linearMixedBit(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linearMixedBit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit_train(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit_train(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit_train(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit_train_thresh(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit_train_thresh(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit_train(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + def test_zeropoint(): def quant_zp(x): @@ -1996,8 +1687,8 @@ def quant_zp(x): C2 -= A.sum(1).view(-1, 1) * zp ca, cqa, cza = quant_zp(A) - #print(ca.min(), ca.max()) - #print((ca - cza).min(), (ca - cza).max()) + # print(ca.min(), ca.max()) + # print((ca - cza).min(), (ca - cza).max()) zp = 1 scale = 2.0 @@ -2026,14 +1717,14 @@ def quant_zp(x): C7 -= zpa * zpb * A.shape[1] C7 /= qa * qb - #print("") + # print("") # print(C0.flatten()[:10]) - #print(C1.flatten()[:10]) - #print(C2.flatten()[:10]) - #print(C3.flatten()[:10]) - #print(C5.flatten()[:10]) - #print(C6.flatten()[:10]) - #print(C7.flatten()[:10]) + # print(C1.flatten()[:10]) + # print(C2.flatten()[:10]) + # print(C3.flatten()[:10]) + # print(C5.flatten()[:10]) + # print(C6.flatten()[:10]) + # print(C7.flatten()[:10]) err1 = torch.abs(C1 - C2).mean().item() err2 = torch.abs(C1 - C3).mean().item() err3 = torch.abs(C1 - C4).mean().item() @@ -2043,6 +1734,7 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) +@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") def test_extract_outliers(): for i in range(k): shapeA = (4096, 4096 * 4) @@ -2070,16 +1762,15 @@ def test_extract_outliers(): torch.testing.assert_close(outliers1, outliers2) - def test_blockwise_cpu_large(): diffs = [] reldiffs = [] batch = 128 seq = 128 - for hidden in [128]:#, 14336]: + for hidden in [128]: # , 14336]: for blocksize in [4096, 16384]: for i in range(2): - A1 = torch.randn(batch, seq, hidden, device='cpu') + A1 = torch.randn(batch, seq, hidden, device="cpu") t0 = time.time() C, S = F.quantize_blockwise(A1, blocksize=blocksize) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) @@ -2093,10 +1784,9 @@ def test_blockwise_cpu_large(): # print(sum(reldiffs)/len(reldiffs)) - def test_fp8_quant(): for e_bits in range(1, 7): - p_bits = 7-e_bits + p_bits = 7 - e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() abserr = [] @@ -2106,12 +1796,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2120,12 +1810,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2134,50 +1824,48 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(3, sum(abserr)/len(abserr)) - #print(3, sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(3, sum(abserr)/len(abserr)) + # print(3, sum(relerr)/len(relerr)) def test_few_bit_quant(): - - #print('') + # print('') for bits in range(2, 9): - #print('='*30, bits, '='*30) - for method in ['linear', 'fp8', 'dynamic', 'quantile']: + # print('='*30, bits, '='*30) + for method in ["linear", "fp8", "dynamic", "quantile"]: abserrs = [] relerrs = [] code = None - if method == 'linear': + if method == "linear": code = F.create_linear_map(True, total_bits=bits).cuda() - elif method == 'fp8': - ebits = math.ceil(bits/2) - pbits = bits-ebits-1 + elif method == "fp8": + ebits = math.ceil(bits / 2) + pbits = bits - ebits - 1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - elif method == 'dynamic': - code = F.create_dynamic_map(True, bits-0, bits).cuda() - elif method == 'quantile': - values = torch.randn(2048, 2048, device='cuda') + elif method == "dynamic": + code = F.create_dynamic_map(True, bits - 0, bits).cuda() + elif method == "quantile": + values = torch.randn(2048, 2048, device="cuda") code = F.create_quantile_map(values, bits).cuda() # for some data types we have no zero # for some data types we have one zero # for some data types we have two zeros - assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}' - #print(method, (code==0).sum()) + assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}" + # print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): - - values = torch.randn(1, 32, device='cuda') + values = torch.randn(1, 32, device="cuda") values /= values.abs().max() - #values[values.abs() < 1e-6] += 1e-5 + # values[values.abs() < 1e-6] += 1e-5 q1 = [] v1 = [] for v in values[0]: - idx = torch.abs(v-code).argmin() + idx = torch.abs(v - code).argmin() q1.append(idx.item()) v1.append(code[idx].item()) @@ -2188,65 +1876,65 @@ def test_few_bit_quant(): v2 = F.dequantize_blockwise(q2, S2) idx = torch.isclose(q1.int(), q2.int()) - err2 = torch.abs(v2-values) + err2 = torch.abs(v2 - values) abserrs.append(err2.mean().item()) - relerrs.append((err2/(1e-10+values).abs()).mean().item()) + relerrs.append((err2 / (1e-10 + values).abs()).mean().item()) if idx.sum(): # some weird cases - err1 = torch.abs(v1-values).mean() - #assert err2.mean() <= err1 + err1 = torch.abs(v1 - values).mean() + # assert err2.mean() <= err1 else: torch.testing.assert_close(q1, q2) - #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) - #assert False + # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + # assert False def test_kbit_quantile_estimation(): for i in range(100): - data = torch.randn(1024, 1024, device='cuda') + data = torch.randn(1024, 1024, device="cuda") for bits in range(2, 9): - p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) val1 = torch.Tensor(norm.ppf(p)).cuda() val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) - err = torch.abs(val1-val2).mean() + err = torch.abs(val1 - val2).mean() assert err < 0.038 for i in range(100): - data = torch.randn(1024, 1024, device='cuda') + data = torch.randn(1024, 1024, device="cuda") for bits in range(2, 4): - total_values = 2**bits-1 - p = np.linspace(0, 1, 2*total_values+1) - idx = np.arange(1, 2*total_values+1, 2) + total_values = 2**bits - 1 + p = np.linspace(0, 1, 2 * total_values + 1) + idx = np.arange(1, 2 * total_values + 1, 2) p = p[idx] - offset = 1/(2*total_values) - p = np.linspace(offset, 1-offset, total_values) + offset = 1 / (2 * total_values) + p = np.linspace(offset, 1 - offset, total_values) val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1) - err = torch.abs(val1-val2).mean() + val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) + err = torch.abs(val1 - val2).mean() assert err < 0.035 +@pytest.mark.benchmark def test_bench_dequantization(): - a = torch.rand(1024, 1024, device='cuda').half() - code =F.create_fp8_map(True, 3, 0, 4).cuda() + a = torch.rand(1024, 1024, device="cuda").half() + code = F.create_fp8_map(True, 3, 0, 4).cuda() qa, SA = F.quantize_blockwise(a, code=code) print(qa.max()) - max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 - #print(max_theoretical_mu) + max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000 + # print(max_theoretical_mu) torch.cuda.synchronize() t0 = time.time() for i in range(100): qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() - #print((time.time()-t0)/1e6) - + # print((time.time()-t0)/1e6) @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) def test_fp4_quant(dtype): vals = list(product([0, 1], repeat=4)) @@ -2255,26 +1943,28 @@ def test_fp4_quant(dtype): result = 0 bias = 3 sign, e1, e2, p1 = bits - idx = sign*8 + e1*4 + e2*2 + p1*1 + idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1 sign = -1.0 if sign else 1.0 - exp = e1*2 + e2*1 + exp = e1 * 2 + e2 * 1 if exp == 0: # sub-normal - if p1 == 0: result = 0 - else: result = sign*0.0625 + if p1 == 0: + result = 0 + else: + result = sign * 0.0625 else: # normal - exp = 2**(-exp + bias + 1) + exp = 2 ** (-exp + bias + 1) frac = 1.5 if p1 else 1.0 - result = sign*exp*frac + result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) + A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) qa, SA = F.quantize_fp4(A1, blocksize=64) A2 = F.dequantize_fp4(qa, SA) err = (A1 - A2).abs().float() - relerr = (err/(A1.abs().float()+1e-8)).mean() + relerr = (err / (A1.abs().float() + 1e-8)).mean() idx = err > 1.0 err = err.mean() @@ -2283,32 +1973,30 @@ def test_fp4_quant(dtype): assert relerr.item() < 0.28 -@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) def test_4bit_compressed_stats(quant_type): blocksizes = [128, 64] if not HIP_ENVIRONMENT else [128] for blocksize in blocksizes: errs1 = [] errs2 = [] for i in range(10): - A1 = torch.randn(1024, 1024, device='cuda').half() + A1 = torch.randn(1024, 1024, device="cuda").half() q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) - err = (A1 - A2).abs().float() - relerr = (err/(A1.abs().float()+1e-15)).mean() + relerr = (err / (A1.abs().float() + 1e-15)).mean() err = err.mean() errs1.append(err.item()) - assert err.item() < 0.11 assert relerr.item() < 0.28 err = (A1 - A3).abs().float() - relerr = (err/(A1.abs().float()+1e-15)).mean() + relerr = (err / (A1.abs().float() + 1e-15)).mean() err = err.mean() errs2.append(err.item()) @@ -2316,70 +2004,74 @@ def test_4bit_compressed_stats(quant_type): assert err.item() < 0.11 assert relerr.item() < 0.28 - #print(sum(errs1)/len(errs1), blocksize, quant_type) - #print(sum(errs2)/len(errs2), blocksize, quant_type) - - + # print(sum(errs1)/len(errs1), blocksize, quant_type) + # print(sum(errs2)/len(errs2), blocksize, quant_type) -#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) -@pytest.mark.parametrize("quant_type", ['nf4']) +# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ["nf4"]) +@pytest.mark.benchmark def test_bench_4bit_dequant(quant_type): blocksize = 256 - a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half() qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) - input_size = a.numel()/2 - output_size = a.numel()*2 - num_bytes = input_size+output_size - GB = num_bytes/1e9 - max_theoretical_s = GB/768 - #print(max_theoretical_s*1e6) - b = torch.randn(128, 1024*12, device='cuda').half() + input_size = a.numel() / 2 + output_size = a.numel() * 2 + num_bytes = input_size + output_size + GB = num_bytes / 1e9 + max_theoretical_s = GB / 768 + # print(max_theoretical_s*1e6) + b = torch.randn(128, 1024 * 12, device="cuda").half() iters = 100 torch.cuda.synchronize() t0 = time.time() for i in range(iters): F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - #b.copy_(a) + # b.copy_(a) torch.cuda.synchronize() - #print((time.time()-t0)/iters*1e6) + # print((time.time()-t0)/iters*1e6) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # torch.matmul(b, a.t()) - #torch.cuda.synchronize() - #print((time.time()-t0)/iters*1e6) - + # torch.cuda.synchronize() + # print((time.time()-t0)/iters*1e6) def test_normal_map_tree(): code = F.create_normal_map() - values =code[:8].tolist() + code[-8:].tolist() + values = code[:8].tolist() + code[-8:].tolist() num_pivots = 1 - #print(values) - while num_pivots <16: - idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) - #print(idx) + # print(values) + while num_pivots < 16: + idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots)) + # print(idx) num_pivots *= 2 pivots = [] for i in idx: - pivots.append((values[i-1]+values[i])/2) - #print(pivots) + pivots.append((values[i - 1] + values[i]) / 2) + # print(pivots) -@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32']) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64") +@pytest.mark.skipif( + HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" +) +@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") +@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize( + "quant_storage", + [torch.uint8, torch.float16, torch.bfloat16, torch.float32], + ids=describe_dtype, +) def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: - #for dim in [4*1024]: - #for dim in [1*16]: + # for dim in [4*1024]: + # for dim in [1*16]: errs1 = [] errs2 = [] errs3 = [] @@ -2390,38 +2082,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2 = [] max_errs3 = [] - for i in range(100): - if kind == 'fc1': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'fc2': - A = torch.randn(1, 4*dim, dtype=dtype, device='cuda') - B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'attn': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'attn_packed': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=storage_type, + compress_statistics=double_quant, + quant_storage=quant_storage, + ) C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) - err1 = (C1-C2).abs().float() - err2 = (C3-C2).abs().float() - err3 = (C3-C1).abs().float() + err1 = (C1 - C2).abs().float() + err2 = (C3 - C2).abs().float() + err3 = (C3 - C1).abs().float() - mag1 = torch.abs(C1).float()+1e-5 - mag2 = torch.abs(C3).float()+1e-5 - mag3 = torch.abs(C3).float()+1e-5 + mag1 = torch.abs(C1).float() + 1e-5 + mag2 = torch.abs(C3).float() + 1e-5 + mag3 = torch.abs(C3).float() + 1e-5 - relerr1 = err1/mag1 - relerr2 = err2/mag2 - relerr3 = err3/mag3 + relerr1 = err1 / mag1 + relerr2 = err2 / mag2 + relerr3 = err3 / mag3 max_err1 = err1.max() max_err2 = err2.max() @@ -2439,34 +2135,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2.append(max_err2.item()) max_errs3.append(max_err3.item()) - c = int(C1.numel()*0.0014*(dim/256))+1 + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - err1 = sum(errs1)/len(errs1)/math.sqrt(dim) - err2 = sum(errs2)/len(errs2)/math.sqrt(dim) - err3 = sum(errs3)/len(errs3)/math.sqrt(dim) - relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim) - relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim) - relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim) - maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim) - maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim) - maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim) - absratio = err2/err3 - relratio = relerr2/relerr3 - maxratio = relerr2/relerr3 + err1 = sum(errs1) / len(errs1) / math.sqrt(dim) + err2 = sum(errs2) / len(errs2) / math.sqrt(dim) + err3 = sum(errs3) / len(errs3) / math.sqrt(dim) + relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) + relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) + relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) + maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) + maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) + maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) + absratio = err2 / err3 + relratio = relerr2 / relerr3 + maxratio = relerr2 / relerr3 # for debugging if the tests fails # - #print('='*80) - #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') - #print(C1.flatten()[-20:]) - #print(C2.flatten()[-20:]) - #print(f'inference vs training abs: {err1}') - #print(f'inference vs training rel: {relerr1}') - #print(f'inference vs training max: {maxerr1}') - #print(f'inference vs training vs torch err ratio abs: {absratio}') - #print(f'inference vs training vs torch err ratio rel: {relratio}') - #print(f'inference vs training vs torch err ratio max: {maxratio}') + # print('='*80) + # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + # print(C1.flatten()[-20:]) + # print(C2.flatten()[-20:]) + # print(f'inference vs training abs: {err1}') + # print(f'inference vs training rel: {relerr1}') + # print(f'inference vs training max: {maxerr1}') + # print(f'inference vs training vs torch err ratio abs: {absratio}') + # print(f'inference vs training vs torch err ratio rel: {relratio}') + # print(f'inference vs training vs torch err ratio max: {maxratio}') if dtype == torch.float16: if dim <= 512: assert err1 < 7e-5 @@ -2502,56 +2198,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): - n = 32*10 + n = 32 * 10 A = F.get_paged(n, n, dtype=torch.float32) B = F.get_paged(n, n, dtype=torch.uint8) B2 = F.get_paged(n, n, dtype=torch.float32) assert A.is_paged assert B.is_paged - assert A.page_deviceid==0 - assert B.page_deviceid==0 + assert A.page_deviceid == 0 + assert B.page_deviceid == 0 F.fill(A, 17.0) F.fill(B, 17) F.fill(B2, 2) - assert (A==17).sum().item() == n*n - assert (B==17).sum().item() == n*n - C = A*B.float() - assert (C==289).sum().item() == n*n + assert (A == 17).sum().item() == n * n + assert (B == 17).sum().item() == n * n + C = A * B.float() + assert (C == 289).sum().item() == n * n F._mul(A, B2) F._mul(A, B2) F._mul(A, B2) - assert (A==17*(2**3)).sum().item() == n*n - # F.prefetch_tensor(A) - # F.prefetch_tensor(B) + assert (A == 17 * (2**3)).sum().item() == n * n + +# F.prefetch_tensor(A) +# F.prefetch_tensor(B) - # F.fill(B2, 17.0) - # F._mul(A, B2) - # F.prefetch_tensor(A, to_cpu=True) - # F.prefetch_tensor(B, to_cpu=True) - # F.prefetch_tensor(B2, to_cpu=True) - # torch.cuda.synchronize() +# F.fill(B2, 17.0) +# F._mul(A, B2) - # assert (A==17).sum().item() == n*n +# F.prefetch_tensor(A, to_cpu=True) +# F.prefetch_tensor(B, to_cpu=True) +# F.prefetch_tensor(B2, to_cpu=True) +# torch.cuda.synchronize() - # torch.testing.assert_close(A, torch.ones(A.shape)*289) +# assert (A==17).sum().item() == n*n +# torch.testing.assert_close(A, torch.ones(A.shape)*289) -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) + +@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) - dims = torch.randint(0, 8192, size=(dims,)).tolist() - dims = [dim + (64-(dim % 64)) for dim in dims] - #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: + dims = get_test_dims(0, 8192, n=dims) + dims = [dim + (64 - (dim % 64)) for dim in dims] + # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: for dim in dims: - A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda') - B = torch.eye(dim, dtype=dtype, device='cuda') + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") + B = torch.eye(dim, dtype=dtype, device="cuda") qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) C3 = torch.matmul(A, B.t()) @@ -2562,7 +2261,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): torch.testing.assert_close(A, C3) torch.testing.assert_close(A, C1) torch.testing.assert_close(A, C2) - #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) - #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) - - + # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) + # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) diff --git a/tests/test_generation.py b/tests/test_generation.py index 54ec10475..8e689261b 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,91 +1,83 @@ -import pytest -import torch -import math - from itertools import product +import math -import transformers -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GenerationConfig, - set_seed, - -) +import pytest +import torch -import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter + +transformers = pytest.importorskip("transformers") def get_4bit_config(): - return BitsAndBytesConfig( - load_in_4bit=True, - load_in_8bit=False, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4', - ) + return transformers.BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) def get_model_and_tokenizer(config): model_name_or_path, quant_type = config bnb_config = get_4bit_config() - if quant_type == '16bit': + if quant_type == "16bit": bnb_config.load_in_4bit = False else: - bnb_config.bnb_4bit_quant_type= quant_type - model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + bnb_config.bnb_4bit_quant_type = quant_type + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name_or_path, quantization_config=bnb_config, - max_memory={0:'48GB'}, - device_map='auto', - torch_dtype=torch.bfloat16 - ).eval() + max_memory={0: "48GB"}, + device_map="auto", + torch_dtype=torch.bfloat16, + ).eval() tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) return model, tokenizer + def get_prompt_for_generation_eval(text, add_roles=True): description = ( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) if add_roles: - prompt = f'{description} ### Human: {text} ### Assistant:' + prompt = f"{description} ### Human: {text} ### Assistant:" else: - prompt = f'{description} {text}' + prompt = f"{description} {text}" return prompt + def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): text = prompt_func(text) - inputs = tokenizer(text, return_tensors="pt").to('cuda:0') - outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) + inputs = tokenizer(text, return_tensors="pt").to("cuda:0") + outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config) return tokenizer.decode(outputs[0], skip_special_tokens=True) -models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] -dtypes = ['nf4', 'fp4'] -load_in_4bit = [True, False] -values = list(product(models, dtypes)) -strfunc = lambda lst: [str(x) for x in lst] -ids = ['_'.join(strfunc(x)) for x in values] -@pytest.fixture(scope='session', params=values, ids=ids) + +models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"] +dtypes = ["nf4", "fp4"] + + +@pytest.fixture(scope="session", params=product(models, dtypes)) def model_and_tokenizer(request): model, tokenizer = get_model_and_tokenizer(request.param) yield request.param, model, tokenizer del model -@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False']) -@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False']) -#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): - print('') - dtype = torch.float16 +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") +@pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq")) +@pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel")) +@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) +@pytest.mark.slow +def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): fixture_config, model, tokenizer = model_and_tokenizer generation_config = transformers.GenerationConfig( @@ -96,20 +88,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): ) generation_config.max_new_tokens = 20 - - #text = 'Please write down the first 50 digits of pi.' - #text = get_prompt_for_generation_eval(text) - #text += ' Sure, here the first 50 digits of pi: 3.14159' + # text = 'Please write down the first 50 digits of pi.' + # text = get_prompt_for_generation_eval(text) + # text += ' Sure, here the first 50 digits of pi: 3.14159' n_cases = 6 - text = '3.14159' - if hasattr(model.config, 'quantization_config'): + text = "3.14159" + if hasattr(model.config, "quantization_config"): model.config.quantization_config.bnb_4bit_compute_dtype = dtype model.config.quantization_config.bnb_4bit_use_double_quant = DQ if not inference_kernel: - text = [text]*n_cases - inputs = tokenizer(text, return_tensors="pt").to('cuda:0') - x = inputs['input_ids'] + text = [text] * n_cases + inputs = tokenizer(text, return_tensors="pt").to("cuda:0") + x = inputs["input_ids"] outputs = [] if inference_kernel: for i in range(n_cases): @@ -120,18 +111,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): outputs = model.generate(x, generation_config=generation_config) outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] - assert len(outputs) == n_cases failure_count = 0 for i in range(n_cases): - if not outputs[i][:len(str(math.pi))] == str(math.pi): + if not outputs[i][: len(str(math.pi))] == str(math.pi): failure_count += 1 - failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) + failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4 if failure_count > failure_max: print(math.pi) for out in outputs: print(out) - raise ValueError(f'Failure count: {failure_count}/{n_cases}') - - - + raise ValueError(f"Failure count: {failure_count}/{n_cases}") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 478255eee..bbbd05335 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,25 +1,28 @@ +import copy import os -from contextlib import nullcontext -from itertools import product +import pickle from tempfile import TemporaryDirectory import pytest import torch import bitsandbytes as bnb +from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer storage = { - 'uint8': torch.uint8, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, - 'float32': torch.float32 + "uint8": torch.uint8, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, } -@pytest.mark.parametrize( - "quant_type, compress_statistics, bias, quant_storage", - list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])), -) -def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): + +@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) +@pytest.mark.parametrize("bias", TRUE_FALSE) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) +def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): original_dtype = torch.float16 compute_dtype = None device = "cuda" @@ -81,7 +84,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_storage=storage[quant_storage], device="meta", ) - linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + linear_qs.weight = bnb.nn.Params4bit( + data=linear.weight, + requires_grad=False, + quant_type=quant_type, + quant_storage=storage[quant_storage], + ) if bias: linear_qs.bias = torch.nn.Parameter(linear.bias) linear_qs = linear_qs.to(device) @@ -92,7 +100,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora q0 = a.quant_state q1 = b.quant_state - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0, attr), getattr(q1, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -100,7 +108,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert c == d, f"{c} != {d}" if q0.state2 is not None: - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -113,6 +121,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert a.dtype == b.dtype assert torch.equal(a, b) + if save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + # Forward test x = torch.rand(42, layer_shape[0], device=device) a = linear_q(x) @@ -125,14 +136,23 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, b) assert torch.equal(a, c) + if not save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + linear_q3 = torch_load_from_buffer(bytes_4bit) + # Test moving to CPU and back to GPU - linear_q2.to('cpu') + linear_q2.to("cpu") linear_q2.to(device) d = linear_qs(x) assert c.dtype == d.dtype assert c.device == d.device assert torch.equal(c, d) + d = linear_q3(x) + assert c.dtype == d.dtype + assert c.device == d.device + assert torch.equal(c, d) + # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") @@ -140,10 +160,49 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora torch.save(linear.state_dict(), state_path) torch.save(linear_q.state_dict(), state_path_4bit) - size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( - state_path_4bit + size_orig, size_4 = ( + os.path.getsize(state_path), + os.path.getsize(state_path_4bit), ) size_ratio = size_4 / size_orig - target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases - ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" + target_compression = ( + 0.143 if original_dtype == torch.float32 else 0.29 + ) # these numbers get lower as weight shape increases + ratio_error_msg = ( + f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" + ) assert size_ratio < target_compression, ratio_error_msg + + +def test_copy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + + shallow_copy_param = copy.copy(param) + assert param.quant_state is shallow_copy_param.quant_state + assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() + + +def test_deepcopy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + copy_param = copy.deepcopy(param) + assert param.quant_state is not copy_param.quant_state + assert param.data.data_ptr() != copy_param.data.data_ptr() + + +def test_params4bit_real_serialization(): + original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") + + original_param.cuda(0) # move to CUDA to trigger quantization + + serialized_param = pickle.dumps(original_param) + deserialized_param = pickle.loads(serialized_param) + + assert torch.equal(original_param.data, deserialized_param.data) + assert original_param.requires_grad == deserialized_param.requires_grad == False + assert original_param.quant_type == deserialized_param.quant_type + assert original_param.blocksize == deserialized_param.blocksize + assert original_param.compress_statistics == deserialized_param.compress_statistics + assert original_param.quant_state == deserialized_param.quant_state diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6d5fc6a82..ca52f312e 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,6 +1,5 @@ -import os from contextlib import nullcontext -from itertools import product +import os from tempfile import TemporaryDirectory import pytest @@ -9,12 +8,19 @@ import bitsandbytes as bnb from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout -from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.nn.modules import Linear8bitLt +from tests.helpers import ( + TRUE_FALSE, + id_formatter, + torch_load_from_buffer, + torch_save_to_buffer, +) # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), @@ -47,7 +53,9 @@ def test_linear_no_igemmlt(): linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + linear.weight.data.clone(), + requires_grad=False, + has_fp16_weights=False, ).to(linear.weight.dtype) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() @@ -69,9 +77,20 @@ def test_linear_no_igemmlt(): @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", - list(product([False, True], [False, True], [False, True], [False, True]))) -def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): +@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) +@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) +@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) +@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) +def test_linear_serialization( + has_fp16_weights, + serialize_before_forward, + deserialize_before_cuda, + force_no_igemmlt, + save_before_forward, + load_before_cuda, +): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -86,7 +105,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights + linear.weight.data.clone(), + requires_grad=has_fp16_weights, + has_fp16_weights=has_fp16_weights, ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() @@ -94,6 +115,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + x_first = x.clone().cuda().requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) @@ -102,6 +126,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if not serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if not save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + with TemporaryDirectory() as tmpdir: state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") state_path = os.path.join(tmpdir, "state.pth") @@ -128,16 +155,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): new_linear_custom.load_state_dict(new_state_dict, strict=True) + if load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + new_linear_custom = new_linear_custom.cuda() if not deserialize_before_cuda: new_linear_custom.load_state_dict(new_state_dict, strict=True) + if not load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + x_second = x.clone().cuda().requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() + x_third = x.clone().cuda().requires_grad_(True) + fx_third = new_linear_custom2(x_third).float() + (fx_third * grad_proj).mean().backward() + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised if has_fp16_weights or not deserialize_before_cuda: assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) + assert torch.allclose(fx_first, fx_third, atol=1e-5) + assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) diff --git a/tests/test_modules.py b/tests/test_modules.py index 3e28a0f21..8235b600c 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,11 +1,14 @@ -from itertools import product +import math +import einops import pytest import torch from torch import nn import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import id_formatter + class MockArgs: def __init__(self, initial_data): @@ -17,12 +20,18 @@ class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super().__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold + dim1, + dim2, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold + dim2, + dim1, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, ) def forward(self, x): @@ -40,19 +49,17 @@ def get_args(): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round norm = math.sqrt(math.pi) / math.sqrt(2.0) # std = torch.abs(x).mean()*norm std = torch.std(x) @@ -120,9 +127,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype): return x.to(dtype) def get_8bit_linear(x, stochastic=False): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.abs(x).max() x = x / max1 * 127 x = round_func(x) / 127 * max1 @@ -131,9 +136,7 @@ def get_8bit_linear(x, stochastic=False): @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1[max1 == 0] = 1.0 x = (x * 127) / max1 @@ -217,9 +220,7 @@ def forward(ctx, x, weight, bias=None, args=None): weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant( - outputq, S1, S2, x.dtype, args.quant_type - ) + output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) # if torch.rand(1) < 0.01: # output32 = torch.matmul(x, weight.t()) # err = torch.abs(output-output32).float() @@ -248,37 +249,25 @@ def backward(ctx, grad_output): # weight and x are already 8bit # -> transform grad_output to 8-bit if args.use_8bit_training == "forward+wgrad": - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=[0, 1] - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant( - grad_weight8, S1, S2, grad_output.dtype, args.quant_type - ) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) grad_input = grad_output.matmul(weight) elif args.use_8bit_training == "full": - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=[0, 1] - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant( - grad_weight8, S1, S2, grad_output.dtype, args.quant_type - ) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=2 - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant( - grad_input8, S1, S3, grad_output.dtype, args.quant_type - ) + grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) else: grad_input = grad_output.matmul(weight) @@ -310,13 +299,8 @@ def forward(self, x): return LinearFunction.apply(x, self.weight, self.bias, self.args) -threshold = [0.0, 3.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) def test_linear8bitlt_inference(threshold): l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() assert l1.weight.device.type == "cuda" @@ -361,12 +345,8 @@ def test_linear8bitlt_accumulated_gradient(): opt1.zero_grad(True) opt2.step() opt2.zero_grad(True) - assert_all_approx_close( - l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 - ) - assert_all_approx_close( - l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 - ) + assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) @@ -380,7 +360,17 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) + l1 = ( + bnb.nn.Linear8bitLt( + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) + .cuda() + .half() + ) assert l1.weight.dtype == torch.int8 l1.eval() @@ -402,11 +392,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .cuda() - .half() - ) + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -419,11 +405,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .half() - .cuda() - ) + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -436,7 +418,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) + mlp = ( + MLP8bit( + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) + .half() + .to("cuda") + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -452,8 +444,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc2.weight.device.type == "cuda" mlp = MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization @@ -488,7 +484,14 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert (idx == 0).sum().item() <= b1.numel() * 0.005 -@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +@pytest.mark.parametrize( + "module", + [ + lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), + bnb.nn.LinearFP4, + ], + ids=["Int8Lt", "FP4"], +) def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically l1 = module(32, 64).cuda() @@ -511,19 +514,22 @@ def test_linear_kbit_fp32_bias(module): o1 = l1(b1) assert l1.bias is None -modules = [] -modules.append(bnb.nn.Linear8bitLt) -modules.append(bnb.nn.Linear4bit) -modules.append(bnb.nn.LinearFP4) -modules.append(bnb.nn.LinearNF4) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) -modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16)) -names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16'] + +module_dict = { + "Int8Lt": bnb.nn.Linear8bitLt, + "4bit": bnb.nn.Linear4bit, + "FP4": bnb.nn.LinearFP4, + "NF4": bnb.nn.LinearNF4, + "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True), + "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True), + "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32), + "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16), + "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16), +} + + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("module", modules, ids=names) +@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): b = 17 dim1 = 37 @@ -540,7 +546,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() - kbit = kbit.half().to('cuda') + kbit = kbit.half().to("cuda") errs1 = [] errs2 = [] @@ -558,10 +564,10 @@ def test_kbit_backprop(module): bgrad1 = ref[0].bias.grad bgrad2 = kbit[0].bias.grad - err1 = (out1-out2).abs().float() - err2 = (grad1-grad2).abs().float() - relerr1 = (err1/(out1.abs().float()+1e-9)) - relerr2 = (err2/(grad1.abs().float()+1e-9)) + err1 = (out1 - out2).abs().float() + err2 = (grad1 - grad2).abs().float() + relerr1 = err1 / (out1.abs().float() + 1e-9) + relerr2 = err2 / (grad1.abs().float() + 1e-9) errs1.append(err1.mean().item()) errs2.append(err2.mean().item()) relerrs1.append(relerr1.mean().item()) @@ -578,20 +584,20 @@ def test_kbit_backprop(module): assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 - #print('out', sum(errs1)/len(errs1)) - #print('grad', sum(errs2)/len(errs2)) - #print('rel out', sum(relerrs1)/len(relerrs1)) - #print('rel grad', sum(relerrs2)/len(relerrs2)) + # print('out', sum(errs1)/len(errs1)) + # print('grad', sum(errs2)/len(errs2)) + # print('rel out', sum(relerrs1)/len(relerrs1)) + # print('rel grad', sum(relerrs2)/len(relerrs2)) -def test_fp8linear(): +def test_fp8linear(): b = 10 h = 1024 inp = torch.randn(b, h).cuda() - fp32 = torch.nn.Linear(h, h*2).cuda() - fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() - fp32b = torch.nn.Linear(h*2, h).cuda() - fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() + fp32 = torch.nn.Linear(h, h * 2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() + fp32b = torch.nn.Linear(h * 2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() fp8.weight.data.copy_(fp32.weight.data) fp8.bias.data.copy_(fp32.bias.data) @@ -601,34 +607,34 @@ def test_fp8linear(): a = fp32b(torch.nn.functional.gelu(fp32(inp))) b = fp8b(torch.nn.functional.gelu(fp8(inp))) - err = (a-b).abs().mean() + err = (a - b).abs().mean() a.mean().backward() b.mean().backward() - graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() - bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() + graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() assert err < 0.05 assert graderr < 0.00002 assert bgraderr < 0.00002 + def test_4bit_warnings(): dim1 = 64 - with pytest.warns(UserWarning, match=r'inference or training'): + with pytest.warns(UserWarning, match=r"inference or training"): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() net(inp) - with pytest.warns(UserWarning, match=r'inference.'): + with pytest.warns(UserWarning, match=r"inference."): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(1, dim1).cuda().half() net(inp) with pytest.warns(UserWarning) as record: - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() @@ -640,6 +646,3 @@ def test_4bit_warnings(): net(inp) assert len(record) == 2 - - - diff --git a/tests/test_optim.py b/tests/test_optim.py index c373a4f14..d8c46e415 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,24 +1,22 @@ -import ctypes import os +from os.path import join import shutil import time import uuid -from itertools import product -from os.path import join -import pytest from lion_pytorch import Lion - +import pytest import torch import bitsandbytes as bnb import bitsandbytes.functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import describe_dtype, id_formatter # import apex k = 20 + def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() @@ -28,7 +26,7 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def get_temp_dir(): - path = f"/tmp/autoswap/{str(uuid.uuid4())}" + path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) return path @@ -36,6 +34,7 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) + str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) @@ -69,8 +68,14 @@ def rm_path(path): ) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) -str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = ( + torch.optim.AdamW, + lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True), +) +str2optimizers["paged_adam8bit_blockwise"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True), +) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( @@ -93,9 +98,18 @@ def rm_path(path): str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] -str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] -str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] -str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["paged_adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["paged_adamw8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] @@ -104,15 +118,16 @@ def rm_path(path): str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] -dim1 = [1024] -dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"] + + +@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() + if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: + pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -137,7 +152,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() - for name1, name2 in str2statenames[optim_name]: torch.testing.assert_close( torch_optimizer.state[p1][name1], @@ -148,7 +162,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -160,13 +174,17 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rm_path(path) # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], - atol=atol, rtol=rtol, - max_error_count=10) + assert_most_approx_close( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + max_error_count=10, + ) if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit @@ -180,14 +198,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] -values = list(product(dim1, dim2, gtype)) -names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) def test_global_config(dim1, dim2, gtype): if dim1 == 1 and dim2 == 1: return @@ -201,13 +214,9 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config( - p3, "optim_bits", 8 - ) + bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) - bnb.optim.GlobalOptimManager.get_instance().register_parameters( - [p1, p2, p3] - ) + bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) p1 = p1.cuda() p2 = p2.cuda() p3 = p3.cuda() @@ -233,10 +242,7 @@ def test_global_config(dim1, dim2, gtype): assert adam2.state[p3]["state2"].dtype == torch.uint8 -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = [ +optimizer_names_8bit = [ "adam8bit", "lion8bit", "momentum8bit", @@ -246,15 +252,15 @@ def test_global_config(dim1, dim2, gtype): "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() + if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]: + pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -306,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol - ) - == 0 - ) - #assert num_not_close.sum().item() < 20 + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 + # assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) - relerr = err / (torch.abs(p1)+1e-9) + relerr = err / (torch.abs(p1) + 1e-9) if g.dtype == torch.bfloat16: assert err.mean() < 0.00015 assert relerr.mean() < 0.0016 @@ -328,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors.append(relerr.mean().item()) if i % 10 == 0 and i > 0: - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): s1cpy = s.clone() raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() @@ -360,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) torch.testing.assert_close(s1cpy, s1) - num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 assert num_not_close.sum().item() < 20 # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 5 errors for Lion @@ -378,18 +377,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): # print(sum(relerrors)/len(relerrors)) -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32] -optim_bits = [32, 8] -values = list(product(dim1, dim2, gtype, optim_bits)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) +@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits")) +@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): if dim1 == 1 and dim2 == 1: return @@ -415,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): for i in range(50): step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( - 0.01 * i - ) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) g2 = g1.clone() p2.grad = g2 - current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( - g1, gnorm_vec, step, 5 - ) + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) g1 = (g1.float() * gnorm_scale).to(gtype) p1.grad = g1 @@ -477,22 +464,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): adam2.load_state_dict(torch.load(join(path, "opt.pt"))) -dim1 = [4096] -dim2 = [4096] -gtype = [torch.float32, torch.float16] -# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] -# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] -# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] -# optimizer_names = ['lamb_apex', 'lamb8bit'] -# optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise'] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values +optimizer_names_benchmark = [ + "adam8bit_blockwise", + "paged_adam8bit_blockwise", + "paged_adamw8bit_blockwise", + "paged_lion8bit_blockwise", ] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) +@pytest.mark.benchmark def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: return @@ -517,39 +501,36 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): print(optim_name, gtype, s / params) # assert s < 3.9 -dim1 = [2*1024] -gtype = [torch.float16] -#mode = ['torch', 'bnb'] -mode = ['bnb'] -optimizer_names = ['paged_adamw'] -#optimizer_names = ['paged_adamw8bit_blockwise'] -values = list(product(dim1,gtype, optimizer_names, mode)) -names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) + +@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) +@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) +@pytest.mark.benchmark def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) layers1 = layers1.to(gtype) layers1 = layers1.cuda() large_tensor = None - if mode == 'torch': + if mode == "torch": optim = str2optimizers[optim_name][0](layers1.parameters()) else: optim = str2optimizers[optim_name][1](layers1.parameters()) # 12 GB - large_tensor = torch.empty((int(4.5e9),), device='cuda') + large_tensor = torch.empty((int(4.5e9),), device="cuda") torch.cuda.synchronize() time.sleep(5) num_batches = 5 - batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) - lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) + lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() for i in range(num_batches): print(i) b = batches[i] - if i ==2: + if i == 2: torch.cuda.synchronize() t0 = time.time() diff --git a/tests/test_triton.py b/tests/test_triton.py index 8890193fc..1c5422c0d 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,21 +1,26 @@ import pytest import torch -from bitsandbytes.triton.triton_utils import is_triton_available -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear -from bitsandbytes.nn import Linear8bitLt from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.nn import Linear8bitLt +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear +from bitsandbytes.triton.triton_utils import is_triton_available +from tests.helpers import TRUE_FALSE + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, - reason="This test requires triton and a GPU with compute capability 8.0 or higher.") -@pytest.mark.parametrize("vector_wise_quantization", [False, True]) +@pytest.mark.skipif( + not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.", +) +@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) def test_switchback(vector_wise_quantization): for dim in [83]: for batch in [13]: - standard = torch.nn.Linear(dim, 4 * dim).cuda().half() - switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + switchback = ( + SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + ) baseline = Linear8bitLt(dim, 4 * dim).cuda().half() switchback.weight.data.copy_(standard.weight) switchback.bias.data.copy_(standard.bias) @@ -38,24 +43,23 @@ def test_switchback(vector_wise_quantization): err_sb = (out_standard - out_sb).abs().mean() err_baseline = (out_standard - out_baseline).abs().mean() - print('OUT', err_sb, err_baseline) + print("OUT", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() - print('GW2', err_sb, err_baseline) + print("GW2", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() - print('GW1', err_sb, err_baseline) + print("GW1", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (x1.grad - x2.grad).abs().mean() err_baseline = (x1.grad - x3.grad).abs().mean() - print('GX1', err_sb, err_baseline) + print("GX1", err_sb, err_baseline) assert err_sb < 2 * err_baseline -