Skip to content

Commit

Permalink
perf(torch): fast but unsafe buildATen & eliminating dispatches (Deep…
Browse files Browse the repository at this point in the history
…Link-org#1271)

* change at::xx_out to at:;cuda::xx_out

* perf: superfast but unsafe buildaten

* modify func_ext buildaten

* build(torch): add option to switch unsafe buildATen

* style: format cpp

* docs(torch): refine naming and docs

---------

Co-authored-by: CoolKbh <[email protected]>
  • Loading branch information
lljbash and CoolKbh authored Jun 26, 2024
1 parent 2b566db commit de8dfe7
Show file tree
Hide file tree
Showing 8 changed files with 1,061 additions and 912 deletions.
2 changes: 2 additions & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ CheckOptions:
value: "_"
- key: readability-identifier-naming.ParameterCase
value: "camelBack"
- key: readability-identifier-naming.ParameterIgnoredRegexp
value: "^([a-z]+_)*[a-z]+$"
- key: readability-identifier-naming.UnionCase
value: "camelBack"
- key: readability-identifier-naming.VariableCase
Expand Down
7 changes: 6 additions & 1 deletion impl/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ cmake_minimum_required(VERSION 3.14)
project(torch_impl)

option(HIP "Whether to use HIP when available" OFF)
option(DIOPI_TORCH_UNSAFE_BUILDATEN "Whether to use fast but unsafe buildATen (caution: only use this with DIPU)" OFF)

include(cmake/TorchBaseFunc.cmake)
InitFindTorch()
Expand Down Expand Up @@ -32,7 +33,7 @@ if (DYLOAD)
set(IMPL_SRC wrap_func.cpp)
endif()

file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions/functions_mmcv/*.cu functions/functions_ext/*.cu functions/*.cpp helper.cpp)
file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions/functions_mmcv/*.cu functions/functions_ext/*.cu functions/*.cpp helper.cpp build_aten.cpp)

# adaptor
set(USE_ADAPTOR ON)
Expand Down Expand Up @@ -104,6 +105,10 @@ if(USE_ADAPTOR)
add_dependencies(${DEVICEIMPL} adaptor_code_gen)
endif()

if(DIOPI_TORCH_UNSAFE_BUILDATEN)
target_compile_definitions(${DEVICEIMPL} PRIVATE DIOPI_TORCH_UNSAFE_BUILDATEN)
endif()

if (TEST)
add_subdirectory(test)
endif()
154 changes: 154 additions & 0 deletions impl/torch/build_aten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#include "build_aten.hpp"

#include <ATen/Context.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/TensorBody.h>
#include <ATen/cuda/EmptyTensor.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <diopi/diopirt.h>

#include <utility>

#include "helper.hpp"

namespace impl::aten {

UnsafelyDeviceChangedTensorWrapper::UnsafelyDeviceChangedTensorWrapper(const at::Tensor& tensor) : at::Tensor(tensor) {
if (!defined() || is_cpu()) {
return;
}
saveForRevert_.emplace(unsafeGetTensorImpl(), device());
// NOTE: CUDA allocators may have not been initialized if we were using DIPU allocators.
// We have to do this explicitly for potential allocations in op workspaces.
at::globalContext().lazyInitCUDA();
at::Device newDevice{at::DeviceType::CUDA, device().index()};
setTensorImplDeviceUnsafe({unsafeGetTensorImpl(), newDevice});
}

UnsafelyDeviceChangedTensorWrapper::~UnsafelyDeviceChangedTensorWrapper() {
if (saveForRevert_.has_value()) {
setTensorImplDeviceUnsafe(*saveForRevert_);
}
}

UnsafelyDeviceChangedTensorWrapper buildATenUnsafe(diopiConstTensorHandle_t tensor) {
if (tensor == nullptr) {
return {};
}
auto& atTensor = *reinterpret_cast<at::Tensor*>(const_cast<diopiTensorHandle_t>(tensor));
return UnsafelyDeviceChangedTensorWrapper::createFromTensor(atTensor);
}

void UnsafelyDeviceChangedTensorWrapper::setTensorImplDeviceUnsafe(const TensorImplAndDevice& tensorAndDevice) {
const auto& [tensorImpl, device] = tensorAndDevice;
auto& storage = const_cast<at::Storage&>(tensorImpl->unsafe_storage());
auto& dataPtr = const_cast<at::DataPtr&>(storage.data_ptr());
dataPtr.unsafe_set_device(device);
tensorImpl->set_storage_keep_dtype(std::move(storage));
tensorImpl->_change_backend_component_keys(device);
}

namespace {

template <diopiDevice_t>
class BuildATenDeviceApi {};

template <>
class BuildATenDeviceApi<diopi_host> {
public:
static void lazyInitDevice() {}
static at::Device device(diopiConstTensorHandle_t /*unused*/) { return {at::DeviceType::CPU}; }
static at::Tensor empty(at::IntArrayRef size, at::ScalarType dtype, at::Device /*unused*/) {
return at::detail::empty_cpu(size, dtype, /*pin_memory=*/false, /*memory_format_opt=*/c10::nullopt);
}
};

template <>
class BuildATenDeviceApi<diopi_device> {
public:
static void lazyInitDevice() { at::globalContext().lazyInitCUDA(); }
static at::Device device(diopiConstTensorHandle_t tensor) {
diopiDeviceIndex_t deviceIndex;
diopiGetTensorDeviceIndex(tensor, &deviceIndex);
return {at::DeviceType::CUDA, deviceIndex};
}
static at::Tensor empty(at::IntArrayRef size, at::ScalarType dtype, at::Device device) {
return at::detail::empty_cuda(size, dtype, device, /*memory_format_opt=*/c10::nullopt);
}
};

template <class DeviceImpl>
at::Tensor buildATenSafeImpl(diopiConstTensorHandle_t tensor) {
diopiSize_t shape;
diopiGetTensorShape(tensor, &shape);
at::IntArrayRef atSizes(shape.data, shape.len);

diopiDtype_t dtype;
diopiGetTensorDtype(tensor, &dtype);
auto atTypeMeta = getATenType(dtype);
auto atDtype = atTypeMeta.toScalarType();

auto atDevice = DeviceImpl::device(tensor);

// NOTE: storage offset has been handled in `diopiGetTensorData`
void* data = nullptr;
diopiGetTensorData(const_cast<diopiTensorHandle_t>(tensor), &data);

if (data == nullptr) {
return DeviceImpl::empty(atSizes, atDtype, atDevice);
}

// NOTE: CUDA allocators may have not been initialized if we were using DIPU allocators.
// We have to do this explicitly for potential allocations in op workspaces.
DeviceImpl::lazyInitDevice();

// PERF: It would be faster if we can obtain and reuse the storage from tensor.
// However we cannot assume diopiTensorHandle_t to be a wrapper of at::Tensor.
// So we have to create a new storage (offset = 0) whose data_ptr points to
// the same address but with an empty dtor (to avoid double-free).

diopiSize_t stride;
diopiGetTensorStride(tensor, &stride);
at::IntArrayRef atStrides(stride.data, stride.len);

auto storageNBytes = at::detail::computeStorageNbytes(atSizes, atStrides, atTypeMeta.itemsize());

// NOTE: in this way, data_ptr will have an empty destructor
at::Storage storage{at::Storage::use_byte_size_t{}, storageNBytes, /*data_ptr=*/{data, atDevice}};

auto dk = at::computeDispatchKey(atDtype, /*layout=*/c10::nullopt, atDevice);
at::Tensor atTensor = at::detail::make_tensor<at::TensorImpl>(std::move(storage), dk, atTypeMeta);
atTensor.unsafeGetTensorImpl()->set_sizes_and_strides(atSizes, atStrides);

return atTensor;
}

} // namespace

at::Tensor buildATenSafe(diopiConstTensorHandle_t tensor) {
if (tensor == nullptr) {
return at::Tensor();
}

diopiDevice_t device;
diopiGetTensorDevice(tensor, &device);
switch (device) {
case diopi_host:
return buildATenSafeImpl<BuildATenDeviceApi<diopi_host>>(tensor);
case diopi_device:
return buildATenSafeImpl<BuildATenDeviceApi<diopi_device>>(tensor);
default:
TORCH_CHECK(false, "Invalid device type encountered in buildATen: ", device);
return {};
}
}

} // namespace impl::aten
99 changes: 99 additions & 0 deletions impl/torch/build_aten.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pragma once

#include <ATen/Context.h>
#include <ATen/core/TensorBody.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>

#include <algorithm>
#include <cstdint>
#include <iterator>
#include <utility>

#include "diopi/diopirt.h"

namespace impl::aten {

// This class is a wrapper around an at::Tensor, which changes the device and dispatch key of the binded at::TensorImpl from any device (e.g. XPU) to CUDA, and
// revert it back when the wrapper is destroyed.
// The wrapper is designed to be implicitly converted to an at::Tensor (object slicing), so that it can be used in place of an at::Tensor.
class UnsafelyDeviceChangedTensorWrapper : public at::Tensor {
public:
static UnsafelyDeviceChangedTensorWrapper createFromTensor(const at::Tensor& tensor) { return UnsafelyDeviceChangedTensorWrapper(tensor); }
UnsafelyDeviceChangedTensorWrapper() = default;
~UnsafelyDeviceChangedTensorWrapper();
UnsafelyDeviceChangedTensorWrapper(const UnsafelyDeviceChangedTensorWrapper& other) : at::Tensor(other) {}
UnsafelyDeviceChangedTensorWrapper(UnsafelyDeviceChangedTensorWrapper&& other) : at::Tensor(std::move(other)) { saveForRevert_.swap(other.saveForRevert_); }
UnsafelyDeviceChangedTensorWrapper& operator=(const UnsafelyDeviceChangedTensorWrapper& other) = delete;
UnsafelyDeviceChangedTensorWrapper& operator=(UnsafelyDeviceChangedTensorWrapper&& other) {
at::Tensor::operator=(std::move(other));
saveForRevert_.swap(other.saveForRevert_);
return *this;
}
UnsafelyDeviceChangedTensorWrapper& operator=(const at::Tensor& other) {
at::Tensor::operator=(other);
return *this;
}
UnsafelyDeviceChangedTensorWrapper& operator=(at::Tensor&& other) {
at::Tensor::operator=(std::move(other));
return *this;
}

private:
explicit UnsafelyDeviceChangedTensorWrapper(const at::Tensor& tensor);
using TensorImplAndDevice = std::pair<at::TensorImpl*, at::Device>;
static void setTensorImplDeviceUnsafe(const TensorImplAndDevice& tensorAndDevice);
c10::optional<TensorImplAndDevice> saveForRevert_ = c10::nullopt;
};

// WARNING: This function is UNSAFE. It is the caller's responsibility to ensure that:
// 1. The returned wrapper is not destroyed when its sliced at::Tensor is still in use in DIOPI.
// 2. The input diopiConstTensorHandle_t is actually a reinterpret_cast of an at::Tensor*.
// 3. The input tensor and its storage are not used in another thread during the lifetime of the returned wrapper.
[[nodiscard]] UnsafelyDeviceChangedTensorWrapper buildATenUnsafe(diopiConstTensorHandle_t tensor);

[[nodiscard]] at::Tensor buildATenSafe(diopiConstTensorHandle_t tensor);

[[nodiscard]] inline auto buildATen(diopiConstTensorHandle_t tensor) {
#if DIOPI_TORCH_UNSAFE_BUILDATEN
return buildATenUnsafe(tensor);
#else
return buildATenSafe(tensor);
#endif
}

template <typename T>
[[nodiscard]] auto buildATenList(T* tensors, int64_t numTensors) {
using TensorType = decltype(buildATen(std::declval<diopiConstTensorHandle_t>()));
c10::SmallVector<TensorType, 4> vecAtTensor;
vecAtTensor.reserve(numTensors);
std::transform(tensors, tensors + numTensors, std::back_inserter(vecAtTensor), [](auto tensor) { return buildATen(tensor); });
return vecAtTensor;
}

// These macros is designed to avoid early destruction of the wrapper when build optional at::Tensor.
#define DIOPI_IMPL_BUILD_ATEN_LIST(atTensors, diopiTensors, numTensors) \
auto atTensors##__MAYBE_WRAPPER = ::impl::aten::buildATenList(diopiTensors, numTensors); \
c10::SmallVector<at::Tensor, 4> atTensors; \
atTensors.reserve(numTensors); \
std::transform(atTensors##__MAYBE_WRAPPER.begin(), atTensors##__MAYBE_WRAPPER.end(), std::back_inserter(atTensors), [](auto& tensor) { \
return static_cast<at::Tensor>(tensor); \
});
#define DIOPI_IMPL_BUILD_ATEN_OPTIONAL(atTensor, diopiTensor) \
auto atTensor##__MAYBE_WRAPPER = ::impl::aten::buildATen(diopiTensor); \
c10::optional<at::Tensor> atTensor; \
if (atTensor##__MAYBE_WRAPPER.defined()) { \
atTensor = atTensor##__MAYBE_WRAPPER; \
}
#define DIOPI_IMPL_BUILD_ATEN_OPTIONAL_LIST(atTensors, diopiTensors, numTensors) \
auto atTensors##__MAYBE_WRAPPER = ::impl::aten::buildATenList(diopiTensors, numTensors); \
c10::List<c10::optional<at::Tensor>> atTensors; \
atTensors.reserve(numTensors); \
std::transform(atTensors##__MAYBE_WRAPPER.begin(), atTensors##__MAYBE_WRAPPER.end(), std::back_inserter(atTensors), [](auto& tensor) { \
return tensor.defined() ? c10::optional<at::Tensor>(tensor) : c10::nullopt; \
});

} // namespace impl::aten
Loading

0 comments on commit de8dfe7

Please sign in to comment.