From dc486d146b125db780fa2b71eda16875d01cd5e0 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 22 Aug 2022 09:40:40 -0700 Subject: [PATCH] Make ORT callable from various Pytorch compilers (LazyTensor, TorchDynamo, etc) (#10460) * Make ORT as Pytorch JIT backend LORT likely doesn't work with aten fallback so we only test LORT in its own CI. * Revert changes to enable external CUDA allocator. Will add it later. Revert "Revert changes to enable external CUDA allocator. Will add it later." This reverts commit d5487f2e193014c805505afae8fb577c53667658. Fix external allocator * Relax tolerance and remove commented code * Print more information in CI * Fix pointer * Address comments. 1. Reuse ORT-eager mode's environment. 2. Remove unused ctor. * Use Pytorch master branch as all PRs are merged Fix * Refine based on cpplint feedbacks * Revert changes to allow custom CUDA allocator in public APIs * Use torch.testing.assert_close * Use unittest framework * Switch docker repo * Rename *.cpp to *.cc * Address comments * Add comment * Use same pipeline file for eager and lort pipelines * Address comments * Add yaml comment * Fix cmake files * Address comments * Rename flags, remove printing code, remove dead comment --- cmake/CMakeLists.txt | 21 +- cmake/onnxruntime_python.cmake | 80 ++- docs/ContribOperators.md | 4 +- .../core/graph/contrib_ops/contrib_defs.cc | 3 +- .../providers/cuda/cuda_execution_provider.cc | 1 - .../providers/shared_library/provider_api.h | 1 + .../core/session/provider_bridge_ort.cc | 10 +- onnxruntime/python/_pybind_state.py.in | 1 + .../test/contrib_ops/function_ops_test.cc | 2 +- .../orttraining/lazy_tensor/accelerator.cc | 487 ++++++++++++++++++ .../orttraining/lazy_tensor/accelerator.h | 65 +++ orttraining/orttraining/lazy_tensor/bridge.cc | 295 +++++++++++ orttraining/orttraining/lazy_tensor/bridge.h | 31 ++ .../orttraining/lazy_tensor/cuda_tool.cc | 61 +++ .../orttraining/lazy_tensor/cuda_tool.h | 47 ++ orttraining/orttraining/lazy_tensor/debug.cc | 136 +++++ orttraining/orttraining/lazy_tensor/debug.h | 24 + orttraining/orttraining/lazy_tensor/flags.cc | 79 +++ orttraining/orttraining/lazy_tensor/flags.h | 68 +++ orttraining/orttraining/lazy_tensor/fusion.cc | 405 +++++++++++++++ orttraining/orttraining/lazy_tensor/fusion.h | 13 + .../orttraining/lazy_tensor/register.cc | 90 ++++ .../python/orttraining_python_module.cc | 21 +- .../python/training/experimental/exporter.py | 26 + .../python/orttraining_ortmodule_tests.py | 5 +- .../test/python/orttraining_test_lort.py | 119 +++++ tools/ci_build/build.py | 12 +- .../linux-cpu-eager-pipeline.yml | 80 ++- .../docker/Dockerfile.manylinux2014_lort_cpu | 10 + .../scripts/manylinux/install_deps_lort.sh | 110 ++++ 30 files changed, 2263 insertions(+), 44 deletions(-) create mode 100644 orttraining/orttraining/lazy_tensor/accelerator.cc create mode 100644 orttraining/orttraining/lazy_tensor/accelerator.h create mode 100644 orttraining/orttraining/lazy_tensor/bridge.cc create mode 100644 orttraining/orttraining/lazy_tensor/bridge.h create mode 100644 orttraining/orttraining/lazy_tensor/cuda_tool.cc create mode 100644 orttraining/orttraining/lazy_tensor/cuda_tool.h create mode 100644 orttraining/orttraining/lazy_tensor/debug.cc create mode 100644 orttraining/orttraining/lazy_tensor/debug.h create mode 100644 orttraining/orttraining/lazy_tensor/flags.cc create mode 100644 orttraining/orttraining/lazy_tensor/flags.h create mode 100644 orttraining/orttraining/lazy_tensor/fusion.cc create mode 100644 orttraining/orttraining/lazy_tensor/fusion.h create mode 100644 orttraining/orttraining/lazy_tensor/register.cc create mode 100644 orttraining/orttraining/python/training/experimental/exporter.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_lort.py create mode 100644 tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_lort_cpu create mode 100755 tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7599a408e8f86..6199c8575048f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -169,6 +169,9 @@ option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF) # build eager mode option(onnxruntime_ENABLE_EAGER_MODE "build ort eager mode") +# build Pytorch's LazyTensor support +cmake_dependent_option(onnxruntime_ENABLE_LAZY_TENSOR "Enable ORT as a LazyTensor backend in Pytorch." ON "onnxruntime_ENABLE_TRAINING" OFF) + # build separate library of schemas of (custom) ops used by ORT (for ONNX to MLIR translation) option(onnxruntime_BUILD_OPSCHEMA_LIB "Build op schema library" ON) @@ -1833,7 +1836,7 @@ if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch") endif() if (NOT WIN32) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --compiler-options -fPIC") + list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC) endif() # Options passed to cudafe set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"") @@ -2041,6 +2044,22 @@ if (onnxruntime_ENABLE_EAGER_MODE) add_compile_definitions(ENABLE_EAGER_MODE) list(APPEND ONNXRUNTIME_TARGETS onnxruntime_eager) endif() + +if (onnxruntime_ENABLE_LAZY_TENSOR) + # To support LazyTensor, ORT needs to call Python function from C/C++. + # so onnxruntime_ENABLE_PYTHON is required. + if (NOT onnxruntime_ENABLE_TRAINING OR NOT onnxruntime_ENABLE_PYTHON) + message( + FATAL_ERROR + "Option onnxruntime_ENABLE_LAZY_TENSOR can only be set when onnxruntime_ENABLE_TRAINING and onnxruntime_ENABLE_PYTHON are enabled") + endif() + add_compile_definitions(ENABLE_LAZY_TENSOR) + # TODO: In the future, we can compile LazyTensor into a standalone + # library target, onnxruntime_lazy_tensor, to make the buid + # cleaner. + #list(APPEND ONNXRUNTIME_TARGETS onnxruntime_lazy_tensor) +endif() + foreach(target_name ${ONNXRUNTIME_TARGETS}) include(${target_name}.cmake) endforeach() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index f7ae96672a3fb..aab345b7e9668 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -28,25 +28,60 @@ if(onnxruntime_ENABLE_TRAINING) list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc) endif() -if (onnxruntime_ENABLE_EAGER_MODE) +# Add Pytorch as a library. +if (onnxruntime_ENABLE_LAZY_TENSOR OR onnxruntime_ENABLE_EAGER_MODE) + # Both Lazy Tensor and Eager Mode require Pytorch as a library. list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_PREBUILT_PYTORCH_PATH}) + # The following line may change ${CUDA_NVCC_FLAGS} and ${CMAKE_CUDA_FLAGS}, + # if Pytorch is built from source. + # For example, pytorch/cmake/public/cuda.cmake and + # pytorch/torch/share/cmake/Caffe2/public/cuda.cmake both defines + # ONNX_NAMESPACE for both CUDA_NVCC_FLAGS and CMAKE_CUDA_FLAGS. + # Later, this ONNX_NAMESPACE may conflicts with ONNX_NAMESPACE set by ORT. find_package(Torch REQUIRED) - find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") + # Let's remove ONNX_NAMESPACE from Torch. + list(FILTER CUDA_NVCC_FLAGS EXCLUDE REGEX "-DONNX_NAMESPACE=.+") + string(REGEX REPLACE "-DONNX_NAMESPACE=.+ " " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +endif() +if (onnxruntime_ENABLE_EAGER_MODE) file(GLOB onnxruntime_eager_extension_srcs CONFIGURE_DEPENDS "${ORTTRAINING_ROOT}/orttraining/eager/*.cpp" ) - if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - list(APPEND onnxruntime_eager_extension_srcs - "${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc") + list(APPEND onnxruntime_pybind_srcs + ${onnxruntime_eager_extension_srcs}) +endif() + +# Support ORT as a backend in Pytorch's LazyTensor. +if (onnxruntime_ENABLE_LAZY_TENSOR) + file(GLOB onnxruntime_lazy_tensor_extension_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_ROOT}/orttraining/lazy_tensor/*.cc") + file(GLOB onnxruntime_lazy_tensor_extension_headers CONFIGURE_DEPENDS + "${ORTTRAINING_ROOT}/orttraining/lazy_tensor/*.h") + + if(NOT MSVC) + set_source_files_properties(${onnxruntime_lazy_tensor_extension_srcs} PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) + set_source_files_properties(${onnxruntime_lazy_tensor_extension_headers} PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) endif() list(APPEND onnxruntime_pybind_srcs - ${onnxruntime_eager_extension_srcs}) + ${onnxruntime_lazy_tensor_extension_srcs}) +endif() + +# onnxruntime_ENABLE_LAZY_TENSOR and onnxruntime_ENABLE_EAGER_MODE +# need DLPack code to pass tensors cross ORT and Pytorch boundary. +# TODO: consider making DLPack code a standalone library. +if (onnxruntime_ENABLE_LAZY_TENSOR OR onnxruntime_ENABLE_EAGER_MODE) + # If DLPack code is not built, add it to ORT's pybind target. + if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + list(APPEND onnxruntime_pybind_srcs + "${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc") + endif() endif() onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_pybind_srcs}) + if(MSVC) target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") if(onnxruntime_ENABLE_TRAINING) @@ -108,18 +143,28 @@ if (onnxruntime_ENABLE_TRAINING) target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_training) endif() -if (onnxruntime_ENABLE_EAGER_MODE) +# Eager mode and LazyTensor are both Pytorch's backends, so their +# dependencies are set together below. +if (onnxruntime_ENABLE_EAGER_MODE OR onnxruntime_ENABLE_LAZY_TENSOR) + # Set library dependencies shared by aforementioned backends. + # todo: this is because the prebuild pytorch may use a different version of protobuf headers. # force the build to find the protobuf headers ort using. - target_include_directories(onnxruntime_pybind11_state PRIVATE "${REPO_ROOT}/cmake/external/protobuf/src") - target_include_directories(onnxruntime_pybind11_state PRIVATE "${TORCH_INSTALL_PREFIX}/include" "${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include") - find_library(LIBTORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib") - find_library(LIBTORCH_CPU_LIBRARY torch_cpu PATHS "${TORCH_INSTALL_PREFIX}/lib") - find_library(LIBC10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib") - target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_eager ${LIBTORCH_LIBRARY} ${LIBTORCH_CPU_LIBRARY} ${LIBC10_LIBRARY} ${TORCH_PYTHON_LIBRARY}) + target_include_directories(onnxruntime_pybind11_state PRIVATE + "${REPO_ROOT}/cmake/external/protobuf/src" + ${TORCH_INCLUDE_DIRS}) + + # Explicitly link torch_python to workaround https://github.com/pytorch/pytorch/issues/38122#issuecomment-694203281 + find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") + target_link_libraries(onnxruntime_pybind11_state PRIVATE ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES}) + if (onnxruntime_ENABLE_EAGER_MODE) + target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_eager) + endif() + + # This part is eager-mode specific. # the ort_aten.g.cpp is generated from tools. currently it has some limitations. # todo: fix this - if (NOT MSVC) + if (onnxruntime_ENABLE_EAGER_MODE AND NOT MSVC) set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_aten.g.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_aten.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_guard.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) @@ -274,6 +319,13 @@ else() set(ONNXRUNTIME_SETDLOPENFLAGS_LOCAL "") endif() +if (onnxruntime_ENABLE_LAZY_TENSOR) + # Import torch so that onnxruntime's pybind can see its DLLs. + set(ONNXRUNTIME_IMPORT_PYTORCH_TO_RESOLVE_DLLS "import torch") +else() + set(ONNXRUNTIME_IMPORT_PYTORCH_TO_RESOLVE_DLLS "") +endif() + configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 2dca2eca8b6cd..703949b6ca6b4 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1357,14 +1357,14 @@ This version of the operator has been available since version 1 of the 'com.micr
Whether B should be transposed
-#### Inputs +#### Inputs (2 - 3)
A : T
Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
B : T
Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
-
C : T
+
C (optional) : T
Input tensor C. The shape of C should be unidirectional broadcastable to (M, N).
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index d1f051f46f381..921d44716f12b 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1295,7 +1295,8 @@ activation and leaky_relu_alpha.)DOC") "C", "Input tensor C. " "The shape of C should be unidirectional broadcastable to (M, N).", - "T") + "T", + OpSchema::Optional) .Output(0, "Y", "Output tensor of shape (M, N).", "T") .TypeConstraint( "T", diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 57e0f1aaf957b..b19d1db233ba7 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -127,7 +127,6 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi false); return CreateAllocator(default_memory_info); - } else { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId id) { diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index cb6dc880c4870..7caaa25cb4ccb 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -317,3 +317,4 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { r #define LOGS_DEFAULT(severity) \ LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index fd35f57877a62..bd75d1ff31f12 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -834,16 +834,16 @@ struct ProviderHostImpl : ProviderHost { const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); } // Tensor (wrapped) - std::unique_ptr Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator) override { - return std::make_unique(p_type, shape, std::move(allocator)); + std::unique_ptr Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator) override { + return std::make_unique(p_type, shape, std::move(allocator)); } std::unique_ptr Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) override { - return std::make_unique(p_type, shape, p_data, alloc, offset); + return std::make_unique(p_type, shape, p_data, alloc, offset); } std::unique_ptr Tensor__construct_default() override { - return std::make_unique(); + return std::make_unique(); } virtual void Tensor__move_assign(Tensor& lhs, Tensor&& rhs) noexcept override { @@ -1130,7 +1130,7 @@ std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const c // Adapter to convert the legacy OrtCUDAProviderOptions to the latest OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const OrtCUDAProviderOptions* legacy_cuda_options) { - OrtCUDAProviderOptionsV2 cuda_options_converted; + OrtCUDAProviderOptionsV2 cuda_options_converted{}; cuda_options_converted.device_id = legacy_cuda_options->device_id; cuda_options_converted.cudnn_conv_algo_search = legacy_cuda_options->cudnn_conv_algo_search; diff --git a/onnxruntime/python/_pybind_state.py.in b/onnxruntime/python/_pybind_state.py.in index f8bdbd9a59e82..ee625a0f7fb6c 100644 --- a/onnxruntime/python/_pybind_state.py.in +++ b/onnxruntime/python/_pybind_state.py.in @@ -28,6 +28,7 @@ if platform.system() == "Windows": "(other than %SystemRoot%\System32), " "make sure it can be found by setting the correct path.") +@ONNXRUNTIME_IMPORT_PYTORCH_TO_RESOLVE_DLLS@ @ONNXRUNTIME_SETDLOPENFLAGS_GLOBAL@ from .onnxruntime_pybind11_state import * # noqa @ONNXRUNTIME_SETDLOPENFLAGS_LOCAL@ diff --git a/onnxruntime/test/contrib_ops/function_ops_test.cc b/onnxruntime/test/contrib_ops/function_ops_test.cc index 065f2bba98736..fa373edd166cd 100644 --- a/onnxruntime/test/contrib_ops/function_ops_test.cc +++ b/onnxruntime/test/contrib_ops/function_ops_test.cc @@ -114,4 +114,4 @@ TEST_F(ContribFunExpansionTest, FastGeluWithoutBias) { } } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/accelerator.cc b/orttraining/orttraining/lazy_tensor/accelerator.cc new file mode 100644 index 0000000000000..3e711f0eff1ea --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/accelerator.cc @@ -0,0 +1,487 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/lazy_tensor/accelerator.h" +// C++ +#include +#include +#include +#include +#include +#include +// Pytorch. +#include +#include +#include +// ORT friends. +#include "core/common/logging/sinks/clog_sink.h" +#include "core/framework/execution_providers.h" +#include "core/framework/session_options.h" +#include "core/session/environment.h" +#include "python/onnxruntime_pybind_state_common.h" +// Lazy tensor specific. +#include "orttraining/lazy_tensor/bridge.h" +#include "orttraining/lazy_tensor/cuda_tool.h" +#include "orttraining/lazy_tensor/debug.h" +#include "orttraining/lazy_tensor/flags.h" + +namespace onnxruntime { + +namespace python { +Environment& GetTrainingORTEnv(); +} + +namespace lazytensor { + +namespace py = pybind11; +namespace aten = torch::jit::aten; +namespace prim = torch::jit::prim; + +bool Accelerator::Supported(const torch::jit::Node* node) { + if (!node) { + return false; + } + + switch (node->kind()) { + // TODO(wechi): add as many ops as possible. + case aten::embedding: + case aten::tanh: + case aten::slice: + case aten::bmm: + case aten::gelu: + case aten::native_layer_norm: + case aten::native_dropout: + case aten::expand: + case aten::add: + case aten::convolution: + case aten::reshape: + case aten::max_pool2d_with_indices: + case aten::_log_softmax: + case aten::relu: + case aten::mul: + case aten::sub: + case aten::div: + case aten::gt: + case aten::lt: + case aten::eq: + case aten::sqrt: + case aten::permute: + case aten::mm: + case aten::ne: + case aten::abs: + case aten::max: + case aten::min: { + if (DumpAtenOpHistory()) { + std::cout << "Supported op: " + << ToString(*node) << std::endl; + } + return true; + } + default: { + if (DumpAtenOpHistory()) { + std::cout << "Unsupported op: " + << ToString(*node) << std::endl; + // To check sub-graph in specific symbol such as prim::TensorExprGroup, + // uncomment and extend the following code. + // + // if (node->kind() == prim::TensorExprGroup || node->kind() == prim::FallbackGraph) { + // auto subgraph = node->g(torch::jit::attr::Subgraph); + // std::cout << "Node's subgraph: " << *subgraph; + // } + } + + return false; + } + } +} + +void Accelerator::OrtRun(torch::jit::Stack& stack) { +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + // Uncomment the following if you want to see the + // sub-graph in Nsys profiling result. This is useful + // for debugging. + // + // NvtxRange range_graph(subgraph_->toString(true)); + if (DumpGraph()) { + std::cout << "[ORT,Graph]\n" + << subgraph_->toString(true); + } + + // Get these inputs from the stack. + at::ArrayRef inputs = torch::jit::last(stack, subgraph_->inputs().size()); + // If we haven't compiled for the shape/device of these inputs before, + // do so now. + // Compile a callable to execute "subgraph_" on the inputs. + // If such input schema appears before, we can reuse a cached compiled callable. + torch::jit::CompleteArgumentSpec spec{false, inputs}; + if (cache_.find(spec) == cache_.end()) { + cache_.emplace(spec, Compile(spec, inputs)); + } + + if (DumpInputsOutputs()) { + std::cout << "[ORT,Input] " << ToString(inputs) << std::endl; + } + + // Run the compiled function! + auto outputs = cache_[spec].code(inputs); + + // Discard used inputs. + torch::jit::drop(stack, inputs.size()); + + // Return results to caller. + for (auto& output : outputs) { + stack.push_back(output); + } + + if (DumpInputsOutputs()) { + at::ArrayRef outputs = torch::jit::last(stack, subgraph_->outputs().size()); + std::cout << "[ORT,Output] " << ToString(outputs) << std::endl; + } +} + +void Accelerator::PytorchRun(torch::jit::Stack& stack) { + DynamicSettings::GetInstance().SetOnnxFusionFlag(false); +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + if (DumpGraph()) { + std::cout << "[Pytorch,Graph]\n" + << subgraph_->toString(true); + } + if (DumpInputsOutputs()) { + at::ArrayRef inputs = torch::jit::last( + stack, subgraph_->inputs().size()); + std::cout << "[PyTorch,Input] " << ToString(inputs) << std::endl; + } + + torch::jit::GraphExecutor executor(subgraph_, ""); + executor.run(stack); + + if (DumpInputsOutputs()) { + at::ArrayRef outputs = torch::jit::last( + stack, subgraph_->outputs().size()); + std::cout << "[PyTorch,Output] " << ToString(outputs) << std::endl; + } + DynamicSettings::GetInstance().SetOnnxFusionFlag(true); +} + +void Accelerator::DebugRun(torch::jit::Stack& stack) { +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + torch::jit::Stack copy; + copy = stack; + OrtRun(stack); + PytorchRun(copy); + ORT_ENFORCE(CompareStack(stack, copy), + "ORT and Pytorch must generate the same results " + "but tensor types, shapes or content are different. " + "Use, e.g., LORT_RELATIVE_TOLERANCE=1e-3 and " + "LORT_ABSOLUTE_TOLERANCE=1e-4 " + "to increase the content tolerance, if " + "the difference is due to numerical errors."); +} + +void Accelerator::Run(torch::jit::Stack& stack) { + const auto run_type = RunType(); + if (run_type == "debug") { + // Run both ORT and Pytorch to execute the subgraph + // and compare their output types and shapes. + DebugRun(stack); + } else if (run_type == "ort") { + OrtRun(stack); + } else if (run_type == "pytorch") { + PytorchRun(stack); + } else { + ORT_THROW("Unknown run type: ", run_type); + } +} + +static void CheckArgs( + const at::ArrayRef& inputs) { + // TODO(wechi): remove this check. +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + TORCH_CHECK(inputs.size(), "Need at least one input."); + for (const auto& input : inputs) { + TORCH_CHECK(input.isTensor() || input.isScalar(), "Compiler can only handle Tensor or Scalar inputs."); + } +} + +// Store input types in sub-graph so that +// ONNX exporter can use them. Input types +// are required when executing ONNX model +// in ORT. +// TODO(wechi): Allow ORT to accept models without +// input types. Then, we can remove this function. +static void SetArgTypes( + const at::ArrayRef& inputs, + std::shared_ptr graph) { + TORCH_CHECK(graph->inputs().size() == inputs.size(), + "Number of provided inputs must match captured sub-graph's schema."); + for (size_t i = 0; i < graph->inputs().size(); ++i) { + auto input_symbol = graph->inputs()[i]; + auto input_value = inputs[i]; + if (!input_value.isTensor()) { + // The allowed IR components in ONNX exporter and Pytorch + // are a little different. I am not confident to fill + // types other than tensor, because of the ambiguous scalar + // representations in Pytorch. + continue; + } + input_symbol->setType(input_value.type()); + } +} + +// ONNX exporter is written in Python, so +// this function may call some Python functions. +// Be aware of GIL issue. +// The returned value is the path to exported +// ONNX file. +static std::string ExportToOnnx( + std::shared_ptr graph, + const at::ArrayRef& args) { +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + // ONNX exporter modifies the graph in-place, so we + // need to clone it to avoid interaction between + // Pytorch's JIT mechanism and ONNX graph. + std::shared_ptr new_subgraph(graph->copyUnique().release()); + // Acquire GIL since Python is not multi-threading. + pybind11::gil_scoped_acquire guard{}; + // Retrieve Python exporter function. + pybind11::function export_to_onnx = + pybind11::reinterpret_borrow( + pybind11::module::import("onnxruntime.training.experimental.exporter") + .attr("_export_jit_graph_to_onnx_model_proto")); + // Fill types up. The sub-graphp from LazyTensor doesn't + // contain input shapes. + SetArgTypes(args, new_subgraph); + // Execute Python function. + auto result = export_to_onnx(new_subgraph, ::torch::onnx::OperatorExportTypes::ONNX); + return result.cast(); +} + +// Create an empty session object. +// Models will be loaded later. +static std::unique_ptr CreateSession() { +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + // Enviroment shared by all sessions. + static onnxruntime::Environment& pybind_default_env = onnxruntime::python::GetTrainingORTEnv(); + // All sessions use the same config. + static onnxruntime::SessionOptions sess_opts; + return std::make_unique(sess_opts, pybind_default_env); +} + +static OrtDevice CheckAndGetTensorDevice(const at::ArrayRef& values) { + // This memory info must be shared by all tensors; + // for example, all tensors on CPU or all on a specific GPU. + // When all values are not tensors, we assume CPU device. + // c10::Device's index is default to -1. + c10::Device unique_tensor_device(c10::DeviceType::CPU); + bool assigned = false; + for (auto value : values) { + if (!value.isTensor()) { + continue; + } + auto tensor = value.toTensor(); + if (assigned) { + // A device has been recorded, so we compare + // it with the current tensor's device. + TORCH_CHECK(unique_tensor_device == tensor.device(), + "All tensors must be on the same device."); + } else { + // Record the 1st tensor device. + unique_tensor_device = tensor.device(); + assigned = true; + } + } + return CreateOrtDevice(unique_tensor_device); +} + +// Initialize empty session with ONNX model. +static void InitializeSession( + const OrtDevice device, + const std::string& serialized_model, + onnxruntime::InferenceSession& sess) { + // Add EPs. +#ifdef USE_CUDA + NvtxRange range(__func__); + // When CUDA is enabled, some CUDA-only graph graph fusions are enabled. + // If we don't add CUDA EP, ONNX Runtime may throw even when running MNIST. + // Information needed to construct CUDA execution providers. + // Note that CUDA is enabled by setting LTC_TS_CUDA=1 when running LazyTensor. + if (device.Type() == OrtDevice::GPU) { + ORT_THROW_IF_ERROR(sess.RegisterExecutionProvider( + CUDAExecutionProviderPool::GetInstance().GetExecutionProvider(device.Id()))); + } +#endif + ORT_THROW_IF_ERROR(sess.Load(serialized_model.data(), serialized_model.size())); + ORT_THROW_IF_ERROR(sess.Initialize()); +} + +void Accelerator::ExampleRun(at::ArrayRef inputs) { +#ifdef USE_CUDA + NvtxRange range(__func__); +#endif + torch::jit::Stack stack; + for (auto input : inputs) { + stack.push_back(input); + } + + // Run graph and store input and output types. + // 1. Store input types. + // This check prevent some unexpected modification on input_types_. + ORT_ENFORCE(input_types_.size() == inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + c10::TypePtr type = inputs.at(i).type(); + // LazyTensor should only capture graph with numerical inputs and outputs. + // If this assumption is broken, please use Accelerator::Supported to filter + /// out unsupported types and operators. + ORT_ENFORCE(type->isSubtypeOf(*c10::TensorType::get()) || + type->isSubtypeOf(*c10::NumberType::get()), + "ONNX only support tensor, float, int, bool as graph's input types"); + input_types_.at(i) = type; + } + + // 2. Run graph. + torch::jit::GraphExecutor executor(subgraph_, ""); + executor.run(stack); + + // 3. Store output types. + at::ArrayRef outputs = torch::jit::last(stack, subgraph_->outputs().size()); + ORT_ENFORCE(output_types_.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + c10::TypePtr type = outputs.at(i).type(); + ORT_ENFORCE(type->isSubtypeOf(*c10::TensorType::get()) || + type->isSubtypeOf(*c10::NumberType::get()), + "ONNX only support tensor, float, int, bool as graph's output types. But got ", + type->str()); + output_types_.at(i) = type; + } +} + +CompiledObject Accelerator::Compile( + torch::jit::CompleteArgumentSpec spec, at::ArrayRef& args) { + CheckArgs(args); + DynamicSettings::GetInstance().SetOnnxFusionFlag(false); + ExampleRun(args); + DynamicSettings::GetInstance().SetOnnxFusionFlag(true); + // Storage of compilation. + CompiledObject compiled; + // Create an empty session. + compiled.sess = CreateSession(); + // Let's get the empty session and initialize it. + onnxruntime::InferenceSession& sess = *compiled.sess; + // Export subgraph_ to ONNX. + // The exporter should never fail. If it does, please modify + // Accelerator::Supported to filter out unsupported operators. + const std::string serialized_model = ExportToOnnx(subgraph_, args); + // Memory info for all tensors. + // Assume all inputs are on the same device. + OrtDevice shared_device = CheckAndGetTensorDevice(args); + // Load ONNX model into session, register + // EPs and finally initialize session. + InitializeSession(shared_device, serialized_model, sess); + + onnxruntime::RunOptions run_options; + std::vector feed_names; + std::vector fetch_names; + + for (auto node_arg : *sess.GetModelInputs().second) { + feed_names.push_back(node_arg->Name()); + } + for (auto node_arg : *sess.GetModelOutputs().second) { + fetch_names.push_back(node_arg->Name()); + } + + // Duplicate device info for putting output tensors on the shared device. + std::vector fetches_device_info(fetch_names.size(), shared_device); + + // Create a callable which feeds inputs to ORT + // session's Run(...) and returns outputs. + auto code = [this, run_options, + feed_names, fetch_names, + fetches_device_info, &sess](at::ArrayRef& args) { + // Inputs of ORT session. + std::vector feeds; + // Outputs of ORT session. + std::vector fetches; + + { +#ifdef USE_CUDA + NvtxRange range("Prepare inputs"); +#endif + // Prepare inputs. + const auto num_inputs = subgraph_->inputs().size(); + for (size_t i = 0; i < num_inputs; ++i) { + // The value can be either tensor or scalar. + // Scalar is a tensor with empty shape vector. + // Create ORT tensor from Pytorch tensor without copy. + if (args.at(i).isScalar()) { + // Scalar. + // ORT_ENFORCE(subgraph_->inputs().at(i)->type()->kind() == c10::TypeKind::TensorType); + feeds.push_back(CreateOrtScalarValue(args.at(i).toScalar())); + } else if (args.at(i).isTensor()) { + // Tensor. + ORT_ENFORCE(subgraph_->inputs().at(i)->type()->kind() == c10::TypeKind::TensorType); + feeds.push_back(CreateOrtTensorValue(args.at(i).toTensor())); + } else { + // Looks like LTC only passes scalars and tensors into backend, so we don't care + // other types for now. + ORT_THROW("Only tensor inputs are supported."); + } + } + } + + { +#ifdef USE_CUDA + NvtxRange range("Call sess.Run"); +#endif + // Inputs are ready. Let's run ORT. + ORT_THROW_IF_ERROR(sess.Run( + run_options, + feed_names, feeds, + fetch_names, &fetches, &fetches_device_info)); + } + + std::vector outputs; + { +#ifdef USE_CUDA + NvtxRange range("Convert outputs"); +#endif + // Convert ORT output to Pytorch format. + for (size_t i = 0; i < fetches.size(); ++i) { + // Get the expected type of the i-th output. + const c10::TypePtr type = output_types_.at(i); + // Convert ORTValue to IValue. + if (type->isSubtypeOf(*c10::TensorType::get())) { + ORT_ENFORCE(fetches.at(i).IsTensor(), "Only ORT tensor can be translated to Pytorch tensor."); + auto value = CreateC10IvalueTensor(fetches.at(i)); + auto expected_scalar_type = output_types_.at(i)->cast()->scalarType().value(); + outputs.push_back(value.toTensor().to(expected_scalar_type)); + } else if (type->isSubtypeOf(*c10::NumberType::get())) { + // ORT represents scalar as tensor without shape. + ORT_ENFORCE(fetches.at(i).IsTensor(), "Only ORT tensor can be translated to Pytorch scalar."); + auto value = CreateC10IvalueScalar(fetches.at(i)); + outputs.push_back(value); + } else { + ORT_ENFORCE(false, "Unsupported c10::Type ", type->str()); + } + } + } + + return outputs; + }; + + compiled.code = code; + return compiled; +} +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/accelerator.h b/orttraining/orttraining/lazy_tensor/accelerator.h new file mode 100644 index 0000000000000..0b32348db60cd --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/accelerator.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/session/inference_session.h" +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace lazytensor { + +// Type of JIT compilation result. +struct CompiledObject { + // Callable to execute the computation represented by torch::jit::Graph. + // It processes tensors across ORT and Pytorch and invokes "sess". + std::function(at::ArrayRef&)> code; + // Session used in the "code" above. + std::unique_ptr sess; +}; + +// Custom JIT engine called by Pytorch. +class Accelerator { + public: + Accelerator(const torch::jit::Node* node) + : subgraph_(node->g(torch::jit::attr::Subgraph)), + input_types_(subgraph_->inputs().size()), + output_types_(subgraph_->outputs().size()) {} + // Execute a call to the torch::jit::Graph represented by "subgraph_". + // This function could compile the graph and cache the result + // for repeated uses. + void Run(torch::jit::Stack& stack); + // Determine if this node can be translated to ONNX. + static bool Supported(const torch::jit::Node* node); + + private: + // This function runs the "subgraph_" using PyTorch JIT executor + // to get expected doutput schema. No ORT is involved. + void ExampleRun(at::ArrayRef inputs); + // This function calls "OrtRun" and "PytorchRun" to execute the graph + // and compare their results. It may fail if their results are different + // types or shapes. + void DebugRun(torch::jit::Stack& stack); + // Execute the graph represented by "subgraph_" using ORT. + // Inputs are popped out from stack and outputs are pushed to stack. + void OrtRun(torch::jit::Stack& stack); + // Similar to "OrtRun" but uses Pytorch as executor. + void PytorchRun(torch::jit::Stack& stack); + // Create callable to execute "subgraph_" given "args" as inputs. + // This calllable is cached for repeated uses. + CompiledObject Compile( + torch::jit::CompleteArgumentSpec spec, at::ArrayRef& args); + // The graph to be compiled and executed by ORT. + std::shared_ptr subgraph_; + // Previously compiled results. + std::unordered_map cache_; + // Types of the inputs (typed to IValue) we got when compile the subgraph. + // Since the subgraph is compiled for these type, feeding + // inputs with different types may fail. + std::vector input_types_; + // Types of the outputs (typed to IValue) by running the subgraph with + // torch::jit::GraphExecutor. + std::vector output_types_; +}; +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/bridge.cc b/orttraining/orttraining/lazy_tensor/bridge.cc new file mode 100644 index 0000000000000..9f930caa6b6b4 --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/bridge.cc @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "bridge.h" +#include + +namespace onnxruntime { +namespace lazytensor { +c10::ScalarType CreateC10ScalarType(const onnxruntime::PrimitiveDataTypeBase* elem_type) { + ORT_ENFORCE(elem_type, "Element type pointer cannot be NULL."); + switch (static_cast(elem_type->GetDataType())) { + case onnxruntime::utils::ToTensorProtoElementType(): { + return c10::kFloat; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return c10::kDouble; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return at::kHalf; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return c10::kBFloat16; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return at::kBool; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return at::kShort; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return at::kInt; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + return at::kLong; + } + default: + ORT_THROW("Unsupport ORT scalar type."); + } +} + +onnxruntime::MLDataType CreateOrtScalarType( + at::ScalarType dtype) { + switch (dtype) { + case at::kFloat: + return onnxruntime::DataTypeImpl::GetType(); + case at::kDouble: + return onnxruntime::DataTypeImpl::GetType(); + case at::kHalf: + return onnxruntime::DataTypeImpl::GetType(); + case at::kBFloat16: + return onnxruntime::DataTypeImpl::GetType(); + case at::kInt: + return onnxruntime::DataTypeImpl::GetType(); + case at::kShort: + return onnxruntime::DataTypeImpl::GetType(); + case at::kLong: + return onnxruntime::DataTypeImpl::GetType(); + case at::kBool: + return onnxruntime::DataTypeImpl::GetType(); + default: + ORT_THROW("Unsupport aten scalar type: ", dtype); + } +} + +OrtDevice CreateOrtDevice(const c10::Device device) { + // c10::Device's ID can be negative, which means current device. + // Assumptions: + // 1. c10::Device::CPU is always indexed by -1. + // Thus, it's mapped to OrtDevice::CPU with index 0. + // 2. c10::Device::GPU always has non-negative index. + // If the index is i, it's mapped to OrtDevice::GPU with index i. + + // For each case, assert if our assumptions are true and then do the work. + if (device.type() == c10::DeviceType::CPU) { + ORT_ENFORCE(device.index() == -1); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0); + } else if (device.type() == c10::DeviceType::CUDA) { + ORT_ENFORCE(device.index() >= 0); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device.index()); + } else { + ORT_THROW("Unsupport Pytorch c10 device.", + " Type: ", c10::DeviceTypeName(device.type()), ",", + " ID: ", device.index()); + } +} + +c10::Device CreateC10Device(const OrtDevice& device) { + // Handles CPU, GPU, and throws otherwise. + switch (device.Type()) { + case OrtDevice::CPU: { + ORT_ENFORCE(device.Id() == 0, "ORT CPU device ID must be 0 but got ", device.Id()); + // No need to specify index when creating c10 CPU. + return c10::Device(c10::DeviceType::CPU); + } + case OrtDevice::GPU: { + ORT_ENFORCE(device.Id() >= 0, "ORT GPU device ID must be >= 0 but got ", device.Id()); + // c10 GPU can have negative index (means current device), + // but only using non-negative index is enough to cover all ORT cases. + return c10::Device(c10::DeviceType::CUDA, device.Id()); + } + default: { + // Got unsupported device. Throws. + const char* device_str = nullptr; + if (device.Type() == OrtDevice::CPU) { + device_str = "CPU"; + } else if (device.Type() == OrtDevice::GPU) { + device_str = "GPU"; + } else { + device_str = "Unknown"; + } + ORT_THROW("Unsupport ORT device: ", device_str, ", ID: ", device.Id()); + } + } +} + +OrtValue CreateOrtTensorValue(const at::Tensor& tensor) { + onnxruntime::MLDataType element_type = CreateOrtScalarType(tensor.scalar_type()); + onnxruntime::TensorShape shape(tensor.sizes().vec()); + OrtDevice device = CreateOrtDevice(tensor.device()); + OrtMemoryInfo memory_info = OrtMemoryInfo("LTC", OrtAllocatorType::OrtDeviceAllocator, device, device.Id()); + // This tensor's life time is controlled by Pytorch. + // TODO: consider to let ORT also own that tensor. + std::unique_ptr ort_tensor = std::make_unique( + element_type, shape, + tensor.data_ptr(), memory_info); + + OrtValue ort_value; + ort_value.Init( + ort_tensor.release(), + onnxruntime::DataTypeImpl::GetType(), + onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); + return ort_value; +} + +c10::IValue CreateC10IvalueTensor(OrtValue value) { + onnxruntime::Tensor* tensor = value.GetMutable(); + const OrtDevice& device = tensor->Location().device; + auto options = torch::TensorOptions() + .dtype(CreateC10ScalarType(tensor->DataType()->AsPrimitiveDataType())) + .layout(torch::kStrided) + .device(CreateC10Device(device)) + .requires_grad(false); + + // Extract shape from onnxruntime::TensorShape as a vector. + auto create_shape_vector = [](const onnxruntime::TensorShape& shape) { + std::vector new_shape(shape.NumDimensions()); + shape.CopyDims(new_shape.data(), shape.NumDimensions()); + return new_shape; + }; + + at::Tensor new_tensor = torch::from_blob( + tensor->MutableDataRaw(), + create_shape_vector(tensor->Shape()), + // Capture-by-value means + // 1. A new OrtValue is direct-initialized from "value". + // 2. The new OrtValue and "value" share the same underlying tensor, so + // the tensor's lifetime is controlled by both of them, whichever is longer. + // 3. The new OrtValue's lifetime is the same as this lambda function. + // 4. This lambda function is deleted by "new_tensor"'s dtor, which also ends + // the underlying tensor's life. + [value, tensor](void* p) { + // std::cout << "ORT-LR fake deletes Pytorch tensor wrapping ORT tensor @ " << tensor << std::endl; + }, + options); + + return c10::IValue(new_tensor); +} + +OrtValue CreateOrtScalarValue(const at::Scalar& scalar) { + // This tensor's life time is controlled by Pytorch. + // TODO: consider to let ORT also own that tensor. + void* data_ptr = nullptr; + std::function data_deleter; + switch (scalar.type()) { + case at::kFloat: { + data_ptr = new float; + *reinterpret_cast(data_ptr) = scalar.toFloat(); + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + case at::kDouble: { + data_ptr = new double; + *reinterpret_cast(data_ptr) = scalar.toDouble(); + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + case at::kBFloat16: { + at::BFloat16 valBFloat16 = scalar.toBFloat16(); + onnxruntime::BFloat16* valOrtBFloat16 = reinterpret_cast(&valBFloat16); + data_ptr = new onnxruntime::BFloat16; + *reinterpret_cast(data_ptr) = *valOrtBFloat16; + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + case at::kShort: { + data_ptr = new int16_t; + *reinterpret_cast(data_ptr) = scalar.toShort(); + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + case at::kInt: { + data_ptr = new int; + *reinterpret_cast(data_ptr) = scalar.toInt(); + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + case at::kLong: { + data_ptr = new int64_t; + *reinterpret_cast(data_ptr) = scalar.toLong(); + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + case at::kBool: { + data_ptr = new bool; + *reinterpret_cast(data_ptr) = scalar.toBool(); + data_deleter = [=]() { + delete reinterpret_cast(data_ptr); + }; + break; + } + default: + ORT_THROW("Unsupport aten scalar type: ", scalar.type()); + } + + OrtDevice cpu_device = CreateOrtDevice(c10::Device(c10::DeviceType::CPU)); + OrtMemoryInfo memory_info = OrtMemoryInfo("LTC", OrtAllocatorType::OrtDeviceAllocator, cpu_device, cpu_device.Id()); + + onnxruntime::MLDataType element_type = CreateOrtScalarType(scalar.type()); + onnxruntime::TensorShape shape({}); + std::unique_ptr ort_tensor = std::make_unique( + element_type, shape, + data_ptr, memory_info); + + std::function deleter = [=](void* p) { + data_deleter(); + onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()(p); + }; + + OrtValue ort_value; + ort_value.Init( + ort_tensor.release(), + onnxruntime::DataTypeImpl::GetType(), + deleter); + return ort_value; +} + +c10::IValue CreateC10IvalueScalar(OrtValue value) { + onnxruntime::Tensor* tensor = value.GetMutable(); + // Here we assume tensors with empty shape are at::Scalar's. + // Assert on CPU because at::Scalar is always on CPU. + const OrtDevice::DeviceType& device_type = tensor->Location().device.Type(); + ORT_ENFORCE(device_type == OrtDevice::CPU); + + // Scalar ORT values must be put on CPU. Otherwise, the following code may throw. + // TODO: relax this constraint by + // 1. specifying output devices when calling ORT session, or + // 2. manually copying GPU OrtValue to CPU here. + c10::IValue new_value; + ORT_ENFORCE(tensor->Location().device.Type() == OrtDevice::CPU); + switch (static_cast(tensor->DataType()->AsPrimitiveDataType()->GetDataType())) { + case onnxruntime::utils::ToTensorProtoElementType(): { + new_value = at::IValue(static_cast(*tensor->Data())); + break; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + new_value = at::IValue(*tensor->Data()); + break; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + new_value = at::IValue(*tensor->Data()); + break; + } + case onnxruntime::utils::ToTensorProtoElementType(): { + new_value = at::IValue(*tensor->Data()); + break; + } + default: + ORT_THROW("Unsupport aten scalar type."); + } + + return new_value; +} +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/bridge.h b/orttraining/orttraining/lazy_tensor/bridge.h new file mode 100644 index 0000000000000..3c55ee5eb987f --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/bridge.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/framework/ortdevice.h" +#include "core/framework/ort_value.h" + +namespace onnxruntime { +namespace lazytensor { +// Scalar type translation from ONNX to Pytorch. +c10::ScalarType CreateC10ScalarType(const onnxruntime::PrimitiveDataTypeBase* elem_type); +// Scalar type translation from Pytorch to ORT. +onnxruntime::MLDataType CreateOrtScalarType(at::ScalarType dtype); +// Device translation from Pytorch to ORT. +OrtDevice CreateOrtDevice(const c10::Device device); +// Device translation from ORT to Pytorch. +c10::Device CreateC10Device(const OrtDevice& device); +// Create a tensor from a Pytorch tensor. No memory copy. +// Conceptually, the returned tensor is a view of the input tensor. +OrtValue CreateOrtTensorValue(const at::Tensor& tensor); +// Similarly, create a Pytorch tensor from an OrtValue without +// memory copy. +// The created at::Tensor and onnxruntime::Tensor have +// the same lifetime. +c10::IValue CreateC10IvalueTensor(OrtValue value); +// Map Pytorch scalar to tensor with empty shape in ORT. +OrtValue CreateOrtScalarValue(const at::Scalar& scalar); +// Wrap ORT scalar as c10::IValue (a scalar). +c10::IValue CreateC10IvalueScalar(OrtValue value); +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/cuda_tool.cc b/orttraining/orttraining/lazy_tensor/cuda_tool.cc new file mode 100644 index 0000000000000..40c7248f25598 --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/cuda_tool.cc @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_CUDA +#include "cuda_tool.h" +// CUDA +#include "cuda.h" +#include "cuda_runtime.h" +#include "nvToolsExt.h" +// Pytorch +#include +// ORT +#include "core/providers/cuda/cuda_provider_options.h" +#include "core/providers/provider_factory_creators.h" +#include "orttraining/python/orttraining_pybind_common.h" + +namespace onnxruntime { +namespace lazytensor { + +NvtxRange::NvtxRange(const char* name) { + nvtxRangePush(name); +} +NvtxRange::NvtxRange(const std::string& name) { + nvtxRangePush(name.c_str()); +} +NvtxRange::~NvtxRange() { + nvtxRangePop(); +} + +// Wrapper of memory allocation function. +void* CudaAllocDelegate(size_t nbytes) { + auto allocator = at::cuda::getCUDADeviceAllocator(); + return allocator->raw_allocate(nbytes); +} + +// Wrapper of memory de-allocation function. +void CudaFreeDelegate(void* ptr) { + auto allocator = at::cuda::getCUDADeviceAllocator(); + allocator->raw_deallocate(ptr); +} + +void CUDAExecutionProviderPool::Initialize() { + int device_count = 0; + cudaGetDeviceCount(&device_count); + for (int i = 0; i < device_count; ++i) { + onnxruntime::ProviderOptions options; + options["device_id"] = std::to_string(i); + options["do_copy_in_default_stream"] = "true"; + options["gpu_external_alloc"] = std::to_string(reinterpret_cast(&CudaAllocDelegate)); + options["gpu_external_free"] = std::to_string(reinterpret_cast(&CudaFreeDelegate)); + + ProviderInfo_CUDA* cuda_provider_info = TryGetProviderInfo_CUDA(); + CUDAExecutionProviderInfo info; + cuda_provider_info->CUDAExecutionProviderInfo__FromProviderOptions(options, info); + cuda_execution_providers_.emplace_back(std::move(cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider())); + } +} + +} // namespace lazytensor +} // namespace onnxruntime +#endif diff --git a/orttraining/orttraining/lazy_tensor/cuda_tool.h b/orttraining/orttraining/lazy_tensor/cuda_tool.h new file mode 100644 index 0000000000000..7f8e6b1f6088a --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/cuda_tool.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_CUDA +#include +#include +#include +#include "core/framework/execution_providers.h" + +namespace onnxruntime { +namespace lazytensor { + +class NvtxRange { + public: + NvtxRange(const char* name); + NvtxRange(const std::string& name); + ~NvtxRange(); +}; + +// Class holding the CUDA EPs (one unique EP per device) +// shared by all sessions. +class CUDAExecutionProviderPool { + public: + static CUDAExecutionProviderPool& GetInstance() { + static CUDAExecutionProviderPool instance; + return instance; + } + + std::shared_ptr GetExecutionProvider(const int device_id) { + return cuda_execution_providers_.at(device_id); + } + + private: + CUDAExecutionProviderPool() { + Initialize(); + }; + ~CUDAExecutionProviderPool() = default; + CUDAExecutionProviderPool(const CUDAExecutionProviderPool&) = delete; + CUDAExecutionProviderPool& operator=(const CUDAExecutionProviderPool&) = delete; + void Initialize(); + + std::vector> cuda_execution_providers_; +}; + +} // namespace lazytensor +} // namespace onnxruntime +#endif diff --git a/orttraining/orttraining/lazy_tensor/debug.cc b/orttraining/orttraining/lazy_tensor/debug.cc new file mode 100644 index 0000000000000..eac07af8aeeff --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/debug.cc @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "debug.h" +#include +#include "core/common/common.h" +#include "flags.h" + +namespace onnxruntime { +namespace lazytensor { +std::string ToString(const c10::IValue& value) { + std::stringstream ss; + if (value.isTensor()) { + // Produce, e.g., Tensor(1024, 128)@cpu. + const auto& tensor = value.toTensor(); + ss << "Tensor" + << "<" << c10::toString(tensor.scalar_type()) << ">"; + if (tensor.sizes().empty()) { + } else { + ss << "("; + for (int i = 0; i < tensor.dim(); i++) { + ss << tensor.sizes()[i]; + if (i != tensor.dim() - 1) { + ss << ","; + } + } + ss << ")"; + } + ss << "@" << tensor.device(); + } else if (value.isScalar()) { + // Produce, e.g., Scalar, which is always on CPU. + ss << "Scalar<" << c10::toString(value.toScalar().type()) << ">"; + } else { + ORT_THROW("Unsupported type."); + } + return ss.str(); +} + +// Print elements in the stack. +std::string ToString(const at::ArrayRef& values) { + std::stringstream ss; + for (size_t i = 0; i < values.size(); i++) { + ss << ToString(values.at(i)); + if (i != values.size() - 1) { + ss << ", "; + } + } + return ss.str(); +} + +std::string ToString(const torch::jit::Value& value) { + auto type = value.type(); + return type->str(); +} + +std::string ToString(const torch::jit::Node& node) { + std::stringstream ss; + ss << node.kind().toDisplayString() << "("; + for (size_t i = 0; i < node.inputs().size(); i++) { + ss << ToString(*node.inputs().at(i)); + if (i != node.inputs().size() - 1) { + ss << ", "; + } + } + ss << ") -> ("; + for (size_t i = 0; i < node.outputs().size(); i++) { + ss << ToString(*node.outputs().at(i)); + if (i != node.outputs().size() - 1) { + ss << ", "; + } + } + ss << ")"; + return ss.str(); +} + +bool CompareTensor( + const at::Tensor& left, const at::Tensor& right) { + if (left.sizes() != right.sizes()) { + return false; + } + if (left.scalar_type() != right.scalar_type()) { + return false; + } + if (left.device() != right.device()) { + return false; + } + if (CheckTensorContent() && + !at::allclose(left, right, RelativeTolerance(), AbsoluteTolerance())) { + return false; + } + return true; +} + +bool CompareScalar( + const at::Scalar& left, const at::Scalar& right) { + if (left.type() != right.type()) { + return false; + } + if (CheckTensorContent()) { + if (left.isFloatingPoint()) { + return left.toDouble() == right.toDouble(); + } else if (left.isIntegral(false)) { + return left.toLong() == right.toLong(); + } else if (left.isBoolean()) { + return left.toBool() == right.toBool(); + } else { + return false; + } + } + return true; +} + +bool Compare(const c10::IValue& left, const c10::IValue& right) { + if (left.isTensor() && right.isTensor()) { + return CompareTensor(left.toTensor(), right.toTensor()); + } else if (left.isScalar() && right.isScalar()) { + return CompareScalar(left.toScalar(), right.toScalar()); + } else { + return false; + } +} + +bool CompareStack( + const torch::jit::Stack& left, const torch::jit::Stack& right) { + if (left.size() != right.size()) { + return false; + } + for (size_t i = 0; i < left.size(); i++) { + if (!Compare(left[i], right[i])) { + return false; + } + } + return true; +} +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/debug.h b/orttraining/orttraining/lazy_tensor/debug.h new file mode 100644 index 0000000000000..384a46745303f --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/debug.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +namespace onnxruntime { +namespace lazytensor { +// This function contains function for comparing values +// and printing values. They are mainly used for debugging. + +bool CompareTensor(const at::Tensor& left, const at::Tensor& right); +bool CompareScalar(const at::Scalar& left, const at::Scalar& right); +bool Compare(const c10::IValue& left, const c10::IValue& right); +bool CompareStack(const torch::jit::Stack& left, const torch::jit::Stack& right); +std::string ToString(const c10::IValue& value); +std::string ToString(const at::ArrayRef& values); +// "torch::jit::Value" is abstract symbol in torch::jit::Graph. +// It represents inputs and outputs for graph, block, and node. +// Note that the actual computation reuslt's type is c10::IValue. +std::string ToString(const torch::jit::Value& value); +std::string ToString(const torch::jit::Node& node); +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/flags.cc b/orttraining/orttraining/lazy_tensor/flags.cc new file mode 100644 index 0000000000000..1549820d58c3a --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/flags.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "flags.h" +#include +#include +#include "core/common/common.h" + +namespace onnxruntime { +namespace lazytensor { +bool IsEnvironmentVariableOne(const char* name) { + const auto flag = std::getenv(name); + if (flag == nullptr) { + return false; + } + const auto is_one = std::strcmp(flag, "1") == 0; + const auto is_zero = std::strcmp(flag, "0") == 0; + ORT_ENFORCE(is_one || is_zero, + "Must set ", name, "=0, ", name, "=1, or unset ", name); + return is_one; +} + +double GetEnvironmentVariableDoubleOrDefault(const char* name, const double default_value) { + const auto number = std::getenv(name); + if (!number) { + return default_value; + } + return std::atof(number); +} + +std::string RunType() { + const auto run_type = std::getenv("LORT_RUN_TYPE"); + if (!run_type) { + return "ort"; + } + return run_type; +} + +bool DumpInputsOutputs() { + return IsEnvironmentVariableOne("LORT_DUMP_INPUTS_OUTPUTS"); +} + +bool DumpGraph() { + return IsEnvironmentVariableOne("LORT_DUMP_GRAPH"); +} + +bool CheckBaseline() { + return IsEnvironmentVariableOne("LORT_CHECK_BASELINE"); +} + +bool DumpAtenOpHistory() { + return IsEnvironmentVariableOne("LORT_DUMP_ATEN_OP_HISTORY"); +} + +bool CheckTensorContent() { + ORT_ENFORCE(CheckBaseline(), "Must set LORT_CHECK_BASELINE=1 to check tensor content."); + return IsEnvironmentVariableOne("LORT_CHECK_TENSOR_CONTENT"); +} + +double AbsoluteTolerance() { + ORT_ENFORCE(CheckBaseline() && CheckTensorContent(), + "Do not set LORT_ABSOLUTE_TOLERANCE unless \ + LORT_CHECK_TENSOR_CONTENT and LORT_CHECK_BASELINE are set."); + return GetEnvironmentVariableDoubleOrDefault("LORT_ABSOLUTE_TOLERANCE", 1e-8); +} + +double RelativeTolerance() { + ORT_ENFORCE(CheckBaseline() && CheckTensorContent(), + "Do not set LORT_RELATIVE_TOLERANCE unless \ + LORT_CHECK_TENSOR_CONTENT and LORT_CHECK_BASELINE are set."); + return GetEnvironmentVariableDoubleOrDefault("LORT_RELATIVE_TOLERANCE", 1e-5); +} + +bool DumpOnnxFusion() { + return IsEnvironmentVariableOne("LORT_DUMP_ONNX_FUSION"); +} + +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/flags.h b/orttraining/orttraining/lazy_tensor/flags.h new file mode 100644 index 0000000000000..b849f9f9a0a3e --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/flags.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +namespace onnxruntime { +namespace lazytensor { +// This file contains environment variables that control +// the behavior of ORT as LazyTensor's backend. +// Most variables are for debug purpose. +// Example: +// LORT_CHECK_TENSOR_CONTENT=1 LORT_DUMP_GRAPH=1 +// LORT_DUMP_INPUTS_OUTPUTS=1 LORT_CHECK_BASELINE=1 +// LORT_RELATIVE_TOLERANCE=1e-3 python main.py + +// When returing true, we dump the inputs and outputs +// when ORT (and Pytorch when ORTLTCHECKBASELINE is set to 1) +// executes the subgraph. +bool DumpInputsOutputs(); +// Returns true to dump the torch::jit::Graph ORT receives +// from LazyTensor. +bool DumpGraph(); +// If returned value is true, run torch::jit::GraphExecutor +// and compare its outputs with ORT's outputs. +// Only types and shapes are compared. The user can control +// the checking mechanism. For example, set +// LORT_CHECK_TENSOR_CONTENT=1 to compare tensor elements. +// +// Related functions' dependency graph: +// CheckBaseline -> CheckTensorContent -> AbsoluteTolerance +// '---> RelativeTolerance +// bool CheckBaseline(); +std::string RunType(); +// If this function returns true, all aten ops seen by ORT +// will be printed. We also tag if these are supported or not. +bool DumpAtenOpHistory(); +// If this function returns true, check tensor's elements +// when CheckBaseline() returns true. +bool CheckTensorContent(); +// The "absolute_tol" in +// |value-expected| <= |expected| * relative_tol + absolute_tol +double AbsoluteTolerance(); +// The "relative_tol" in +// |value-expected| <= |expected| * relative_tol + absolute_tol +double RelativeTolerance(); +bool DumpOnnxFusion(); + +class DynamicSettings { + public: + static DynamicSettings& GetInstance() { + static DynamicSettings instance; + return instance; + } + DynamicSettings(DynamicSettings const&) = delete; + void operator=(DynamicSettings const&) = delete; + bool GetOnnxFusionFlag() const { + return onnx_fusion_status_; + } + void SetOnnxFusionFlag(bool status) { + onnx_fusion_status_ = status; + } + + private: + DynamicSettings() : onnx_fusion_status_(true){}; + bool onnx_fusion_status_; +}; + +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/fusion.cc b/orttraining/orttraining/lazy_tensor/fusion.cc new file mode 100644 index 0000000000000..88c6d7fe0968e --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/fusion.cc @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fusion.h" +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace lazytensor { + +struct OrtFuser { + using FusionCallback = std::function; + + torch::jit::Block* block_; + torch::jit::AliasDb* aliasDb_; + std::shared_ptr graph_; + FusionCallback callback_; + torch::jit::Symbol kind_; + bool strict_fuser_check_ = false; + + // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. + // The specific limit is a function of constant memory size, amount available + // to pass arguments, and some implementation dependence. Select a safe + // limit here. + // This limit is also applied to other devices in the fuser by default. + // Change with setInputArgLimit + size_t subgraph_arg_limit_; + + // Custom passes require kind to specified + OrtFuser( + torch::jit::AliasDb* aliasDb, + torch::jit::Block* block, + FusionCallback callback, + torch::jit::Symbol kind, + bool strict_fuser_check = false, + size_t subgraph_arg_limit = 128) + : block_(block), + aliasDb_(aliasDb), + callback_(std::move(callback)), + kind_(kind), + strict_fuser_check_(strict_fuser_check), + subgraph_arg_limit_(subgraph_arg_limit) {} + + torch::jit::value_list tensorInputs(torch::jit::Node* node) { + return filter(node->inputs(), [](torch::jit::Value* v) { + return v->type()->isSubtypeOf(*c10::TensorType::get()); + }); + } + + bool isFusable(torch::jit::Node* node) { + return callback_(this, node); + } + + bool calculatesSize(torch::jit::Node* node) { + return node->matches("aten::size(Tensor self) -> int[]"); + } + + bool allUsersAreThisConsumerOrCalcSizes(torch::jit::Node* consumer, torch::jit::Value* producer) { + auto defining_node = producer->node(); + for (auto o : defining_node->outputs()) { + for (auto u : o->uses()) { + if (u.user != consumer && !calculatesSize(u.user)) + return false; + } + } + return true; + } + + torch::jit::Graph& getSubgraph(torch::jit::Node* n) { + AT_ASSERT(n->kind() == kind_); + return *n->g(torch::jit::attr::Subgraph); + } + + void mergeFusionGroups(torch::jit::Node* consumer_group, torch::jit::Node* producer_group) { + // Now we have two fusion groups! + // Revert the fusion - place all inner nodes of producer back in the outer + // graph. + std::vector temporary_nodes; + auto producer_subgraph = &getSubgraph(producer_group); + + // Initialize a map of inner graph values to outer graph values + std::unordered_map inner_to_outer; + auto inner_inputs = producer_subgraph->inputs(); + auto outer_inputs = producer_group->inputs(); + for (const auto i : c10::irange(inner_inputs.size())) { + inner_to_outer[inner_inputs[i]] = outer_inputs[i]; + } + + // Clone all nodes + for (auto inner : producer_subgraph->nodes()) { + torch::jit::Node* outer = block_->owningGraph()->createClone( + inner, [&](torch::jit::Value* k) -> torch::jit::Value* { return inner_to_outer.at(k); }); + outer->insertBefore(producer_group); + temporary_nodes.emplace_back(outer); + auto inner_outputs = inner->outputs(); + auto outer_outputs = outer->outputs(); + for (const auto i : c10::irange(inner_outputs.size())) { + inner_to_outer[inner_outputs[i]] = outer_outputs[i]; + } + } + + // Replace uses of producer_group outputs and destroy the producer + auto subgraph_outputs = producer_subgraph->outputs(); + for (const auto i : c10::irange(subgraph_outputs.size())) { + auto outer_output = inner_to_outer.at(subgraph_outputs[i]); + producer_group->outputs()[i]->replaceAllUsesWith(outer_output); + // new producer outputs have same aliasing properties as outer_output + aliasDb_->replaceWithNewValue(producer_group->outputs()[i], outer_output); + } + producer_group->destroy(); + producer_group = + nullptr; // Just to get a clear error in case someone uses it + + // Inline the temporary nodes into the first group + auto consumer_subgraph = &getSubgraph(consumer_group); + for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); + ++it) { + torch::jit::Node* node = *it; + torch::jit::Node* merged = mergeNodeIntoGroup(consumer_group, node); + // If any of the outputs are still used then we need to add them + auto outputs = node->outputs(); + for (const auto i : c10::irange(outputs.size())) { + auto output = outputs[i]; + if (output->uses().size() == 0) + continue; + consumer_subgraph->registerOutput(merged->outputs()[i]); + auto new_output = consumer_group->addOutput(); + new_output->setType(output->type()); + output->replaceAllUsesWith(new_output); + aliasDb_->replaceWithNewValue(output, new_output); + } + node->destroy(); + } + } + + // insert a producer node into a consuming fusion group. + // DOES NOT WORK if n is a consumer of an output of the fusion group + // returns the node _inside_ the group that represents the node + torch::jit::Node* mergeNodeIntoGroup(torch::jit::Node* group, torch::jit::Node* n) { + AT_ASSERT(n->kind() != kind_); + auto& subgraph = getSubgraph(group); + // map from nodes in the surrounding graph to parameters in the fusion + // group's subgraph that correspond to them + std::unordered_map inputs_map; + AT_ASSERT(group->inputs().size() == subgraph.inputs().size()); + for (size_t i = 0; i < group->inputs().size(); ++i) { + // outer scope input -> inner scope (inside subgraph) input + inputs_map[group->inputs().at(i)] = subgraph.inputs().at(i); + } + // add n's inputs to the fusion group's input list if we don't already have + // them + // we insert tensors first because the fuser assumes that to be the case + // (as a legacy from tensors only) + torch::jit::WithInsertPoint guard(*subgraph.nodes().begin()); + for (auto input : n->inputs()) { + if (inputs_map.count(input) == 0) { + if (input->type()->isSubtypeOf(*c10::TensorType::get())) { + group->addInput(input); + // Add the corresponding input to subgraph's input list. + auto inner_input = subgraph.addInput(); + inner_input->setType(input->type()); + // Update outer-to-inner value mapping. + inputs_map[input] = inner_input; + } else if ( + (input->type()->isSubtypeOf(*c10::FloatType::get()) && + input->node()->kind() != torch::jit::prim::Constant) || + (n->kind() == torch::jit::aten::_grad_sum_to_size && + input->type()->isSubtypeOf(*c10::ListType::ofInts()))) { + group->addInput(input); + auto inner_input = subgraph.addInput(); + inner_input->setType(input->type()); + inputs_map[input] = inner_input; + } else if ( + input->type()->isSubtypeOf(*c10::IntType::get()) && + input->node()->kind() != torch::jit::prim::Constant) { + group->addInput(input); + auto inner_input = subgraph.addInput(); + inner_input->setType(input->type()); + inputs_map[input] = inner_input; + } else { + // We don't support passing in scalars as arguments to fused kernels, + // so we generally don't allow fusing tensor-scalar operations unless + // the scalar is constant. In those cases we inline the constants + // directly in the body of the fused group. + AT_ASSERT(input->node()->kind() == torch::jit::prim::Constant); + torch::jit::Node* in_const = + subgraph.createClone(input->node(), [](torch::jit::Value*) -> torch::jit::Value* { + throw std::runtime_error("unexpected input"); + }); + subgraph.insertNode(in_const); + inputs_map[input] = in_const->output(); + } + } + } + // copy n into the graph, remapping its inputs to internal nodes + torch::jit::Node* n_in_graph = subgraph.createClone( + n, [&](torch::jit::Value* k) -> torch::jit::Value* { return inputs_map[k]; }); + // if n's outputs are already inputs to the fusion group, + // we need to remove them because n is now inside the fusion group. + // + // i.e., + // x = f(w); group(x, y, z) becomes group(w, y, z). + // x, y, z = f(w); group(x, y, z) becomes group(w). + // + // remapping nodes that used the input to the newly-merged node + // n is not an input when the fusion group is empty + auto inputs = group->inputs(); + for (size_t i = 0; i < n->outputs().size(); ++i) { + auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]); + if (it != inputs.end()) { + size_t p = it - inputs.begin(); + group->removeInput(p); + subgraph.inputs()[p]->replaceAllUsesWith(n_in_graph->outputs()[i]); + subgraph.eraseInput(p); + } + } + return subgraph.insertNode(n_in_graph); + } + + // turn consumer node n into a fusion group with just n inside + // to prepare for fusion and replace uses of n with the new group + torch::jit::Node* createSingletonFusionGroup(torch::jit::Node* n) { + auto group = block_->owningGraph()->createWithSubgraph(kind_); + // propagate position information for the new node so we can always + // have a valid mapping + group->insertBefore(n); + torch::jit::Node* mergedNode = mergeNodeIntoGroup(group, n); + // Now n's outputs should be generated by the new node (aka mergedNode) + // in the fusion group. Let's connect mergedNode to the outer graph. + for (size_t i = 0; i < mergedNode->outputs().size(); ++i) { + // Connect the i-th inner output to outer graph. + getSubgraph(group).registerOutput(mergedNode->output(i)); + auto new_outer_output = group->addOutput(); + // Copy metadata from old outer output to new outer output. + new_outer_output->copyMetadata(n->output(i)); + aliasDb_->replaceWithNewValue(n->output(i), new_outer_output); + } + // Now group is a single-op subgraph containing the clone of n. + AT_ASSERT(n->outputs().size() == group->outputs().size()); + n->replaceAllUsesWith(group); + n->destroy(); + return group; + } + + at::optional tryFuse(torch::jit::Node* consumer, torch::jit::Value* producer) { + // this handles cases where producer can be moved _into_ the fusion group of + // consumer. + // TODO: extend to fusion of consumer into _producer's_ fusion blob + // if the consumer allInputsAreThisProducer(consumer,producer) + // we can move the consumer up into the producer. + // but this requires better handling of merging fusion groups so it is not + // done now + bool shouldFuse = isFusable(producer->node()) && + // Rearrange nodes such that all uses of producer are after the + // consumer. Fusion will rewrite those later uses to use the version of + // producer generated by the fused blob. In this case, producer becomes + // an output of the fusion group. + aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer); + + if (!shouldFuse) { + return at::nullopt; + } + + if ((consumer->inputs().size() + consumer->outputs().size() + + producer->node()->inputs().size() + + producer->node()->outputs().size()) > subgraph_arg_limit_) { + return at::nullopt; + } + + auto group = consumer; + if (consumer->kind() != kind_) { + group = createSingletonFusionGroup(consumer); + } + + if (producer->node()->kind() == kind_) { + mergeFusionGroups(group, producer->node()); + return group; + } + // AT_ASSERT(producer->node()->outputs().size() == 1); + torch::jit::Node* merged = mergeNodeIntoGroup(group, producer->node()); + // remaining uses of this producer can occur because we allow + // fusion in cases where uses remain after the consumer + // if these exist, re-route them to the version of producer + // created in FusionGroup + size_t i = -1; + for (auto output : producer->node()->outputs()) { + ++i; + if (output->uses().size() == 0) { + continue; + } + getSubgraph(group).registerOutput(merged->outputs()[i]); + torch::jit::Value* new_output = group->addOutput(); + new_output->copyMetadata(new_output); + aliasDb_->replaceWithNewValue(output, new_output); + output->replaceAllUsesWith(new_output); + } + producer->node()->destroy(); + return group; + } + + torch::jit::value_list sortReverseTopological(torch::jit::ArrayRef inputs) { + torch::jit::value_list result; + for (auto i : inputs) { + if (i->node()->owningBlock() == block_) { + result.push_back(i); + } + } + // Sort in reverse topological order + std::sort(result.begin(), result.end(), [&](torch::jit::Value* a, torch::jit::Value* b) { + return a->node()->isAfter(b->node()); + }); + return result; + } + + // returns where to continue scanning, and whether any fusion was made + std::pair scanNode(torch::jit::Node* consumer) { + if (isFusable(consumer)) { + // handle inputs in reverse topological order as well... + // otherwise in f(a,a+b) it will appear a is used twice if we consider + // the f-a fusion before the f-(a+b) fusion first. + auto inputs = sortReverseTopological(consumer->inputs()); + for (auto producer : inputs) { + auto fusion_group = tryFuse(consumer, producer); + if (fusion_group) { + // after fusion, consumer moves into a FusionGroup, so inputs is no + // longer valid so we rescan the new FusionGroup for more fusions... + return std::make_pair(fusion_group.value()->reverseIterator(), true); + } + } + } + return std::make_pair(++consumer->reverseIterator(), false); + } + + void optimizeFusedGraphs() { + for (torch::jit::Node* node : block_->nodes()) { + if (node->kind() != torch::jit::prim::FusionGroup) { + continue; + } + auto subgraph = node->g(torch::jit::attr::Subgraph); + EliminateDeadCode(subgraph); + EliminateCommonSubexpression(subgraph); + ConstantPooling(subgraph); + } + } + + void run() { + // Run the pass until no changes are made. + // This is necessary, because the algorithm can miss out on certain fusion + // opportunities if ran only once. Consider this graph: + // + // %1 = f(...) + // %2 = g(%1) + // %3 = h(%1) + // %4 = l(%3) + // return (%4, %2) + // + // where f, g, h, l are simple map ops. + // The first iteration will fuse %4 and %3, and see that %1 is an input, but + // can't be fused, because it has a different use before the fusion group + // in our topological ordering. Then, %2 will be considered, and fused with + // %1. If we do another iteration, the algorithm will consider the fusion of + // these two groups and fix the situation. + bool any_changed = true; + while (any_changed) { + any_changed = false; + for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool changed; + std::tie(it, changed) = scanNode(*it); + any_changed |= changed; + } + } + + optimizeFusedGraphs(); + + for (torch::jit::Node* node : block_->nodes()) { + for (torch::jit::Block* sub_block : node->blocks()) { + OrtFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_) + .run(); + } + } + } +}; + +void OrtFuseGraph( + std::shared_ptr& graph, + const std::function& fn, + torch::jit::Symbol kind, + size_t arg_limit) { + torch::jit::AliasDb db(graph); + auto g = OrtFuser( + &db, + graph->block(), + [=](OrtFuser* gf, torch::jit::Node* n) { return fn(n) || n->kind() == kind; }, + kind, false, arg_limit); + + g.run(); + torch::jit::Lint(&db); +} + +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/fusion.h b/orttraining/orttraining/lazy_tensor/fusion.h new file mode 100644 index 0000000000000..599a7d7784df6 --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/fusion.h @@ -0,0 +1,13 @@ +#include + +namespace onnxruntime { +namespace lazytensor { + +void OrtFuseGraph( + std::shared_ptr& graph, + const std::function& fn, + torch::jit::Symbol kind, + size_t arg_limit = std::numeric_limits::max()); + +} // namespace lazytensor +} // namespace onnxruntime diff --git a/orttraining/orttraining/lazy_tensor/register.cc b/orttraining/orttraining/lazy_tensor/register.cc new file mode 100644 index 0000000000000..f0b5ddfd3e7ae --- /dev/null +++ b/orttraining/orttraining/lazy_tensor/register.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +// Instead of torch/torch.h, include torch torch/extension.h +// for extra Python headers. +#include +#include +#include "torch/csrc/jit/passes/shape_analysis.h" +#include +#include "accelerator.h" +#include "flags.h" +#include "fusion.h" +#include "core/common/logging/logging.h" + +namespace onnxruntime { +namespace lazytensor { +// This function register a new torch::jit::Symbol, ort::graph, +// in Pytorch's JIT executor. +// A custom callable, is registered to be called when Pytorch's JIT +// executor encountering this symbol. +void register_ort_as_torch_jit_executor() { + // Pytorch's JIT symbol to be execute by ORT. + const auto accelerator_symbol = + torch::jit::Symbol::fromQualString("ort::graph"); + // First, register a pass that will coalesce supported consecutive operators + // into a single symbol (it contains a subgraph). Encountering an unsupported + // operator will result two separated symbols (i.e., two independent sub-graphs). + // Note that torch::jit::Symbol is an anology of NodeProto in ONNX. + // + // TODO: Allow single-op fusion in Pytorch so ORT can receive single-op sub-graph. + // We should extend OrtFuseGraph and OrtFuser to fuse single-op into ort::graph. + torch::jit::RegisterPass pass([accelerator_symbol](std::shared_ptr& g) { + if (!DynamicSettings::GetInstance().GetOnnxFusionFlag()) { + if (DumpOnnxFusion()) { + std::cout << "[No fusion]\n" + << *g; + } + return; + } + + if (DumpOnnxFusion()) { + std::cout << "[Before fusion]\n" + << *g; + } + + std::shared_ptr new_subgraph_with_shapes(g->copyUnique().release()); + + OrtFuseGraph(g, Accelerator::Supported, accelerator_symbol); + if (DumpOnnxFusion()) { + std::cout << "[After fusion]\n" + << *g; + } + }); + + // Define a function to generate actual computation code for a + // symbol (type: torch::jit::Node). + torch::jit::OperationCreator op_creator = + [](const torch::jit::Node* node) -> torch::jit::Operation { + // Construct an accelerator instance. It's responsible + // for executing the "node". Note that the "node" is a sub-graph. + auto accelerator = std::make_shared(node); + return [accelerator](torch::jit::Stack& stack) { + accelerator->Run(stack); + }; + }; + + // Tell Pytorch to use "op_creator" to execute "accelerator_symbol" + // when executing JIT graph. + torch::jit::RegisterOperators op({torch::jit::Operator( + accelerator_symbol, op_creator, + c10::AliasAnalysisKind::PURE_FUNCTION)}); +} +} // namespace lazytensor +} // namespace onnxruntime + +namespace onnxruntime { +namespace python { + +void addObjectMethodsForLazyTensor(pybind11::module& m) { + LOGS_DEFAULT(INFO) << "pybind11 module init for lazy tensor"; + m.def( + "register_ort_as_torch_jit_executor", + []() { + onnxruntime::lazytensor::register_ort_as_torch_jit_executor(); + }); +} + +} // namespace python +} // namespace onnxruntime diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 8917dd2a623e4..6b84959309c39 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -39,6 +39,9 @@ void addGlobalMethods(py::module& m, Environment& env); void addObjectMethods(py::module& m, Environment& env, ExecutionProviderRegistrationFn ep_registration_fn); void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); void addObjectMethodsForEager(py::module& m); +#ifdef ENABLE_LAZY_TENSOR +void addObjectMethodsForLazyTensor(py::module& m); +#endif void InitArray(); @@ -159,7 +162,7 @@ void ORTTrainingPythonEnv::AddExecutionProvider(const std::string& provider_type std::move(execution_provider)}); } -void ORTTrainingPythonEnv::RegisterExtExecutionProviderInfo(const std::string& provider_type, +void ORTTrainingPythonEnv::RegisterExtExecutionProviderInfo(const std::string& provider_type, const std::string& provider_lib_path, const ProviderOptions& default_options){ ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}}); @@ -309,7 +312,7 @@ void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::ve PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.doc() = "pybind11 stateful interface to ORTTraining"; RegisterExceptions(m); - + Environment& env = GetTrainingORTEnv(); addGlobalMethods(m, env); addObjectMethods(m, env, ORTTrainingRegisterExecutionProviders); @@ -325,25 +328,29 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { LOGS(default_logger, WARNING) << "Init provider bridge failed."; } #endif - + addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders); #ifdef ENABLE_EAGER_MODE addObjectMethodsForEager(m); #endif - - m.def("_register_provider_lib", [](const std::string& name, + +#ifdef ENABLE_LAZY_TENSOR + addObjectMethodsForLazyTensor(m); +#endif + + m.def("_register_provider_lib", [](const std::string& name, const std::string& provider_shared_lib_path, const ProviderOptions& default_options) { GetTrainingEnv().RegisterExtExecutionProviderInfo(name, provider_shared_lib_path, default_options); }); m.def( - "get_available_providers", []() -> const std::vector& { + "get_available_providers", []() -> const std::vector& { return GetTrainingEnv().GetAvailableTrainingExecutionProviderTypes(); }, "Return list of available Execution Providers in this installed version of Onnxruntime. " "The order of elements represents the default priority order of Execution Providers " "from highest to lowest."); - + m.def("clear_training_ep_instances", []() -> void { ort_training_env->ClearExecutionProviderInstances(); }, diff --git a/orttraining/orttraining/python/training/experimental/exporter.py b/orttraining/orttraining/python/training/experimental/exporter.py new file mode 100644 index 0000000000000..8c5ccd1119576 --- /dev/null +++ b/orttraining/orttraining/python/training/experimental/exporter.py @@ -0,0 +1,26 @@ +import torch +import torch.onnx.symbolic_helper +import torch.onnx.utils + + +def _export_jit_graph_to_onnx_model_proto(graph: torch._C.Graph, operator_export_type: int): + from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_operator_export_type, _set_opset_version + + _set_onnx_shape_inference(True) + _set_operator_export_type(operator_export_type) + torch._C._jit_pass_run_decompositions(graph) + graph = torch.onnx.utils._optimize_graph(graph, operator_export_type, params_dict={}) + proto, _, _, _ = graph._export_onnx( + {}, + torch.onnx._globals.GLOBALS.export_onnx_opset_version, + {}, + False, + operator_export_type, + False, + False, + {}, + True, + "", + {}, + ) + return proto diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 3b03829cc25e8..0c8a24d20e3fc 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -2,13 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import sys import argparse +import logging +import sys from _test_commons import run_subprocess -import logging - logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger("ORTModuleTests") diff --git a/orttraining/orttraining/test/python/orttraining_test_lort.py b/orttraining/orttraining/test/python/orttraining_test_lort.py new file mode 100644 index 0000000000000..8202ceef7445c --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_lort.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Ask lazy backend to use Pytorch's JIT as +# lazy backend's executor. +from torch._lazy.ts_backend import init as init_ts_backend + +init_ts_backend() + +# Handle ORT dependencies. +import onnxruntime as ort +from onnxruntime.capi import _pybind_state as C + +# Set up ORT as torch.jit's sub-executor. +C.register_ort_as_torch_jit_executor() + +# Make computation deterministic. +torch.manual_seed(42) +ort.set_seed(1) + + +class TestOrtLazyTensor(unittest.TestCase): + def test_elementwise_model(self): + def run_elementwise_model(): + # A function to test. + def elementwise_model(x): + w = x.relu() + y = w * w + 1.5 + z = y + x + p = z * x + q = p.relu() + return q + + def run(fun, device, x): + x = torch.tensor(x, device=device, dtype=torch.float32).requires_grad_() + y = fun(x) + y.sum().backward() + return x, y, x.grad + + # Baseline. + x, y, g_x = run(elementwise_model, "cpu", [-1.0, 2.0]) + # ORT result. + x_new, y_new, g_x_new = run(elementwise_model, "lazy", [-1.0, 2.0]) + + torch.testing.assert_close(x.to("lazy"), x_new) + torch.testing.assert_close(y.to("lazy"), y_new) + torch.testing.assert_close(g_x.to("lazy"), g_x_new) + + for _ in range(5): + run_elementwise_model() + + def test_mnist_model(self): + def run_mnist_model(): + class MNISTModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) + self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) + self.fc1 = nn.Linear(9216, 128, bias=False) + self.fc2 = nn.Linear(128, 10, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + def run(model, device, x, y): + for param in model.parameters(): + param.grad = None + model.to(device) + x = x.to(device) + y = y.to(device) + output = model(x) + loss = F.nll_loss(output, y) + # return loss + loss.backward() + return loss, (param.grad for param in model.parameters()) + + x = torch.rand((64, 1, 28, 28), dtype=torch.float32) + y = torch.randint(0, 9, (64,), dtype=torch.int64) + model = MNISTModel() + + # Baseline. + loss, grads = run(model, "cpu", x, y) + # ORT result. + loss_new, grads_new = run(model, "lazy", x, y) + + print(f"MNIST loss: {loss} (pytorch), {loss_new} (ort).") + torch.testing.assert_close(loss.to("lazy"), loss_new, rtol=1e-2, atol=1e-5) + for g, g_new in zip(grads, grads_new): + torch.testing.assert_close(g.to("lazy"), g_new) + + for _ in range(5): + run_mnist_model() + + +if __name__ == "__main__": + # For a specific model, the first 1 or 2 runs of Pytorch + # JIT is actual eager mode. As a Pytorch JIT sub-executor, + # ORT won't be unless we run multiple times. Thus, in each + # test function, we repeat their core test function multiple times. + # Here we repeat 5 times because we want to our test similar to + # training loop. + # TODO: we should force torch.jit executor to use ORT at the first run. + unittest.main() diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4e9e268a0d8bf..62c3c20796873 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -613,7 +613,14 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument( "--code_coverage", action="store_true", help="Generate code coverage when targetting Android (only)." ) + + # lazy tensor support. + parser.add_argument( + "--enable_lazy_tensor", action="store_true", help="Enable use ORT as backend in Pytorch LazyTensor." + ) + parser.add_argument("--ms_experimental", action="store_true", help="Build microsoft experimental operators.") + # eager mode parser.add_argument("--build_eager_mode", action="store_true", help="Build ONNXRuntime micro-benchmarks.") parser.add_argument( @@ -906,6 +913,7 @@ def generate_build_tree( "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_EAGER_MODE=" + ("ON" if args.build_eager_mode else "OFF"), + "-Donnxruntime_ENABLE_LAZY_TENSOR=" + ("ON" if args.enable_lazy_tensor else "OFF"), "-Donnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS=" + ("ON" if args.enable_external_custom_op_schemas else "OFF"), "-Donnxruntime_ENABLE_CUDA_PROFILING=" + ("ON" if args.enable_cuda_profiling else "OFF"), @@ -1195,7 +1203,7 @@ def generate_build_tree( else: add_default_definition(cmake_extra_defines, "onnxruntime_PYBIND_EXPORT_OPSCHEMA", "OFF") - if args.build_eager_mode: + if args.build_eager_mode or args.enable_lazy_tensor: import torch cmake_args += ["-Donnxruntime_PREBUILT_PYTORCH_PATH=%s" % os.path.dirname(torch.__file__)] @@ -1767,7 +1775,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): # Adding the torch lib path for loading DLLs for onnxruntime in eager mode # This works for Python 3.7 and below, and doesn't work for Python 3.8+ # User will need to import torch before onnxruntime and it will work for all versions - if args.build_eager_mode and is_windows(): + if (args.build_eager_mode or args.enable_lazy_tensor) and is_windows(): import torch dll_path_list.append(os.path.join(os.path.dirname(torch.__file__), "lib")) diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml index e0e36e501f469..7756efc76dd22 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml @@ -7,7 +7,7 @@ resources: ref: a8099af1b3e25f0489717ad9c4f9a2e25a8c5b36 jobs: -- job: Linux_Build +- job: BuildAndTestEagerMode timeoutInMinutes: 120 workspace: clean: all @@ -17,10 +17,6 @@ jobs: clean: true submodules: recursive - - task: NodeTool@0 - inputs: - versionSpec: '12.16.3' - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_eager_cpu @@ -50,7 +46,6 @@ jobs: --parallel \ --build_eager_mode --enable_training --build_wheel --skip_test workingDirectory: $(Build.SourcesDirectory) - - task: CmdLine@2 displayName: 'install ortmodule extension and test' inputs: @@ -75,4 +70,75 @@ jobs: --build_eager_mode --enable_training --build_wheel --test" workingDirectory: $(Build.SourcesDirectory) - - template: templates/clean-agent-build-directory-step.yml \ No newline at end of file + - template: templates/clean-agent-build-directory-step.yml + +# This pipeline builds the latest PyTorch commit from source +# and use it in ORT tests. See Dockerfile.manylinux2014_lort_cpu +# for the installation steps. Idally, we should only use one pipeline +# for eager mode and LazyTensor, but we split them due to recent +# PyTorch's breaking changes. +# +# TODO: once ORT eager mode can run with latest PyTorch commit, we +# should +# 1. Set --build_eager_mode when running build.py in the +# first "task" below. +# 2. Copy the second "task" above as the third task below. +- job: BuildAndTestLazyTensor + timeoutInMinutes: 120 + workspace: + clean: all + pool: Linux-CPU-2019 + steps: + - checkout: self + clean: true + submodules: recursive + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_lort_cpu + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildlort + + - task: CmdLine@2 + displayName: 'Build ORT for Python 3.9' + inputs: + script: | + docker run --rm \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuildlort \ + /opt/python/cp39-cp39/bin/python3.9 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel \ + --enable_lazy_tensor --enable_training --build_wheel --skip_test \ + workingDirectory: $(Build.SourcesDirectory) + + - task: CmdLine@2 + displayName: 'Test LORT with Python 3.9' + inputs: + script: | + docker run --rm \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuildlort \ + bash -c " + export LORT_CHECK_BASELINE=1 && \ + export LORT_DUMP_GRAPH=1 && \ + export LORT_DUMP_ATEN_OP_HISTORY=1 && \ + export PYTHONPATH=/build/Release && \ + /opt/python/cp39-cp39/bin/python3.9 -m pip install /build/Release/dist/*.whl && \ + /opt/python/cp39-cp39/bin/python3.9 /onnxruntime_src/orttraining/orttraining/test/python/orttraining_test_lort.py" + workingDirectory: $(Build.SourcesDirectory) + condition: succeededOrFailed() + + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_lort_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_lort_cpu new file mode 100644 index 0000000000000..43d51fb1d6fc7 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_lort_cpu @@ -0,0 +1,10 @@ +FROM quay.io/pypa/manylinux2014_x86_64:latest + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps_lort.sh && rm -rf /tmp/scripts + +ARG BUILD_UID=1002 +ARG BUILD_USER=onnxruntimedev +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh new file mode 100755 index 0000000000000..9274b1e8278f9 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -0,0 +1,110 @@ +#!/bin/bash +set -e -x + +# Development tools and libraries +yum -y install \ + graphviz + +# Download a file from internet +function GetFile { + local uri=$1 + local path=$2 + local force=${3:-false} + local download_retries=${4:-5} + local retry_wait_time_seconds=${5:-30} + + if [[ -f $path ]]; then + if [[ $force = false ]]; then + echo "File '$path' already exists. Skipping download" + return 0 + else + rm -rf $path + fi + fi + + if [[ -f $uri ]]; then + echo "'$uri' is a file path, copying file to '$path'" + cp $uri $path + return $? + fi + + echo "Downloading $uri" + # Use aria2c if available, otherwise use curl + if command -v aria2c > /dev/null; then + aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" + else + curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail + fi + + return $? +} + +os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) + +SYS_LONG_BIT=$(getconf LONG_BIT) +mkdir -p /tmp/src +GLIBC_VERSION=$(getconf GNU_LIBC_VERSION | cut -f 2 -d \.) + +DISTRIBUTOR=$(lsb_release -i -s) + +if [[ "$DISTRIBUTOR" = "CentOS" && $SYS_LONG_BIT = "64" ]]; then + LIBDIR="lib64" +else + LIBDIR="lib" +fi + +cd /tmp/src + +echo "Installing azcopy" +mkdir -p /tmp/azcopy +GetFile https://aka.ms/downloadazcopy-v10-linux /tmp/azcopy/azcopy.tar.gz +tar --strip 1 -xf /tmp/azcopy/azcopy.tar.gz -C /tmp/azcopy +cp /tmp/azcopy/azcopy /usr/bin + +echo "Installing Ninja" +GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz +tar -zxf ninja-linux.tar.gz +cd ninja-1.10.0 +cmake -Bbuild-cmake -H. +cmake --build build-cmake +mv ./build-cmake/ninja /usr/bin + +echo "Installing Node.js" +GetFile https://nodejs.org/dist/v16.14.2/node-v16.14.2-linux-x64.tar.gz /tmp/src/node-v16.14.2-linux-x64.tar.gz +tar --strip 1 -xf /tmp/src/node-v16.14.2-linux-x64.tar.gz -C /usr + +echo "Installing gradle" +cd /tmp/src +GetFile https://downloads.gradle-dn.com/distributions/gradle-6.3-bin.zip /tmp/src/gradle-6.3-bin.zip +unzip /tmp/src/gradle-6.3-bin.zip +mv /tmp/src/gradle-6.3 /usr/local/gradle + +if ! [ -x "$(command -v protoc)" ]; then + source ${0/%install_deps_lort\.sh/..\/install_protobuf.sh} +fi + +export ONNX_ML=1 +export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" + +cd /usr/local +echo "Cloning Pytorch" +git clone --recursive https://github.com/pytorch/pytorch.git +cd pytorch +echo "Installing Pytorch requirements" +/opt/python/cp39-cp39/bin/python3.9 -m pip install -r requirements.txt +/opt/python/cp39-cp39/bin/python3.9 -m pip install flatbuffers cerberus h5py onnx +echo "Building and installing Pytorch" +VERBOSE=1 BUILD_LAZY_TS_BACKEND=1 /opt/python/cp39-cp39/bin/python3.9 setup.py develop +/opt/python/cp39-cp39/bin/python3.9 -c "import torch; print(f'Installed Pytorch: {torch.__version__}')" + +echo "Installing valgrind" +cd /tmp/src +GetFile 'https://sourceware.org/pub/valgrind/valgrind-3.16.1.tar.bz2' /tmp/src/valgrind-3.16.1.tar.bz2 +tar -jxvf valgrind-3.16.1.tar.bz2 +cd valgrind-3.16.1 +./configure --prefix=/usr --libdir=/usr/lib64 --enable-only64bit --enable-tls +make -j$(getconf _NPROCESSORS_ONLN) +make install + +cd / +rm -rf /tmp/src