Skip to content

Commit

Permalink
Make ORT callable from various Pytorch compilers (LazyTensor, TorchDy…
Browse files Browse the repository at this point in the history
…namo, etc) (microsoft#10460)

* Make ORT as Pytorch JIT backend

LORT likely doesn't work with aten fallback so we only test LORT in its own CI.

* Revert changes to enable external CUDA allocator. Will add it later.

Revert "Revert changes to enable external CUDA allocator. Will add it later."

This reverts commit d5487f2.

Fix external allocator

* Relax tolerance and remove commented code

* Print more information in CI

* Fix pointer

* Address comments.
1. Reuse ORT-eager mode's environment.
2. Remove unused ctor.

* Use Pytorch master branch as all PRs are merged

Fix

* Refine based on cpplint feedbacks

* Revert changes to allow custom CUDA allocator in public APIs

* Use torch.testing.assert_close

* Use unittest framework

* Switch docker repo

* Rename *.cpp to *.cc

* Address comments

* Add comment

* Use same pipeline file for eager and lort pipelines

* Address comments

* Add yaml comment

* Fix cmake files

* Address comments

* Rename flags, remove printing code, remove dead comment
  • Loading branch information
wschin authored Aug 22, 2022
1 parent 53090f6 commit dc486d1
Show file tree
Hide file tree
Showing 30 changed files with 2,263 additions and 44 deletions.
21 changes: 20 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF)
# build eager mode
option(onnxruntime_ENABLE_EAGER_MODE "build ort eager mode")

# build Pytorch's LazyTensor support
cmake_dependent_option(onnxruntime_ENABLE_LAZY_TENSOR "Enable ORT as a LazyTensor backend in Pytorch." ON "onnxruntime_ENABLE_TRAINING" OFF)

# build separate library of schemas of (custom) ops used by ORT (for ONNX to MLIR translation)
option(onnxruntime_BUILD_OPSCHEMA_LIB "Build op schema library" ON)

Expand Down Expand Up @@ -1833,7 +1836,7 @@ if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")
endif()
if (NOT WIN32)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --compiler-options -fPIC")
list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC)
endif()
# Options passed to cudafe
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"")
Expand Down Expand Up @@ -2041,6 +2044,22 @@ if (onnxruntime_ENABLE_EAGER_MODE)
add_compile_definitions(ENABLE_EAGER_MODE)
list(APPEND ONNXRUNTIME_TARGETS onnxruntime_eager)
endif()

if (onnxruntime_ENABLE_LAZY_TENSOR)
# To support LazyTensor, ORT needs to call Python function from C/C++.
# so onnxruntime_ENABLE_PYTHON is required.
if (NOT onnxruntime_ENABLE_TRAINING OR NOT onnxruntime_ENABLE_PYTHON)
message(
FATAL_ERROR
"Option onnxruntime_ENABLE_LAZY_TENSOR can only be set when onnxruntime_ENABLE_TRAINING and onnxruntime_ENABLE_PYTHON are enabled")
endif()
add_compile_definitions(ENABLE_LAZY_TENSOR)
# TODO: In the future, we can compile LazyTensor into a standalone
# library target, onnxruntime_lazy_tensor, to make the buid
# cleaner.
#list(APPEND ONNXRUNTIME_TARGETS onnxruntime_lazy_tensor)
endif()

foreach(target_name ${ONNXRUNTIME_TARGETS})
include(${target_name}.cmake)
endforeach()
Expand Down
80 changes: 66 additions & 14 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,60 @@ if(onnxruntime_ENABLE_TRAINING)
list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc)
endif()

if (onnxruntime_ENABLE_EAGER_MODE)
# Add Pytorch as a library.
if (onnxruntime_ENABLE_LAZY_TENSOR OR onnxruntime_ENABLE_EAGER_MODE)
# Both Lazy Tensor and Eager Mode require Pytorch as a library.
list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_PREBUILT_PYTORCH_PATH})
# The following line may change ${CUDA_NVCC_FLAGS} and ${CMAKE_CUDA_FLAGS},
# if Pytorch is built from source.
# For example, pytorch/cmake/public/cuda.cmake and
# pytorch/torch/share/cmake/Caffe2/public/cuda.cmake both defines
# ONNX_NAMESPACE for both CUDA_NVCC_FLAGS and CMAKE_CUDA_FLAGS.
# Later, this ONNX_NAMESPACE may conflicts with ONNX_NAMESPACE set by ORT.
find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
# Let's remove ONNX_NAMESPACE from Torch.
list(FILTER CUDA_NVCC_FLAGS EXCLUDE REGEX "-DONNX_NAMESPACE=.+")
string(REGEX REPLACE "-DONNX_NAMESPACE=.+ " " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
endif()

if (onnxruntime_ENABLE_EAGER_MODE)
file(GLOB onnxruntime_eager_extension_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_ROOT}/orttraining/eager/*.cpp"
)

if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
list(APPEND onnxruntime_eager_extension_srcs
"${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc")
list(APPEND onnxruntime_pybind_srcs
${onnxruntime_eager_extension_srcs})
endif()

# Support ORT as a backend in Pytorch's LazyTensor.
if (onnxruntime_ENABLE_LAZY_TENSOR)
file(GLOB onnxruntime_lazy_tensor_extension_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_ROOT}/orttraining/lazy_tensor/*.cc")
file(GLOB onnxruntime_lazy_tensor_extension_headers CONFIGURE_DEPENDS
"${ORTTRAINING_ROOT}/orttraining/lazy_tensor/*.h")

if(NOT MSVC)
set_source_files_properties(${onnxruntime_lazy_tensor_extension_srcs} PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${onnxruntime_lazy_tensor_extension_headers} PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()

list(APPEND onnxruntime_pybind_srcs
${onnxruntime_eager_extension_srcs})
${onnxruntime_lazy_tensor_extension_srcs})
endif()

# onnxruntime_ENABLE_LAZY_TENSOR and onnxruntime_ENABLE_EAGER_MODE
# need DLPack code to pass tensors cross ORT and Pytorch boundary.
# TODO: consider making DLPack code a standalone library.
if (onnxruntime_ENABLE_LAZY_TENSOR OR onnxruntime_ENABLE_EAGER_MODE)
# If DLPack code is not built, add it to ORT's pybind target.
if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
list(APPEND onnxruntime_pybind_srcs
"${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc")
endif()
endif()

onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_pybind_srcs})

if(MSVC)
target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
if(onnxruntime_ENABLE_TRAINING)
Expand Down Expand Up @@ -108,18 +143,28 @@ if (onnxruntime_ENABLE_TRAINING)
target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_training)
endif()

if (onnxruntime_ENABLE_EAGER_MODE)
# Eager mode and LazyTensor are both Pytorch's backends, so their
# dependencies are set together below.
if (onnxruntime_ENABLE_EAGER_MODE OR onnxruntime_ENABLE_LAZY_TENSOR)
# Set library dependencies shared by aforementioned backends.

# todo: this is because the prebuild pytorch may use a different version of protobuf headers.
# force the build to find the protobuf headers ort using.
target_include_directories(onnxruntime_pybind11_state PRIVATE "${REPO_ROOT}/cmake/external/protobuf/src")
target_include_directories(onnxruntime_pybind11_state PRIVATE "${TORCH_INSTALL_PREFIX}/include" "${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include")
find_library(LIBTORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
find_library(LIBTORCH_CPU_LIBRARY torch_cpu PATHS "${TORCH_INSTALL_PREFIX}/lib")
find_library(LIBC10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_eager ${LIBTORCH_LIBRARY} ${LIBTORCH_CPU_LIBRARY} ${LIBC10_LIBRARY} ${TORCH_PYTHON_LIBRARY})
target_include_directories(onnxruntime_pybind11_state PRIVATE
"${REPO_ROOT}/cmake/external/protobuf/src"
${TORCH_INCLUDE_DIRS})

# Explicitly link torch_python to workaround https://github.com/pytorch/pytorch/issues/38122#issuecomment-694203281
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(onnxruntime_pybind11_state PRIVATE ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES})
if (onnxruntime_ENABLE_EAGER_MODE)
target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_eager)
endif()

# This part is eager-mode specific.
# the ort_aten.g.cpp is generated from tools. currently it has some limitations.
# todo: fix this
if (NOT MSVC)
if (onnxruntime_ENABLE_EAGER_MODE AND NOT MSVC)
set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_aten.g.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_aten.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_guard.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
Expand Down Expand Up @@ -274,6 +319,13 @@ else()
set(ONNXRUNTIME_SETDLOPENFLAGS_LOCAL "")
endif()

if (onnxruntime_ENABLE_LAZY_TENSOR)
# Import torch so that onnxruntime's pybind can see its DLLs.
set(ONNXRUNTIME_IMPORT_PYTORCH_TO_RESOLVE_DLLS "import torch")
else()
set(ONNXRUNTIME_IMPORT_PYTORCH_TO_RESOLVE_DLLS "")
endif()

configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py)

Expand Down
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1357,14 +1357,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Whether B should be transposed</dd>
</dl>

#### Inputs
#### Inputs (2 - 3)

<dl>
<dt><tt>A</tt> : T</dt>
<dd>Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.</dd>
<dt><tt>B</tt> : T</dt>
<dd>Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.</dd>
<dt><tt>C</tt> : T</dt>
<dt><tt>C</tt> (optional) : T</dt>
<dd>Input tensor C. The shape of C should be unidirectional broadcastable to (M, N).</dd>
</dl>

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,8 @@ activation and leaky_relu_alpha.)DOC")
"C",
"Input tensor C. "
"The shape of C should be unidirectional broadcastable to (M, N).",
"T")
"T",
OpSchema::Optional)
.Output(0, "Y", "Output tensor of shape (M, N).", "T")
.TypeConstraint(
"T",
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi
false);

return CreateAllocator(default_memory_info);

} else {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId id) {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,4 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() { r

#define LOGS_DEFAULT(severity) \
LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime)

10 changes: 5 additions & 5 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,16 +834,16 @@ struct ProviderHostImpl : ProviderHost {
const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); }

// Tensor (wrapped)
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) override {
return std::make_unique<Tensor>(p_type, shape, std::move(allocator));
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) override {
return std::make_unique<Tensor>(p_type, shape, std::move(allocator));
}

std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) override {
return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset);
return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset);
}

std::unique_ptr<Tensor> Tensor__construct_default() override {
return std::make_unique<Tensor>();
return std::make_unique<Tensor>();
}

virtual void Tensor__move_assign(Tensor& lhs, Tensor&& rhs) noexcept override {
Expand Down Expand Up @@ -1130,7 +1130,7 @@ std::unique_ptr<IAllocator> CreateROCMPinnedAllocator(int16_t device_id, const c

// Adapter to convert the legacy OrtCUDAProviderOptions to the latest OrtCUDAProviderOptionsV2
OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const OrtCUDAProviderOptions* legacy_cuda_options) {
OrtCUDAProviderOptionsV2 cuda_options_converted;
OrtCUDAProviderOptionsV2 cuda_options_converted{};

cuda_options_converted.device_id = legacy_cuda_options->device_id;
cuda_options_converted.cudnn_conv_algo_search = legacy_cuda_options->cudnn_conv_algo_search;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/_pybind_state.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ if platform.system() == "Windows":
"(other than %SystemRoot%\System32), "
"make sure it can be found by setting the correct path.")

@ONNXRUNTIME_IMPORT_PYTORCH_TO_RESOLVE_DLLS@
@ONNXRUNTIME_SETDLOPENFLAGS_GLOBAL@
from .onnxruntime_pybind11_state import * # noqa
@ONNXRUNTIME_SETDLOPENFLAGS_LOCAL@
2 changes: 1 addition & 1 deletion onnxruntime/test/contrib_ops/function_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ TEST_F(ContribFunExpansionTest, FastGeluWithoutBias) {
}

} // namespace test
} // namespace onnxruntime
} // namespace onnxruntime
Loading

0 comments on commit dc486d1

Please sign in to comment.