diff --git a/advanced_source/cpp_custom_ops.rst b/advanced_source/cpp_custom_ops.rst index 512c39b2a68..c19a815f43a 100644 --- a/advanced_source/cpp_custom_ops.rst +++ b/advanced_source/cpp_custom_ops.rst @@ -16,7 +16,7 @@ Custom C++ and CUDA Operators .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites - * PyTorch 2.4 or later + * PyTorch 2.4 or later (or PyTorch 2.10 or later if using the stable ABI) * Basic understanding of C++ and CUDA programming .. note:: @@ -27,7 +27,7 @@ PyTorch offers a large library of operators that work on Tensors (e.g. torch.add However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the blessed path to authoring a custom operator written in C++/CUDA. -For our tutorial, we’ll demonstrate how to author a fused multiply-add C++ +For our tutorial, we'll demonstrate how to author a fused multiply-add C++ and CUDA operator that composes with PyTorch subsystems. The semantics of the operation are as follows: @@ -37,48 +37,103 @@ the operation are as follows: return a * b + c You can find the end-to-end working example for this tutorial -`here `_ . +in the `extension-cpp `_ repository, +which contains two parallel implementations: + +- `extension_cpp/ `_: + Uses the standard ATen/LibTorch API. +- `extension_cpp_stable/ `_: + Uses APIs supported by the LibTorch Stable ABI (recommended for PyTorch 2.10+). + +**Which API should you use?** + +- **ABI-Stable LibTorch API** (recommended): If you are using PyTorch 2.10+, we recommend using + the ABI-stable API. It allows you to build a single wheel that works across multiple PyTorch + versions (2.10, 2.11, 2.12, etc.), reducing the maintenance burden of supporting multiple + PyTorch releases. See the :ref:`libtorch-stable-abi` section below for more details. + +- **Non-ABI-Stable LibTorch API**: Use this if you need APIs not yet available in the stable ABI, + or if you are targeting PyTorch versions older than 2.10. Note that you will need to build + separate wheels for each PyTorch version you want to support. + +The code snippets below show both implementations using tabs, with the ABI-stable API shown by default. Setting up the Build System --------------------------- If you are developing custom C++/CUDA code, it must be compiled. -Note that if you’re interfacing with a Python library that already has bindings +Note that if you're interfacing with a Python library that already has bindings to precompiled C++/CUDA code, you might consider writing a custom Python operator instead (:ref:`python-custom-ops-tutorial`). Use `torch.utils.cpp_extension `_ -to compile custom C++/CUDA code for use with PyTorch +to compile custom C++/CUDA code for use with PyTorch. C++ extensions may be built either "ahead of time" with setuptools, or "just in time" via `load_inline `_; -we’ll focus on the "ahead of time" flavor. - -Using ``cpp_extension`` is as simple as writing the following ``setup.py``: - -.. code-block:: python - - from setuptools import setup, Extension - from torch.utils import cpp_extension - - setup(name="extension_cpp", - ext_modules=[ - cpp_extension.CppExtension( - "extension_cpp", - ["muladd.cpp"], - # define Py_LIMITED_API with min version 3.9 to expose only the stable - # limited API subset from Python.h - extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]}, - py_limited_api=True)], # Build 1 wheel across multiple Python versions - cmdclass={'build_ext': cpp_extension.BuildExtension}, - options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version - ) +we'll focus on the "ahead of time" flavor. + +Using ``cpp_extension`` is as simple as writing a ``setup.py``: + +.. tab-set:: + + .. tab-item:: ABI-Stable LibTorch API + + .. code-block:: python + + from setuptools import setup, Extension + from torch.utils import cpp_extension + + setup(name="extension_cpp", + ext_modules=[ + cpp_extension.CppExtension( + "extension_cpp", + ["muladd.cpp"], + extra_compile_args={ + "cxx": [ + # define Py_LIMITED_API with min version 3.9 to expose only the stable + # limited API subset from Python.h + "-DPy_LIMITED_API=0x03090000", + # define TORCH_TARGET_VERSION with min version 2.10 to expose only the + # stable API subset from torch + "-DTORCH_TARGET_VERSION=0x020a000000000000", + ] + }, + py_limited_api=True)], # Build 1 wheel across multiple Python versions + cmdclass={'build_ext': cpp_extension.BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version + ) + + .. tab-item:: Non-ABI-Stable LibTorch API + + .. code-block:: python + + from setuptools import setup, Extension + from torch.utils import cpp_extension + + setup(name="extension_cpp", + ext_modules=[ + cpp_extension.CppExtension( + "extension_cpp", + ["muladd.cpp"], + extra_compile_args={ + "cxx": [ + "-DPy_LIMITED_API=0x03090000", + ] + }, + py_limited_api=True)], + cmdclass={'build_ext': cpp_extension.BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}} + ) If you need to compile CUDA code (for example, ``.cu`` files), then instead use `torch.utils.cpp_extension.CUDAExtension `_. Please see `extension-cpp `_ for an example for how this is set up. -The above example represents what we refer to as a CPython agnostic wheel, meaning we are +CPython Agnosticism +^^^^^^^^^^^^^^^^^^^ + +The above examples represent what we refer to as a CPython agnostic wheel, meaning we are building a single wheel that can be run across multiple CPython versions (similar to pure Python packages). CPython agnosticism is desirable in minimizing the number of wheels your custom library needs to support and release. The minimum version we'd like to support is @@ -98,7 +153,7 @@ minimum CPython version you would like to support: Defining the ``Py_LIMITED_API`` flag helps verify that the extension is in fact only using the `CPython Stable Limited API `_, -which is a requirement for the building a CPython agnostic wheel. If this requirement +which is a requirement for building a CPython agnostic wheel. If this requirement is not met, it is possible to build a wheel that looks CPython agnostic but will crash, or worse, be silently incorrect, in another CPython environment. Take care to avoid using unstable CPython APIs, for example APIs from libtorch_python (in particular @@ -148,34 +203,101 @@ like so: cmdclass={'build_ext': cpp_extension.BuildExtension}, ) +.. _libtorch-stable-abi: + +LibTorch Stable ABI (PyTorch Agnosticism) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In addition to CPython agnosticism, there is a second axis of wheel compatibility: +LibTorch agnosticism. While CPython agnosticism allows building a single wheel +that works across multiple Python versions (3.9, 3.10, 3.11, etc.), LibTorch agnosticism +allows building a single wheel that works across multiple PyTorch versions (2.10, 2.11, 2.12, etc.). +These two concepts are orthogonal and can be combined. + +To achieve LibTorch agnosticism, you must use the ABI stable LibTorch API, which provides +a stable API for interacting with PyTorch tensors and operators. For example, instead of +using ``at::Tensor``, you must use ``torch::stable::Tensor``. For comprehensive +documentation on the stable ABI, including migration guides, supported types, and +stack-based API conventions, see the +`LibTorch Stable ABI documentation `_. + +The stable ABI setup.py includes ``TORCH_TARGET_VERSION=0x020a000000000000``, which indicates that +the extension targets the LibTorch Stable ABI with a minimum supported PyTorch version of 2.10. The version format is: +``[MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes]``, so 2.10.0 = ``0x020a000000000000``. + +If the stable API/ABI does not contain what you need, you can use the Non-ABI-stable LibTorch API, +but you will need to build separate wheels for each PyTorch version you want to support. + Defining the custom op and adding backend implementations --------------------------------------------------------- First, let's write a C++ function that computes ``mymuladd``: -.. code-block:: cpp - - at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); - for (int64_t i = 0; i < result.numel(); i++) { - result_ptr[i] = a_ptr[i] * b_ptr[i] + c; - } - return result; - } - -In order to use this from PyTorch’s Python frontend, we need to register it -as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically -bind the operator to Python. +.. tab-set:: + + .. tab-item:: ABI-Stable LibTorch API + + .. code-block:: cpp + + #include + #include + #include + #include + #include + + torch::stable::Tensor mymuladd_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; + } + + .. tab-item:: Non-ABI-Stable LibTorch API + + .. code-block:: cpp + + #include + #include + #include + + at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; + } + +In order to use this from PyTorch's Python frontend, we need to register it +as a PyTorch operator using the ``TORCH_LIBRARY`` (or ``STABLE_TORCH_LIBRARY``) macro. +This will automatically bind the operator to Python. Operator registration is a two step-process: @@ -188,7 +310,7 @@ Defining an operator To define an operator, follow these steps: 1. select a namespace for an operator. We recommend the namespace be the name of your top-level - project; we’ll use "extension_cpp" in our tutorial. + project; we'll use "extension_cpp" in our tutorial. 2. provide a schema string that specifies the input/output types of the operator and if an input Tensors will be mutated. We support more types in addition to Tensor and float; please see `The Custom Operators Manual `_ @@ -197,57 +319,151 @@ To define an operator, follow these steps: * If you are authoring an operator that can mutate its input Tensors, please see here (:ref:`mutable-ops`) for how to specify that. -.. code-block:: cpp +.. tab-set:: - TORCH_LIBRARY(extension_cpp, m) { - // Note that "float" in the schema corresponds to the C++ double type - // and the Python float type. - m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); - } + .. tab-item:: ABI-Stable LibTorch API -This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd``. + .. code-block:: cpp -Registering backend implementations for an operator -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator. + STABLE_TORCH_LIBRARY(extension_cpp, m) { + // Note that "float" in the schema corresponds to the C++ double type + // and the Python float type. + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + } -.. code-block:: cpp + .. tab-item:: Non-ABI-Stable LibTorch API - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); - } + .. code-block:: cpp -If you also have a CUDA implementation of ``myaddmul``, you can register it -in a separate ``TORCH_LIBRARY_IMPL`` block: + TORCH_LIBRARY(extension_cpp, m) { + // Note that "float" in the schema corresponds to the C++ double type + // and the Python float type. + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + } -.. code-block:: cpp +This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd``. - __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) result[idx] = a[idx] * b[idx] + c; - } +Registering backend implementations for an operator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Use ``TORCH_LIBRARY_IMPL`` (or ``STABLE_TORCH_LIBRARY_IMPL``) to register a backend implementation for the operator. - at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); - - int numel = a_contig.numel(); - muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); - return result; - } +.. tab-set:: - TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { - m.impl("mymuladd", &mymuladd_cuda); - } + .. tab-item:: ABI-Stable LibTorch API + + Note that we wrap the function pointer with ``TORCH_BOX()`` - this is required for + stable ABI functions to handle argument boxing/unboxing correctly. + + .. code-block:: cpp + + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + } + + .. tab-item:: Non-ABI-Stable LibTorch API + + .. code-block:: cpp + + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + } + +If you also have a CUDA implementation of ``mymuladd``, you can register it +in a separate ``TORCH_LIBRARY_IMPL`` (or ``STABLE_TORCH_LIBRARY_IMPL``) block: + +.. tab-set:: + + .. tab-item:: ABI-Stable LibTorch API + + .. code-block:: cpp + + #include + #include + #include + #include + #include + #include + + __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) result[idx] = a[idx] * b[idx] + c; + } + + torch::stable::Tensor mymuladd_cuda( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + + int numel = a_contig.numel(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); + return result; + } + + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda)); + } + + .. tab-item:: Non-ABI-Stable LibTorch API + + .. code-block:: cpp + + #include + #include + #include + #include + #include + #include + + __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) result[idx] = a[idx] * b[idx] + c; + } + + at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + + int numel = a_contig.numel(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); + return result; + } + + TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", &mymuladd_cuda); + } Adding ``torch.compile`` support for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -327,7 +543,7 @@ three ways: for more details: .. code-block:: cpp - + #include extern "C" { @@ -380,8 +596,7 @@ three ways: Adding training (autograd) support for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Use ``torch.library.register_autograd`` to add training support for an operator. Prefer -this over directly using Python ``torch.autograd.Function`` or C++ ``torch::autograd::Function``; -you must use those in a very specific way to avoid silent incorrectness (see +this over directly using Python ``torch.autograd.Function`` (see `The Custom Operators Manual `_ for more details). @@ -418,39 +633,76 @@ it must be wrapped into a custom operator. If we had our own custom ``mymul`` kernel, we would need to wrap it into a custom operator and then call that from the backward: -.. code-block:: cpp - - // New! a mymul_cpu kernel - at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(a.device().type() == at::DeviceType::CPU); - TORCH_CHECK(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); - for (int64_t i = 0; i < result.numel(); i++) { - result_ptr[i] = a_ptr[i] * b_ptr[i]; - } - return result; - } - - TORCH_LIBRARY(extension_cpp, m) { - m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); - // New! defining the mymul operator - m.def("mymul(Tensor a, Tensor b) -> Tensor"); - } - - - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); - // New! registering the cpu kernel for the mymul operator - m.impl("mymul", &mymul_cpu); - } +.. tab-set:: + + .. tab-item:: ABI-Stable LibTorch API + + .. code-block:: cpp + + torch::stable::Tensor mymul_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i]; + } + return result; + } + + STABLE_TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + } + + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); + } + + .. tab-item:: Non-ABI-Stable LibTorch API + + .. code-block:: cpp + + at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i]; + } + return result; + } + + TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + } + + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + m.impl("mymul", &mymul_cpu); + } .. code-block:: python @@ -528,46 +780,93 @@ behavior. If there are multiple mutated Tensors, use different names (for exampl Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ``a+b`` into ``out``. -.. code-block:: cpp - - // An example of an operator that mutates one of its inputs. - void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(b.sizes() == out.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(out.dtype() == at::kFloat); - TORCH_CHECK(out.is_contiguous()); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = out.data_ptr(); - for (int64_t i = 0; i < out.numel(); i++) { - result_ptr[i] = a_ptr[i] + b_ptr[i]; - } - } - -When defining the operator, we must specify that it mutates the out Tensor in the schema: - -.. code-block:: cpp - - TORCH_LIBRARY(extension_cpp, m) { - m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); - m.def("mymul(Tensor a, Tensor b) -> Tensor"); - // New! - m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); - } - - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); - m.impl("mymul", &mymul_cpu); - // New! - m.impl("myadd_out", &myadd_out_cpu); - } +.. tab-set:: + + .. tab-item:: ABI-Stable LibTorch API + + .. code-block:: cpp + + void myadd_out_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + torch::stable::Tensor& out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = out.mutable_data_ptr(); + + for (int64_t i = 0; i < out.numel(); i++) { + result_ptr[i] = a_ptr[i] + b_ptr[i]; + } + } + + When defining the operator, we must specify that it mutates the out Tensor in the schema: + + .. code-block:: cpp + + STABLE_TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); + } + + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu)); + } + + .. tab-item:: Non-ABI-Stable LibTorch API + + .. code-block:: cpp + + void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(b.sizes() == out.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_CHECK(out.dtype() == at::kFloat); + TORCH_CHECK(out.is_contiguous()); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = out.data_ptr(); + for (int64_t i = 0; i < out.numel(); i++) { + result_ptr[i] = a_ptr[i] + b_ptr[i]; + } + } + + When defining the operator, we must specify that it mutates the out Tensor in the schema: + + .. code-block:: cpp + + TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); + } + + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + m.impl("mymul", &mymul_cpu); + m.impl("myadd_out", &myadd_out_cpu); + } .. note:: @@ -577,6 +876,6 @@ When defining the operator, we must specify that it mutates the out Tensor in th Conclusion ---------- In this tutorial, we went over the recommended approach to integrating Custom C++ -and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly +and CUDA operators with PyTorch. The ``TORCH_LIBRARY/STABLE_TORCH_LIBRARY`` and ``torch.library`` APIs are fairly low-level. For more information about how to use the API, see `The Custom Operators Manual `_.