diff --git a/mlir-tensorrt/README.md b/mlir-tensorrt/README.md index 3916be7b0..206d1e32e 100644 --- a/mlir-tensorrt/README.md +++ b/mlir-tensorrt/README.md @@ -23,7 +23,7 @@ We currently support only building on Linux x86 systems. We support building several different ways (only via CMake) depending on use-case. In each case, the LLVM-Project version that we are currently aligned to is -given in `build_tools/cmake/LLVMCommit.txt`. +given in `build_tools/cmake/LLVMCommit.cmake`. Note that currently we provide an LLVM patch which essentially cherry-picks the bug fixes from [this open MLIR PR](https://github.com/llvm/llvm-project/pull/91524). @@ -82,7 +82,7 @@ git clone https://github.com/llvm/llvm-project.git llvm-project # Checkout the right commit. Of course, you may try # a newer commit or your own modified LLVM-Project. cd llvm-project -git checkout $(cat build_tools/cmake/LLVMCommit.cmake | grep -Po '(?<=").*(?=")') +git checkout $(cat ../build_tools/cmake/LLVMCommit.cmake | grep -Po '(?<=").*(?=")') # Apply patch from llvm-project PR 91524 git apply ../build_tools/llvm-project.patch diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 11ae93519..22c87c56e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -32,6 +32,8 @@ #include #include +#include "cuda_runtime.h" + #ifdef __cplusplus extern "C" { #endif @@ -93,6 +95,50 @@ static inline bool mtrtDeviceIsNull(MTRT_Device device) { return !device.ptr; } /// arguments are optional in functions below. static inline MTRT_Device mtrtDeviceGetNull() { return MTRT_Device{nullptr}; } +//===----------------------------------------------------------------------===// +// MTRT_GpuAllocator +//===----------------------------------------------------------------------===// + +// Function pointer types for the allocate and deallocate callbacks. +typedef void *(*AllocateFunc)(void *self, uint64_t size, uint64_t alignment, uint32_t flags, cudaStream_t* stream); +typedef bool (*DeallocateFunc)(void *self, void *memory, cudaStream_t* stream); + +typedef struct MTRT_GpuAllocator { + void *ptr; // Pointer to the implementation (PyGpuAllocatorTrampoline in our + // case.) + // Function pointers to methods. + AllocateFunc allocate; + DeallocateFunc deallocate; +} MTRT_GpuAllocator; + +//===----------------------------------------------------------------------===// +// MTRT_OutputAllocator +//===----------------------------------------------------------------------===// + +// Function pointer types for the allocate and deallocate callbacks. +typedef void (*SetGpuAllocator)(void *self, MTRT_GpuAllocator gpuAllocator); +typedef void (*SetTensorName)(void *self, const char *tensorName); +typedef void (*SetCurrentMemory)(void *self, void *currentMemory); +typedef void (*SetOutputSize)(void *self, const int64_t outputSize); +typedef void *(*ReallocateOutputAsync)(void *self, char const *tensorName, + void *currentMemory, uint64_t size, + uint64_t alignment, + cudaStream_t *stream); +typedef void (*NotifyShape)(void *self, char const *tensorName, const int64_t *dims, + int64_t nbDims); + +typedef struct MTRT_OutputAllocator { + void *ptr; // Pointer to the implementation (PyOutputAllocatorTrampoline in + // our case.) + // Function pointers to methods. + SetGpuAllocator setGpuAllocator; + SetTensorName setTensorName; + SetCurrentMemory setCurrentMemory; + SetOutputSize setOutputSize; + ReallocateOutputAsync reallocateOutputAsync; + NotifyShape notifyShape; +} MTRT_OutputAllocator; + //===----------------------------------------------------------------------===// // MTRT_MemRefValue //===----------------------------------------------------------------------===// @@ -170,6 +216,9 @@ typedef struct MTRT_MemRefValueInfo { MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueGetInfo(MTRT_MemRefValue memref, MTRT_MemRefValueInfo *info); +MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueSetOutputAllocator( + MTRT_MemRefValue memrefValue, MTRT_OutputAllocator pyOutputAllocator); + /// Create DL Managed tensor from MemRefValue. MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueGetDLPackManagedTensor( MTRT_MemRefValue memrefValue, MTRT_DLPackManagedTensor *outTensor); @@ -360,7 +409,7 @@ typedef struct MTRT_RuntimeSession { /// constant data. Therefore the Executable must outlive the RuntimeSession. MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionCreate( MTRT_RuntimeSessionOptions options, MTRT_Executable executable, - MTRT_RuntimeSession *result); + MTRT_GpuAllocator allocator, MTRT_RuntimeSession *result); /// Destory the session. This does not destroy the associated Executable, which /// may be shared among many sessions. @@ -372,6 +421,10 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) { return !session.ptr; } +MLIR_CAPI_EXPORTED MTRT_Status mtrtAddMemRefOutputAllocatorSessionRegistry( + MTRT_MemRefValue memrefValue, + MTRT_OutputAllocator pyOutputAllocator); + /// Using `session`, execute the pubic function with the specified name. /// The `inArgs` and `outArgs` are arrays for input arguments and destination /// arguments, respectively. Input arguments may be MemRefs or scalars, but diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index d3672c149..d48b80d51 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -667,6 +667,12 @@ class MemRefValue : public RuntimeValue { return v->getKind() == Kind::MemRef; } + void setOutputAllocator(OutputAllocator* _outputAllocator) { + outputAllocator = _outputAllocator; + } + + OutputAllocator *getOutputAllocator() { return outputAllocator; } + const std::optional &getScalarType() const { return scalarType; } RuntimeClient *getClient() { return client; } @@ -691,6 +697,7 @@ class MemRefValue : public RuntimeValue { /// address. std::optional device; std::optional scalarType{}; + OutputAllocator *outputAllocator{nullptr}; }; //===----------------------------------------------------------------------===// @@ -867,7 +874,9 @@ class RuntimeSession { sol::state state, std::unique_ptr pinnedMemoryAllocator, std::unique_ptr allocTracker, - std::unique_ptr resourceTracker); + std::unique_ptr resourceTracker, + std::unique_ptr outputAllocatorTracker, + std::unique_ptr gpuAllocator); ExecutableView getExecutable() const { return executable; } @@ -881,6 +890,12 @@ class RuntimeSession { ResourceTracker &getResourceTracker() { return *resourceTracker; } + OutputAllocatorTracker &getOutputAllocatorTracker() { + return *outputAllocatorTracker; + } + + GpuAllocator &getGpuAllocator() { return *gpuAllocator; } + private: RuntimeSessionOptions options; ExecutableView executable; @@ -888,7 +903,8 @@ class RuntimeSession { std::unique_ptr pinnedMemoryAllocator; std::unique_ptr allocTracker; std::unique_ptr resourceTracker; - + std::unique_ptr outputAllocatorTracker; + std::unique_ptr gpuAllocator; sol::state state; }; @@ -970,6 +986,14 @@ class RuntimeClient { return pinnedMemoryAllocator; } + void addOutputAllocator(std::unique_ptr outputAllocator) { + outputAllocators.emplace_back(std::move(outputAllocator)); + } + + OutputAllocator* getLastOutputAllocator() { + return outputAllocators.back().get(); + } + private: RuntimeClient(llvm::SmallVector> devices) : devices(std::move(devices)) {} @@ -978,6 +1002,7 @@ class RuntimeClient { PinnedMemoryAllocator pinnedMemoryAllocator; AllocTracker allocTracker; ResourceTracker resourceTracker; + std::vector> outputAllocators; }; //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h index b5fed9c3d..9dd689de8 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h @@ -37,6 +37,8 @@ void registerLuaRuntimeMethods(lua_State *state, const RuntimeSessionOptions &options, PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, - ResourceTracker *resourceTracker); + ResourceTracker *resourceTracker, + OutputAllocatorTracker *outputAllocatorTracker, + GpuAllocator *allocator); } // namespace mlirtrt::runtime diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index f39eabd7b..e7251580f 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -36,7 +36,8 @@ namespace mlirtrt::runtime { /// `main` function. It is assumed that `main` takes no arguments and returns an /// integer result (which is returned if the execution is successful). /// TODO: this should take a handle to a function for streaming output/errors. -StatusOr runExecutorLuaScript(std::string_view luaScript); +StatusOr runExecutorLuaScript(std::string_view luaScript, + GpuAllocator *allocator); /// Synchronously run a serialized executor Executable one time. An `Executable` /// is essentially a Lua script packaged with metadata and serialized constants @@ -48,12 +49,15 @@ StatusOr runExecutorLuaScript(std::string_view luaScript); /// execution is successful). /// TODO: this should take a handle to a function for /// streaming output/errors. -StatusOr runExecutorExecutable(std::unique_ptr executable); +StatusOr +runExecutorExecutable(std::unique_ptr executable, + std::unique_ptr allocator); /// Create an execution state. This will setup a Lua environment and invoke /// global initialization. StatusOr> createRuntimeSessionWithLuaBackend(ExecutableView executable, + std::unique_ptr allocator, const RuntimeSessionOptions &options); /// Set the primary stream for the loaded executable to use. diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h index 37d8de629..54655ddf7 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h @@ -37,7 +37,8 @@ class ResourceTracker; /// Lua state. void registerExecutorTensorRTModuleLuaRuntimeMethods( lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, ResourceTracker *resourceTracker); + AllocTracker *allocTracker, ResourceTracker *resourceTracker, + OutputAllocatorTracker *outputAllocatorTracker, GpuAllocator *allocator); } // namespace mlirtrt::runtime diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h index 180dbf09e..054bbcf04 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h @@ -32,6 +32,139 @@ namespace mlirtrt { struct EventPool; +//===----------------------------------------------------------------------===// +// GpuAllocator and CustomTensorRTAllocator +//===----------------------------------------------------------------------===// + +class GpuAllocator { +public: + GpuAllocator() = default; + virtual ~GpuAllocator() = default; + virtual void *allocate(uint64_t const size, uint64_t const alignment, + uint32_t flags, cudaStream_t* stream) { + return nullptr; + } + virtual bool deallocate(void *const memory, + cudaStream_t* stream) { + return false; + } +}; + +class CustomTensorRTAllocator : public GpuAllocator { +public: + CustomTensorRTAllocator() = default; + ~CustomTensorRTAllocator() = default; + void *allocate(uint64_t const size, uint64_t const alignment, uint32_t flags, + cudaStream_t* stream) override; + bool deallocate(void *const memory, + cudaStream_t* stream) override; +}; + +//===----------------------------------------------------------------------===// +// OutputAllocator and CustomTensorRTOuputAllocator +//===----------------------------------------------------------------------===// + +//! +//! Class to allocate memory for outputs with data-dependent shapes. The sizes +//! of those are unknown so pre-allocation is not possible. +//! +class OutputAllocator { +public: + virtual ~OutputAllocator() = default; + virtual void setGpuAllocator(GpuAllocator* gpuAllocator) = 0; + virtual void setTensorName(const char *tensorName) = 0; + virtual void setCurrentMemory(void *currentMemory) = 0; + virtual void setOutputSize(const int64_t outputSize) = 0; + virtual void *reallocateOutputAsync(char const *tensorName, + void *currentMemory, uint64_t size, + uint64_t alignment, + cudaStream_t * /*stream*/) = 0; + virtual void notifyShape(char const *tensorName, const int64_t *dims, + int64_t nbDims) = 0; +}; + +class CustomTensorRTOuputAllocator : public OutputAllocator { +public: + CustomTensorRTOuputAllocator() = default; + ~CustomTensorRTOuputAllocator() { + if (mOutputPtr != nullptr) { + cudaFree(mOutputPtr); + } + } + + void setGpuAllocator(GpuAllocator* gpuAllocator) override { + mGpuAllocator = gpuAllocator; + } + + //! Methods are called just after construction. TODO: can they be called + //! during construction? + void setTensorName(const char *tensorName) override { + mTensorName = tensorName; + } + + void setCurrentMemory(void *currentMemory) override { + mCurrentMemory = currentMemory; + } + + void setOutputSize(int64_t outputSize) override { mOutputSize = outputSize; } + + void *reallocateOutputAsync(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment, + cudaStream_t * /*stream*/) override; + + void notifyShape(char const *tensorName, const int64_t *dims, + int64_t nbDims) override; + + //! nullptr if memory could not be allocated + void *mOutputPtr{nullptr}; + + //! Size of allocation pointed to by output. + uint64_t mOutputSize{0}; + + bool mReallocateOutputCalled{false}; + + bool mNotifyShapeCalled{false}; + + //! Dimensions of tensor. + std::vector mOutputDims; + +private: + GpuAllocator* mGpuAllocator; + const char *mTensorName; + void *mCurrentMemory; +}; + +class OutputAllocatorTracker { +public: + OutputAllocatorTracker() = default; + ~OutputAllocatorTracker() = default; + + OutputAllocatorTracker(const OutputAllocatorTracker &) = delete; + OutputAllocatorTracker &operator=(const OutputAllocatorTracker &) = delete; + OutputAllocatorTracker(OutputAllocatorTracker &&) = default; + OutputAllocatorTracker &operator=(OutputAllocatorTracker &&) = default; + + // Add a new OutputAllocator + void addAllocator(void *ptr, OutputAllocator *allocator) { + mOutputAllocatorRegistry.emplace_back(std::make_pair(ptr, allocator)); + } + + // Get a reference to an OutputAllocator + OutputAllocator *getAllocator(void *ptr) { + auto it = std::find_if( + mOutputAllocatorRegistry.begin(), mOutputAllocatorRegistry.end(), + [ptr](const auto &pair) { return pair.first == ptr; }); + + if (it != mOutputAllocatorRegistry.end()) { + return it->second; + } + return nullptr; + } + +private: + std::vector> mOutputAllocatorRegistry; +}; + //===----------------------------------------------------------------------===// // PoolTrackedCudaEvent //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index 41b4db1b2..abd3b06c7 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -27,6 +27,7 @@ #include "mlir-executor/Runtime/API/API.h" #include "mlir-executor/Runtime/API/ExecutableFlatbuffer.h" #include "mlir-executor/Runtime/Backend/Lua/LuaRuntime.h" +#include "mlir-executor/Support/Allocators.h" #include "mlir-executor/Support/Status.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -414,7 +415,64 @@ static void dlpackManagedTensorDeleter(DLManagedTensor *tensor) { } } -MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueGetDLPackManagedTensor( + +class OutputAllocatorWrapper : public OutputAllocator { +private: + MTRT_OutputAllocator mPyOutputAllocator; + +public: + OutputAllocatorWrapper(MTRT_OutputAllocator outputAllocator) + : mPyOutputAllocator(outputAllocator) {} + + void setGpuAllocator(GpuAllocator *gpuAllocator) override { + return mPyOutputAllocator.setGpuAllocator( + mPyOutputAllocator.ptr, + MTRT_GpuAllocator{gpuAllocator, nullptr, nullptr}); + } + + void setTensorName(const char *tensorName) override { + return mPyOutputAllocator.setTensorName(mPyOutputAllocator.ptr, tensorName); + } + + void setCurrentMemory(void *currentMemory) override { + return mPyOutputAllocator.setCurrentMemory(mPyOutputAllocator.ptr, + currentMemory); + } + + void setOutputSize(const int64_t outputSize) override { + return mPyOutputAllocator.setOutputSize(mPyOutputAllocator.ptr, outputSize); + } + + void *reallocateOutputAsync(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment, + cudaStream_t *stream) override { + return mPyOutputAllocator.reallocateOutputAsync(mPyOutputAllocator.ptr, + tensorName, currentMemory, + size, alignment, stream); + } + + void notifyShape(char const *tensorName, const int64_t *dims, + int64_t nbDims) override { + return mPyOutputAllocator.notifyShape(mPyOutputAllocator.ptr, tensorName, + dims, nbDims); + } + + // Static method to create a OutputAllocator from MTRT_OutputAllocator + static std::unique_ptr + create(MTRT_OutputAllocator outputAllocator) { + if (!outputAllocator.ptr || !outputAllocator.setGpuAllocator || + !outputAllocator.setTensorName || !outputAllocator.setCurrentMemory || + !outputAllocator.setOutputSize || + !outputAllocator.reallocateOutputAsync || + !outputAllocator.notifyShape) { + llvm::errs() << "Invalid MTRT_OutputAllocator passed to create()"; + return nullptr; + } + return std::make_unique(outputAllocator); + } +}; + +MTRT_Status mtrtMemRefValueGetDLPackManagedTensor( MTRT_MemRefValue memrefValue, MTRT_DLPackManagedTensor *outTensor) { MemRefValue memref = *unwrap(memrefValue); @@ -461,7 +519,7 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueGetDLPackManagedTensor( return mtrtStatusGetOk(); } -MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueGetDLPackDevice( +MTRT_Status mtrtMemRefValueGetDLPackDevice( MTRT_MemRefValue memrefValue, int32_t *device_type, int32_t *device_id) { MemRefValue memref = *unwrap(memrefValue); int device = memref.getDevice().has_value() @@ -625,14 +683,71 @@ mtrtRuntimeSessionOptionsDestroy(MTRT_RuntimeSessionOptions options) { // MTRT_RuntimeSession //===----------------------------------------------------------------------===// -MTRT_Status mtrtRuntimeSessionCreate(MTRT_RuntimeSessionOptions options, - MTRT_Executable executable, - MTRT_RuntimeSession *result) { +MTRT_Status mtrtAddMemRefOutputAllocatorSessionRegistry( + MTRT_MemRefValue memrefValue, MTRT_OutputAllocator pyOutputAllocator) { + auto memref = unwrap(memrefValue); + + std::unique_ptr outputAllocator; + if (pyOutputAllocator.ptr) { + outputAllocator.reset( + OutputAllocatorWrapper::create(pyOutputAllocator).release()); + } + + // Client should own the output allocator. + memref->getClient()->addOutputAllocator(std::move(outputAllocator)); + + // Store the output allocator reference. + memref->setOutputAllocator(memref->getClient()->getLastOutputAllocator()); + + return mtrtStatusGetOk(); +} + +// A wrapper class for MTRT_GpuAllocator implementing the GpuAllocator +// interface. It encapsulates GPU memory allocation and deallocation operations, +// ensuring correct routing of callbacks from C++ to Python. +class GpuAllocatorWrapper : public GpuAllocator { +private: + MTRT_GpuAllocator mPyGpuAllocator; + +public: + GpuAllocatorWrapper(MTRT_GpuAllocator gpuAllocator) + : mPyGpuAllocator(gpuAllocator) {} + + void *allocate(uint64_t size, uint64_t alignment, uint32_t flags, cudaStream_t* stream) override { + return mPyGpuAllocator.allocate(mPyGpuAllocator.ptr, size, alignment, flags, stream); + } + + bool deallocate(void *ptr, cudaStream_t* stream) override { + return mPyGpuAllocator.deallocate(mPyGpuAllocator.ptr, ptr, stream); + } + + // Static method to create a GpuAllocator from MTRT_GpuAllocator + static std::unique_ptr create(MTRT_GpuAllocator gpuAllocator) { + if (!gpuAllocator.ptr || !gpuAllocator.allocate || + !gpuAllocator.deallocate) { + llvm::errs() << "Invalid MTRT_GpuAllocator passed to create()"; + return nullptr; + } + return std::make_unique(gpuAllocator); + } +}; + +MTRT_Status +mtrtRuntimeSessionCreate(MTRT_RuntimeSessionOptions options, + MTRT_Executable executable, + MTRT_GpuAllocator gpuAllocator, + MTRT_RuntimeSession *result) { RuntimeSessionOptions *cppOptions = unwrap(options); Executable *cppExecutable = unwrap(executable); + std::unique_ptr allocator; + if (gpuAllocator.ptr) { + allocator.reset(GpuAllocatorWrapper::create(gpuAllocator).release()); + } + StatusOr> session = - createRuntimeSessionWithLuaBackend(cppExecutable->getView(), *cppOptions); + createRuntimeSessionWithLuaBackend(cppExecutable->getView(), + std::move(allocator), *cppOptions); if (session.isError()) return wrap(session.getStatus()); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 6c10d1f99..583dc344b 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -353,16 +353,19 @@ RuntimeSessionOptions::createUsingSingleHostMpi() { //===----------------------------------------------------------------------===// // RuntimeSession //===----------------------------------------------------------------------===// - RuntimeSession::RuntimeSession( RuntimeSessionOptions options, ExecutableView exe, sol::state state, std::unique_ptr pinnedMemoryAllocator, std::unique_ptr allocTracker, - std::unique_ptr resourceTracker) + std::unique_ptr resourceTracker, + std::unique_ptr outputAllocatorTracker, + std::unique_ptr gpuAllocator) : options(std::move(options)), executable(exe), pinnedMemoryAllocator(std::move(pinnedMemoryAllocator)), allocTracker(std::move(allocTracker)), - resourceTracker(std::move(resourceTracker)), state(std::move(state)) {} + resourceTracker(std::move(resourceTracker)), + outputAllocatorTracker(std::move(outputAllocatorTracker)), + gpuAllocator(std::move(gpuAllocator)), state(std::move(state)) {} //===----------------------------------------------------------------------===// // AllocTracker diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index 7596c9da7..81897d8c2 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -72,7 +72,8 @@ static void registerDefaultDeviceDependentMethods(lua_State *state, static void registerLuaRuntimeMethodsCommon( lua_State *state, PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, ResourceTracker *resourceTracker) { + AllocTracker *allocTracker, ResourceTracker *resourceTracker, + GpuAllocator *allocator, OutputAllocatorTracker *outputAllocatorTracker) { registerExecutorCoreModuleLuaRuntimeMethods(state, pinnedMemoryAllocator, allocTracker); registerExecutorCUDAModuleLuaRuntimeMethods( @@ -84,15 +85,15 @@ static void registerLuaRuntimeMethodsCommon( #endif registerExecutorTensorRTModuleLuaRuntimeMethods( - state, pinnedMemoryAllocator, allocTracker, resourceTracker); + state, pinnedMemoryAllocator, allocTracker, resourceTracker, outputAllocatorTracker, allocator); } void mlirtrt::runtime::registerLuaRuntimeMethods( lua_State *state, const RuntimeSessionOptions &options, PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, - ResourceTracker *resourceTracker) { + ResourceTracker *resourceTracker, OutputAllocatorTracker* outputAllocatorTracker, GpuAllocator* allocator) { registerLuaRuntimeMethodsCommon(state, pinnedMemoryAllocator, allocTracker, - resourceTracker); + resourceTracker, allocator, outputAllocatorTracker); #ifdef MLIR_EXECUTOR_ENABLE_NCCL registerExecutorNCCLModuleLuaRuntimeMethods(state, resourceTracker); registerDeviceDependentNCCLMethods(state, options.getNumDevices(), @@ -107,8 +108,8 @@ void mlirtrt::runtime::registerLuaRuntimeMethods( #endif } -StatusOr -mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript) { +StatusOr mlirtrt::runtime::runExecutorLuaScript( + std::string_view luaScript, GpuAllocator *allocator) { ADD_RUNTIME_MODULE_RANGE("runtime_runExecutorLuaScript"); StatusOr> client = RuntimeClient::create(); @@ -117,10 +118,11 @@ mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript) { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::string); - registerLuaRuntimeMethods(lua.lua_state(), RuntimeSessionOptions(), - &(*client)->getPinnedMemorAllocator(), - &(*client)->getAllocTracker(), - &(*client)->getResourceTracker()); + registerLuaRuntimeMethods( + lua.lua_state(), RuntimeSessionOptions(), + &(*client)->getPinnedMemorAllocator(), &(*client)->getAllocTracker(), + &(*client)->getResourceTracker(), nullptr /* Output allocator */, + allocator /* can this be nullptr as well */); sol::protected_function_result result = lua.script(luaScript); if (!result.valid()) { @@ -171,7 +173,8 @@ static Status maybeCheckForValidNcclUuid(const RuntimeSessionOptions &options) { /// global initialization. StatusOr> mlirtrt::runtime::createRuntimeSessionWithLuaBackend( - ExecutableView executable, const RuntimeSessionOptions &options) { + ExecutableView executable, std::unique_ptr allocator, + const RuntimeSessionOptions &options) { ADD_RUNTIME_MODULE_RANGE("runtime_loadExecutable"); MTRT_RETURN_IF_ERROR(maybeCheckForValidNcclUuid(options)); @@ -179,12 +182,13 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend( auto pinnedMemoryAllocator = std::make_unique(); auto allocTracker = std::make_unique(); auto resourceTracker = std::make_unique(); + auto outputAllocatorTracker = std::make_unique(); sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::string); - registerLuaRuntimeMethods(lua.lua_state(), options, - pinnedMemoryAllocator.get(), allocTracker.get(), - resourceTracker.get()); + registerLuaRuntimeMethods( + lua.lua_state(), options, pinnedMemoryAllocator.get(), allocTracker.get(), + resourceTracker.get(), outputAllocatorTracker.get(), allocator.get()); // Load globals into the context. // TODO: eliminate this copy, we already own the executable. @@ -225,11 +229,13 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend( } return std::make_unique( options, executable, std::move(lua), std::move(pinnedMemoryAllocator), - std::move(allocTracker), std::move(resourceTracker)); + std::move(allocTracker), std::move(resourceTracker), + std::move(outputAllocatorTracker), std::move(allocator)); } StatusOr mlirtrt::runtime::runExecutorExecutable( - std::unique_ptr executable) { + std::unique_ptr executable, + std::unique_ptr allocator) { StatusOr> client = RuntimeClient::create(); if (!client.isOk()) @@ -245,7 +251,8 @@ StatusOr mlirtrt::runtime::runExecutorExecutable( return options.getStatus(); StatusOr> session = - createRuntimeSessionWithLuaBackend(executable->getView(), *options); + createRuntimeSessionWithLuaBackend(executable->getView(), + std::move(allocator), *options); if (!session.isOk()) return session.getStatus(); @@ -465,6 +472,8 @@ runtime::executeFunctionWithLuaBackend( // Call the main function, if present. sol::state_view lua(session.getLuaState()); AllocTracker &tracker = session.getAllocTracker(); + OutputAllocatorTracker &outputAllocatorTracker = session.getOutputAllocatorTracker(); + sol::protected_function funcObj = lua[name]; if (funcObj.get_type() != sol::type::function) return getStatusWithMsg(StatusCode::InternalError, "no function named \"", @@ -523,6 +532,12 @@ runtime::executeFunctionWithLuaBackend( for (auto [idx, rv] : llvm::enumerate(outputArgs)) { if (MemRefValue *memref = llvm::dyn_cast(rv)) { MTRT_RETURN_IF_ERROR(pushMemRefTableArg(lua, tracker, args, *memref)); + + // Creating a mapping from memref pointer to output allocator tracker. + if (memref->getOutputAllocator()) { + outputAllocatorTracker.addAllocator(memref->getVoidPtr(), memref->getOutputAllocator()); + } + continue; } return getInvalidArgStatus("output (destination) argument #{0} to function " diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index 1b96eac44..1aed00592 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -64,6 +64,131 @@ class StdioLogger : public nvinfer1::ILogger { bool verbose; }; +//===----------------------------------------------------------------------===// +// TensorRTCallBackOutputAllocator +//===----------------------------------------------------------------------===// + +static bool isSubByte(nvinfer1::DataType t) { + return t == nvinfer1::DataType::kINT4; +} + +static int32_t elementSizeInBits(nvinfer1::DataType t) { + switch (t) { + case nvinfer1::DataType::kINT64: + return 64; + case nvinfer1::DataType::kINT32: + return 32; + case nvinfer1::DataType::kFLOAT: + return 32; + case nvinfer1::DataType::kHALF: + return 16; + case nvinfer1::DataType::kBF16: + return 16; + case nvinfer1::DataType::kINT8: + return 8; + case nvinfer1::DataType::kBOOL: + return 8; + case nvinfer1::DataType::kUINT8: + return 8; + case nvinfer1::DataType::kFP8: + return 8; + case nvinfer1::DataType::kINT4: + return 4; + } + return 0; +} + +static int32_t elementeSizeInBytes(nvinfer1::DataType dtype) { + if (!isSubByte(dtype)) { + auto bits = elementSizeInBits(dtype); + assert(bits % 8 == 0); + return bits / 8; + } + if (dtype == nvinfer1::DataType::kINT4) { + return 1; + } + return -1; +} + +static int64_t volume(nvinfer1::Dims64 const& d) +{ + int64_t v = 1; + for (int64_t i = 0; i < d.nbDims; i++) + { + v *= d.d[i]; + } + return v; +} + +class TensorRTCallBackOutputAllocator final + : public nvinfer1::IOutputAllocator { +public: + TensorRTCallBackOutputAllocator(GpuAllocator* gpuAllocator, OutputAllocator *outputAllocator, + const char *tensorName, void *currentMemory, + nvinfer1::Dims64 dims, + nvinfer1::DataType dtype) + : nvinfer1::IOutputAllocator(), + mOutputAllocatorCallBack(outputAllocator) { + mOutputAllocatorCallBack->setGpuAllocator(gpuAllocator); + mOutputAllocatorCallBack->setTensorName(tensorName); + mOutputAllocatorCallBack->setCurrentMemory(currentMemory); + mOutputAllocatorCallBack->setOutputSize(volume(dims) * + elementeSizeInBytes(dtype)); + } + + void *reallocateOutput(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment) noexcept override { + return mOutputAllocatorCallBack->reallocateOutputAsync( + tensorName, currentMemory, size, alignment, nullptr); + } + + //! IMirroredBuffer does not implement Async allocation, hence this is just a + //! wrap around + void *reallocateOutputAsync(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment, + cudaStream_t stream) noexcept override { + + return mOutputAllocatorCallBack->reallocateOutputAsync( + tensorName, currentMemory, size, alignment, &stream); + } + + void notifyShape(char const *tensorName, + nvinfer1::Dims const &dims) noexcept override { + return mOutputAllocatorCallBack->notifyShape(tensorName, &dims.d[0], dims.nbDims); + } + + ~TensorRTCallBackOutputAllocator() override {} + +private: + OutputAllocator *mOutputAllocatorCallBack; +}; + +//===----------------------------------------------------------------------===// +// TensorRTCallBackAllocator +//===----------------------------------------------------------------------===// + +class TensorRTCallBackAllocator final : public nvinfer1::IGpuAsyncAllocator { +public: + TensorRTCallBackAllocator(GpuAllocator *gpuAllocator) + : nvinfer1::IGpuAsyncAllocator(), mGpuAllocatorCallBack(gpuAllocator) {} + + void *allocateAsync(uint64_t const size, uint64_t const alignment, + uint32_t flags, cudaStream_t stream) noexcept final { + void *result = + mGpuAllocatorCallBack->allocate(size, alignment, flags, &stream); + return result; + } + + bool deallocateAsync(void *const memory, + cudaStream_t stream) noexcept override { + bool result = mGpuAllocatorCallBack->deallocate(memory, &stream); + return result; + } + +private: + GpuAllocator *mGpuAllocatorCallBack; +}; + } // namespace static StdioLogger logger(/*verbose=*/false); @@ -88,13 +213,40 @@ struct Signature { } }; +class NvInferRuntimeWrapper { +public: + explicit NvInferRuntimeWrapper(GpuAllocator* gpuAllocator) { + runtime = std::shared_ptr( + nvinfer1::createInferRuntime(logger), [](nvinfer1::IRuntime *runtime) { + MTRT_DBGF("freeing tensorrt runtime at %lu", + reinterpret_cast(runtime)); + delete runtime; + }); + // GpuAllocator is optional. + if (gpuAllocator) { + callbackAllocatorPair = + std::make_pair(std::shared_ptr( + new TensorRTCallBackAllocator(gpuAllocator)), + gpuAllocator); + runtime->setGpuAllocator(callbackAllocatorPair.first.get()); + } + } + + nvinfer1::IRuntime *operator*() { return runtime.get(); } + nvinfer1::IRuntime *operator->() { return runtime.get(); } + + std::shared_ptr runtime; + std::pair, GpuAllocator*> callbackAllocatorPair; +}; + class NvInferEngineWrapper { public: - explicit NvInferEngineWrapper(std::shared_ptr &runtime, + explicit NvInferEngineWrapper(std::shared_ptr runtime, uintptr_t pointer, size_t size) : runtime(runtime) { engine = std::shared_ptr( - runtime->deserializeCudaEngine(reinterpret_cast(pointer), size), + runtime->runtime->deserializeCudaEngine( + reinterpret_cast(pointer), size), [](nvinfer1::ICudaEngine *engine) { MTRT_DBGF("freeing cuda engine at %lu", reinterpret_cast(engine)); @@ -105,7 +257,7 @@ class NvInferEngineWrapper { nvinfer1::ICudaEngine *operator*() { return engine.get(); } nvinfer1::ICudaEngine *operator->() { return engine.get(); } - std::shared_ptr runtime; + std::shared_ptr runtime; std::shared_ptr engine; }; @@ -183,6 +335,22 @@ class NvInferExecContextWrapper { /// Returned the pre-allocated host staging buffers. std::vector &getHostIOBuffers() { return hostIOBuffers; } + /// Add a call back output allocator. + void addCallBackAllocators( + std::unique_ptr allocator) { + outputAllocators.emplace_back(std::move(allocator)); + } + + /// Return the last call back output allocator pointer. + TensorRTCallBackOutputAllocator *getLastCallBackAllocatorPtr() { + return outputAllocators.back().get(); + } + + /// Return registered callback gpu allocator. + GpuAllocator *getGpuAllocator() { + return engine->runtime->callbackAllocatorPair.second; + } + private: // We keep a reference to the cuda engine to keep it from going out of scope. // The standard TensorRTRuntime-to-Executor lowering only creates globals for @@ -196,13 +364,14 @@ class NvInferExecContextWrapper { /// A set of pinned host buffers one per input host buffer (shape tensor) to /// the TRT network. std::vector hostIOBuffers; + std::vector> outputAllocators; }; } // namespace -static Status setTensorAddressesOrReport( +static Status setTensorAddressesAndOutputAllocatorsOrReport( NvInferExecContextWrapper &context, const std::vector> - &buffers) { + &buffers, OutputAllocatorTracker &outputAllocatorTracker) { ADD_TENSORRT_MODULE_RANGE("set_tensor_addresses"); unsigned idx = 0; for (auto &[name, ptr, dims] : buffers) { @@ -215,9 +384,10 @@ static Status setTensorAddressesOrReport( bool result = context->setTensorAddress(name.c_str(), reinterpret_cast(ptr)); + const nvinfer1::ICudaEngine &engine = context->getEngine(); + if (!result) { std::stringstream ss; - const nvinfer1::ICudaEngine &engine = context->getEngine(); ss << "Failed to set tensor address for IO tensor: " << name << " at position " << idx << "; the IO tensors are:\n"; for (int64_t i = 0; i < engine.getNbIOTensors(); i++) { @@ -238,6 +408,37 @@ static Status setTensorAddressesOrReport( return getInternalErrorStatus("failed to set input shape"); } + // Set output allocators + if (engine.getTensorIOMode(name.c_str()) == + nvinfer1::TensorIOMode::kOUTPUT and + engine.getTensorLocation(name.c_str()) == + nvinfer1::TensorLocation::kDEVICE) { + + // Since setting output allocator is optional. + if (outputAllocatorTracker.getAllocator(reinterpret_cast(ptr)) != + nullptr) { + context.addCallBackAllocators( + std::make_unique( + context.getGpuAllocator(), + outputAllocatorTracker.getAllocator( + reinterpret_cast(ptr)), + name.c_str(), reinterpret_cast(ptr), dims, + engine.getTensorDataType(name.c_str()))); + context->setOutputAllocator(name.c_str(), + static_cast( + context.getLastCallBackAllocatorPtr())); + } else { + // It is possible that previous call with same output name and different + // memref pointer would have set output allocator. Due to "hacky" naming + // scheme, outputs are always named as "result0", "result1", .... If not + // tracker is found for a given pointer, let's unset the output + // allocator. + if (context->getOutputAllocator(name.c_str())) { + context->setOutputAllocator(name.c_str(), nullptr); + } + } + } + MTRT_DBGF("Set tensor address [%d] = %lu", idx, ptr); idx++; } @@ -339,6 +540,7 @@ prepareBuffers(const AllocTracker &allocTracker, static Status enqueueV3Wrapper(AllocTracker &tracker, ResourceTracker &resourceTracker, + OutputAllocatorTracker &outputAllocatorTracker, NvInferExecContextWrapper &context, CudaStreamPtr stream, sol::table &va) { StatusOr>> @@ -347,8 +549,8 @@ static Status enqueueV3Wrapper(AllocTracker &tracker, return getStatusWithMsg(StatusCode::InternalError, "failed to prepare buffers: ", buffers.getString()); - MTRT_RETURN_IF_ERROR(setTensorAddressesOrReport(context, *buffers)); + MTRT_RETURN_IF_ERROR(setTensorAddressesAndOutputAllocatorsOrReport(context, *buffers, outputAllocatorTracker)); // Create an event that we can wait on for releasing any host-pinned staging // allocations we made. MTRT_ASSIGN_OR_RETURN(CudaEventPtr inputConsumedEvent, @@ -375,19 +577,21 @@ static Status enqueueV3Wrapper(AllocTracker &tracker, //===----------------------------------------------------------------------===// void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods( lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, ResourceTracker *resourceTracker) { + AllocTracker *allocTracker, ResourceTracker *resourceTracker, + OutputAllocatorTracker *outputAllocatorTracker, GpuAllocator *allocator) { sol::state_view lua(luaState); - lua["_trtrt_create_runtime"] = [](sol::this_state state) { + lua["_trtrt_create_runtime"] = + [allocator](sol::this_state state) -> std::shared_ptr { ADD_TENSORRT_MODULE_RANGE("trtrt_create_runtime"); MTRT_DBGF("%s", "creating nvinfer runtime"); - return std::shared_ptr( - nvinfer1::createInferRuntime(logger)); + return std::make_shared(allocator); }; lua["_trtrt_load"] = [allocTracker]( - sol::this_state state, std::shared_ptr &runtime, + sol::this_state state, + std::shared_ptr &runtime, uintptr_t pointer) -> std::shared_ptr { ADD_TENSORRT_MODULE_RANGE("trtrt_load"); const AllocTracker &tracker = *allocTracker; @@ -411,16 +615,17 @@ void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods( }; lua["_trtrt_enqueue"] = - [allocTracker, - resourceTracker](sol::this_state state, - std::shared_ptr context, - CudaStreamPtr stream, sol::table va) { + [allocTracker, resourceTracker, outputAllocatorTracker]( + sol::this_state state, + std::shared_ptr context, + CudaStreamPtr stream, sol::table va) { ADD_TENSORRT_MODULE_RANGE("trtrt_enqueue"); sol::state_view luaState(state); assert(context != nullptr); assert(stream != nullptr && "expected valid stream"); - Status result = enqueueV3Wrapper(*allocTracker, *resourceTracker, - *context, stream, va); + Status result = + enqueueV3Wrapper(*allocTracker, *resourceTracker, + *outputAllocatorTracker, *context, stream, va); SET_LUA_ERROR_IF_ERROR(result, state); }; } diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index 2eadd2cca..8def6eb90 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -23,11 +23,14 @@ //===----------------------------------------------------------------------===// #include "mlir-executor/Support/Allocators.h" #include "mlir-executor/Support/Status.h" +#include "mlir-executor/Runtime/Support/Support.h" +#include "cuda_runtime_api.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" +#include #include #include #include @@ -39,6 +42,120 @@ using namespace mlirtrt; DEBUG_WITH_TYPE("allocators", fprintf(stderr, "%s:%d " fmt "\n", __FILE__, \ __LINE__, __VA_ARGS__)) +//===----------------------------------------------------------------------===// +// CustomTensorRTOutputAllocator +//===----------------------------------------------------------------------===// + +inline uint64_t roundUp(uint64_t m, uint64_t n) { + return ((m + n - 1) / n) * n; +} + +void *CustomTensorRTOuputAllocator::reallocateOutputAsync( + char const *tensorName, void *currentMemory, uint64_t size, + uint64_t alignment, cudaStream_t *stream) { + + assert(currentMemory == mCurrentMemory && "output buffer mismatch"); + assert(strcmp(tensorName, mTensorName) == 0 && "tensor name mismatch"); + assert(!mReallocateOutputCalled && "duplicate call to reallocateOutput"); + mReallocateOutputCalled = true; + // Some memory allocators return nullptr when allocating zero bytes, but + // TensorRT requires a non-null ptr even for empty tensors, so allocate a + // dummy byte. + size = std::max(size, static_cast(1)); + + // Check if reallocation is required. + if (size > mOutputSize) { + size = roundUp(size, alignment); + + if (mOutputPtr) { + if (mGpuAllocator) { + // Use registeted call back GPU allocator for output allocations. + mGpuAllocator->deallocate(mOutputPtr, stream); + } else { + // Fall-back to local memory management. + cudaFree(mOutputPtr); + } + } + + mOutputPtr = nullptr; + mOutputSize = 0; + + void *memory; + if (mGpuAllocator) { + // Use registeted call back GPU allocator for output allocations. + memory = mGpuAllocator->allocate(size, alignment, 0 /* flags */, stream); + } else { + // Fall-back to local memory management. + cudaMalloc(&memory, size); + } + mOutputPtr = memory; + if (mOutputPtr != nullptr) { + mOutputSize = size; + } + return mOutputPtr; + } + return mCurrentMemory; +} + +void CustomTensorRTOuputAllocator::notifyShape(char const *tensorName, + const int64_t *dims, int64_t nbDims) { + assert(mReallocateOutputCalled && + "TensorRT must invoke reallocateOutput first"); + assert(!mNotifyShapeCalled && "duplicate call to notifyShape"); + assert(tensorName == mTensorName); + + mNotifyShapeCalled = true; + mOutputDims.resize(nbDims); + std::copy_n(dims, nbDims, mOutputDims.begin()); +} + +//===----------------------------------------------------------------------===// +// CustomTensorRTAllocator +//===----------------------------------------------------------------------===// + + +void* +CustomTensorRTAllocator::allocate(uint64_t const size, uint64_t const alignment, + uint32_t /*flags*/, + cudaStream_t* stream) { + uint8_t *memory; + assert(alignment > 0 && (alignment & (alignment - 1)) == 0 && + "Memory alignment has to be power of 2"); + if (stream && *stream != nullptr) { + auto status = cudaMallocAsync(reinterpret_cast(&memory), size, *stream); + assert(status == cudaSuccess); + MTRT_DBGF("[CustomTensorRTAllocator][allocate]: Asynchronously allocated %lx bytes at 0x%lx on stream %lx", size, + reinterpret_cast(memory), + reinterpret_cast(*stream)); + } else { + auto status = cudaMalloc(reinterpret_cast(&memory), size); + assert(status == cudaSuccess); + MTRT_DBGF("[CustomTensorRTAllocator][allocate]: Synchronously allocated %lx bytes at 0x%lx", size, + reinterpret_cast(memory)); + } + assert(reinterpret_cast(memory) % alignment == 0); + return memory; +} + +bool CustomTensorRTAllocator::deallocate(void *const memory, + cudaStream_t* stream) { + if (stream && *stream != nullptr) { + MTRT_DBGF("[CustomTensorRTAllocator][deallocate]: Asynchronously freeing CUDA device memory 0x%lx on stream %lx", + reinterpret_cast(memory), + reinterpret_cast(*stream)); + cudaError_t status = cudaFreeAsync(memory, *stream); + assert(status == cudaSuccess); + } else { + MTRT_DBGF("[CustomTensorRTAllocator][deallocate]: Synchronously freeing CUDA device/pinned host memory 0x%lx ptr " + "on stream %lx", + reinterpret_cast(memory), + reinterpret_cast(*stream)); + cudaError_t status = cudaFree(memory); + assert(status == cudaSuccess); + } + return true; +} + //===----------------------------------------------------------------------===// // PoolTrackedCudaEvent //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp b/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp index 3241de0da..d06b59618 100644 --- a/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp +++ b/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp @@ -98,7 +98,12 @@ struct Options { cl::values(clEnumValN(Lua, "lua", "interpret the input as Lua code")), cl::values(clEnumValN(ExecutorRuntimeExecutable, "rtexe", "load the input file as an Executor executable"))}; + + cl::opt useCustomAllocator{"use-custom-allocator", + cl::desc("Use custom allocator"), + cl::init(false)}; }; + } // namespace LogicalResult @@ -168,13 +173,17 @@ executor::ExecutorRunnerMain(int argc, char **argv, if (result != cudaSuccess) return emitError(loc) << "cudaFree failed: " << cudaGetErrorString(result); - // Read the buffer as a Lua script and execute. + std::unique_ptr allocator{nullptr}; + if (options.useCustomAllocator) { + // Create an optional runtime GPU allocator + allocator.reset(new CustomTensorRTAllocator()); + } if (options.inputType == Lua) { assert(!options.dumpFunctionSignature && "Can not dump function signature for Lua input type."); mlirtrt::StatusOr result = - mlirtrt::runtime::runExecutorLuaScript(input->getBuffer()); + mlirtrt::runtime::runExecutorLuaScript(input->getBuffer(), allocator.get()); if (!result.isOk()) return emitError(UnknownLoc::get(&context)) << result.getString(); return success(*result == 0); @@ -202,7 +211,8 @@ executor::ExecutorRunnerMain(int argc, char **argv, } mlirtrt::StatusOr executionResult = - mlirtrt::runtime::runExecutorExecutable(std::move(*executable)); + mlirtrt::runtime::runExecutorExecutable(std::move(*executable), + std::move(allocator)); if (!executionResult.isOk()) return emitError(UnknownLoc::get(&context)) << "failed to load and run executable: " diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index fef5ad868..efe2bf7ec 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -31,6 +31,8 @@ #include #include +#include "cuda_runtime.h" + namespace py = pybind11; using namespace mlirtrt; @@ -126,6 +128,7 @@ class PyStream : public PyMTRTWrapper { public: using Base::Base; DECLARE_WRAPPER_CONSTRUCTORS(PyStream); + static constexpr auto kMethodTable = CAPITable{ mtrtStreamIsNull, mtrtStreamDestroy, mtrtPythonCapsuleToStream, mtrtPythonStreamToCapsule}; @@ -184,6 +187,201 @@ class PyRuntimeValue : public PyMTRTWrapper { mtrtPythonCapsuleToRuntimeValue, mtrtPythonRuntimeValueToCapsule}; }; +// Abstract base class for Python-implemented GPU allocators. +// Provides a C++ interface for Python classes and handles C-style callback +// routing. +class PyGpuAllocator { +public: + py::object pySelf; + // This ensure that PyGpuAllocator is not deallocated before corresponding Python object lives. + PyGpuAllocator(py::object self) : pySelf(self) {} + + virtual ~PyGpuAllocator() = default; + virtual std::uintptr_t allocate(uint64_t size, uint64_t alignment, + uint32_t flags) = 0; + virtual bool deallocate(std::uintptr_t ptr) = 0; + + // Creates a C-compatible struct for interfacing with lower-level APIs. + MTRT_GpuAllocator getCApiObject() { return createWithPythonCallbacks(this); } + +private: + // Trampoline function: Routes C-style allocation calls to C++ virtual method. + static void *pyGpuAllocatorAllocate(void *self, uint64_t size, + uint64_t alignment, uint32_t flags, + cudaStream_t* /*stream*/) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + std::uintptr_t ptr = allocator->allocate(size, alignment, flags); + return reinterpret_cast(ptr); + } + + // Trampoline function: Routes C-style deallocation calls to C++ virtual + // method. + static bool pyGpuAllocatorDeallocate(void *self, void *memory, + cudaStream_t* /*stream*/) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + return allocator->deallocate(reinterpret_cast(memory)); + } + + // Constructs MTRT_GpuAllocator with this instance's methods as callbacks. + static MTRT_GpuAllocator + createWithPythonCallbacks(PyGpuAllocator *allocator) { + MTRT_GpuAllocator capi_allocator; + capi_allocator.ptr = allocator; + capi_allocator.allocate = pyGpuAllocatorAllocate; + capi_allocator.deallocate = pyGpuAllocatorDeallocate; + return capi_allocator; + } +}; + +// Pybind11 trampoline class for PyGpuAllocator. +// Enables Python subclasses to override virtual methods of PyGpuAllocator. +class PyGpuAllocatorTrampoline : public PyGpuAllocator { +public: + using PyGpuAllocator::PyGpuAllocator; // Inherit constructors + + // Trampoline for allocate: Dispatches call to Python implementation if + // overridden. + uintptr_t allocate(uint64_t size, uint64_t alignment, uint32_t flags) override { + PYBIND11_OVERRIDE_PURE(uintptr_t, // Return type + PyGpuAllocator, // Parent class + allocate, // Name of function in C++ + size, // Arguments + alignment, flags); + } + + // Trampoline for deallocate: Dispatches call to Python implementation if + // overridden. + bool deallocate(uintptr_t ptr) override { + PYBIND11_OVERRIDE_PURE(bool, // Return type + PyGpuAllocator, // Parent class + deallocate, // Name of function in C++ + ptr); // Arguments + } +}; + +class PyOutputAllocator { +public: + py::object pySelf; + // This ensure that PyOutputAllocator is not deallocated before corresponding + // Python object lives. + PyOutputAllocator(py::object self) : pySelf(self) {} + + virtual ~PyOutputAllocator() = default; + virtual void setTensorName(const char *tensorName) = 0; + virtual void setCurrentMemory(uintptr_t currentMemory) = 0; + virtual void setOutputSize(const int64_t outputSize) = 0; + virtual uintptr_t reallocateOutputAsync(char const *tensorName, + uintptr_t currentMemory, uint64_t size, + uint64_t alignment) = 0; + virtual void notifyShape(char const *tensorName, const int64_t *dims, + int64_t nbDims) = 0; + // Creates a C-compatible struct for interfacing with lower-level APIs. + MTRT_OutputAllocator getCApiObject() { return createWithPythonCallbacks(this); } + +private: + static void PySetGpuAllocator(void *self, MTRT_GpuAllocator gpuAllocator) { + // Let user use the default available gpu allocator for now. + } + + static void PySetTensorName(void *self, const char *tensorName) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + return allocator->setTensorName(tensorName); + } + + static void PySetCurrentMemory(void *self, void *currentMemory) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + return allocator->setCurrentMemory( + reinterpret_cast(currentMemory)); + } + + static void PySetOutputSize(void *self, const int64_t outputSize) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + return allocator->setOutputSize(outputSize); + } + + static void *PyReallocateOutputAsync(void *self, char const *tensorName, + void *currentMemory, uint64_t size, + uint64_t alignment, + cudaStream_t * /*stream*/) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + return reinterpret_cast(allocator->reallocateOutputAsync( + tensorName, reinterpret_cast(currentMemory), size, + alignment)); + } + + static void PyNotifyShape(void *self, char const *tensorName, const int64_t *dims, + int64_t nbDims) { + py::gil_scoped_acquire acquire; + auto *allocator = static_cast(self); + return allocator->notifyShape(tensorName, dims, nbDims); + } + + // Constructs MTRT_GpuAllocator with this instance's methods as callbacks. + static MTRT_OutputAllocator + createWithPythonCallbacks(PyOutputAllocator *allocator) { + MTRT_OutputAllocator capi_allocator; + capi_allocator.ptr = allocator; + capi_allocator.setGpuAllocator = PySetGpuAllocator; + capi_allocator.setTensorName = PySetTensorName; + capi_allocator.setCurrentMemory = PySetCurrentMemory; + capi_allocator.setOutputSize = PySetOutputSize; + capi_allocator.reallocateOutputAsync = PyReallocateOutputAsync; + capi_allocator.notifyShape = PyNotifyShape; + return capi_allocator; + } +}; + +// Pybind11 trampoline class for PyOutputAllocator. +// Enables Python subclasses to override virtual methods of PyOutputAllocator. +class PyOutputAllocatorTrampoline : public PyOutputAllocator { +public: + using PyOutputAllocator::PyOutputAllocator; // Inherit constructors + + // Trampoline for setTensorName: Dispatches call to Python implementation if + // overridden. + void setTensorName(const char *tensorName) override { + PYBIND11_OVERRIDE_PURE(void, // Return type + PyOutputAllocator, // Parent class + set_tensor_name, // Name of function in Python + tensorName); // Arguments + } + void setCurrentMemory(uintptr_t currentMemory) override { + PYBIND11_OVERRIDE_PURE(void, // Return type + PyOutputAllocator, // Parent class + set_current_memory,// Name of function in Python + currentMemory); // Arguments + } + void setOutputSize(const int64_t outputSize) override { + PYBIND11_OVERRIDE_PURE(void, // Return type + PyOutputAllocator, // Parent class + set_output_size, // Name of function in Python + outputSize); // Arguments + } + uintptr_t reallocateOutputAsync(char const *tensorName, + uintptr_t currentMemory, uint64_t size, + uint64_t alignment) override { + PYBIND11_OVERRIDE_PURE(uintptr_t, // Return type + PyOutputAllocator, // Parent class + reallocate_output, // Name of function in Python + tensorName, // Arguments + currentMemory, size, alignment); + } + void notifyShape(char const *tensorName, const int64_t *dims, + int64_t nbDims) override { + PYBIND11_OVERRIDE_PURE(void, // Return type + PyOutputAllocator, // Parent class + notify_shape, // Name of function in C++ + tensorName, // Arguments + dims, nbDims); + } +}; + /// Python object type wrapper for `MTRT_StableHLOToExecutableOptions`. class PyRuntimeSessionOptions : public PyMTRTWrapper(m, "GpuAllocator") + .def(py::init<>( + [](py::object self) { return new PyGpuAllocatorTrampoline(self); })) + .def("allocate", &PyGpuAllocator::allocate) + .def("deallocate", &PyGpuAllocator::deallocate) + .def("get_capi_object", &PyGpuAllocator::getCApiObject); + + py::class_(m, + "OutputAllocator") + .def(py::init<>( + [](py::object self) { return new PyOutputAllocatorTrampoline(self); })) + .def("set_tensor_name", &PyOutputAllocator::setTensorName) + .def("set_current_memory", &PyOutputAllocator::setCurrentMemory) + .def("set_output_size", &PyOutputAllocator::setOutputSize) + .def("rellocate_output_async", &PyOutputAllocator::reallocateOutputAsync) + .def("notify_shape", &PyOutputAllocator::notifyShape) + .def("get_capi_object", &PyOutputAllocator::getCApiObject); + py::class_(m, "RuntimeSession", py::module_local()) - .def(py::init<>([](PyRuntimeSessionOptions &options, PyExecutable &exe) { + .def(py::init<>([](PyRuntimeSessionOptions &options, PyExecutable &exe, + py::object gpu_allocator = py::none()) { MTRT_RuntimeSession session; - MTRT_Status s = mtrtRuntimeSessionCreate(options, exe, &session); + MTRT_Status s; + + if (gpu_allocator.is_none()) { + // Create session without custom allocator + s = mtrtRuntimeSessionCreate( + options, exe, MTRT_GpuAllocator{nullptr, nullptr, nullptr}, + &session); + } else { + try { + PyGpuAllocator &allocator = + gpu_allocator.cast(); + MTRT_GpuAllocator capi_allocator = allocator.getCApiObject(); + s = mtrtRuntimeSessionCreate(options, exe, capi_allocator, + &session); + } catch (const py::cast_error &) { + throw py::type_error( + "gpu_allocator must be a GpuAllocator object or None"); + } + } THROW_IF_MTRT_ERROR(s); return new PyRuntimeSession(session); }), - py::arg("options"), py::arg("executable")) + py::arg("options"), py::arg("executable"), + py::arg("gpu_allocator") = py::none()) .def( "execute_function", [](PyRuntimeSession &self, std::string name, diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir index 84d8e714a..0d3688a17 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir @@ -3,7 +3,7 @@ // RUN: stablehlo-clustering-pipeline, \ // RUN: post-clustering-pipeline, \ // RUN: executor-lowering-pipeline)" \ -// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | mlir-tensorrt-runner -input-type=rtexe +// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator #profile = #tensorrt.shape_profile #profile1 = #tensorrt.shape_profile diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir index 28718291a..949ce1b7d 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir @@ -3,7 +3,7 @@ // RUN: stablehlo-clustering-pipeline, \ // RUN: post-clustering-pipeline, \ // RUN: executor-lowering-pipeline)" \ -// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | mlir-tensorrt-runner -input-type=rtexe +// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator #profile0 = #tensorrt.shape_profile #profile1 = #tensorrt.shape_profile diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir index 15f652aac..f825f236d 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xbf16, #plan.memory_space> diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir index 73c1cd690..0b0230b83 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s func.func @run_with_shape_2d(%arg0: memref, %arg1: memref<2xindex>) { %c0 = arith.constant 0 : index diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir index 448b88c6f..526c8162f 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xf16, #plan.memory_space> diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir index 0d16f189a..6196e5317 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xf32, #plan.memory_space> diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir index 7b3ae4765..6e93ac265 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir @@ -2,7 +2,7 @@ // REQUIRES: all-gpus-support-fp8 // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xf8E4M3FN, #plan.memory_space> diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir index f44da93c5..61a74dfdf 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xi1, #plan.memory_space> diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir index 766bec84f..5d917af19 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xi4, #plan.memory_space> diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy-strided.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy-strided.mlir index 0abcfec01..8cab67749 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy-strided.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy-strided.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s func.func @main() -> index { %c0 = arith.constant 0 : index diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy.mlir index f750810c8..dd336d8ce 100644 --- a/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy.mlir +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/memcpy.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe --use-custom-allocator | FileCheck %s func.func @main() -> i32 { %c0_i32 = arith.constant 0 : i32 diff --git a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py index 480ce74d4..fce25bac3 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py @@ -33,7 +33,7 @@ def test_stablehlo_add( exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) - session = runtime.RuntimeSession(session_options, exe) + session = runtime.RuntimeSession(session_options, exe, None) session.execute_function( "main", in_args=test.in_args, out_args=test.out_args, stream=stream diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py index 1687a8f1b..8ebc2d3c4 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py @@ -62,7 +62,7 @@ def create_scalar(self, value): return self.client.create_scalar(value, runtime.ScalarTypeCode.i64) def execute(self, arg: runtime.RuntimeValue): - session = runtime.RuntimeSession(self.session_options, self.exe) + session = runtime.RuntimeSession(self.session_options, self.exe, None) try: session.execute_function( "main", in_args=[arg], out_args=[arg], stream=self.stream diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py index 2c95a3081..1e7289b71 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py @@ -5,6 +5,7 @@ import mlir_tensorrt.compiler.ir as ir import mlir_tensorrt.runtime.api as runtime import numpy as np +import cupy as cp ASM = """ func.func @main(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { @@ -14,6 +15,57 @@ """ +class CupyGPUAllocator(runtime.GpuAllocator): + def __init__(self): + super().__init__(self) + self.allocations = {} # Keep track of allocations + + def allocate(self, size, alignment, flags): + # Allocate memory on the GPU using CuPy + mem = cp.cuda.alloc(size) + ptr = int(mem.ptr) # Convert to integer + # Store the CuPy memory object + self.allocations[ptr] = mem + return ptr + + def deallocate(self, ptr): + if ptr in self.allocations: + # Remove the reference to the CuPy memory object + # This will trigger deallocation if there are no other references + del self.allocations[ptr] + return True + return False + + +class CupyOutputAllocator(runtime.OutputAllocator): + def __init__(self): + super().__init__(self) + + def set_tensor_name(self, tensor_name): + self.tensor_name = tensor_name + + def set_current_memory(self, memory): + self.memory = memory + + def set_output_size(self, size): + self.size = size + + def reallocate_output(self, tensor_name, memory, size, alignment): + assert self.tensor_name == tensor_name + assert self.memory == memory + + if size > self.size: + # For now just fail if reallocation is required. + assert 0 + + return self.memory + + def notify_shape(self, tensor_name, dims, nb_dims): + assert self.tensor_name == tensor_name + self.dims = dims + self.nb_dims = nb_dims + + def stablehlo_add(): # Build/parse the main function. with ir.Context() as context: @@ -36,8 +88,11 @@ def stablehlo_add(): if len(devices) == 0: return + # Create an instance of the custom allocator + allocator = CupyGPUAllocator() + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) - session = runtime.RuntimeSession(session_options, exe) + session = runtime.RuntimeSession(session_options, exe, gpu_allocator=allocator) arg0 = client.create_memref( np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, @@ -49,6 +104,10 @@ def stablehlo_add(): device=devices[0], stream=stream, ) + + output_allocator = CupyOutputAllocator() + arg1.set_output_allocator(output_allocator) + session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) data = np.asarray(client.copy_to_host(arg1, stream=stream)) @@ -62,12 +121,12 @@ def stablehlo_add(): start_time = time.time() for _ in range(0, num_iter): session.execute_function("main", in_args=[arg0], out_args=[arg0], stream=stream) - data = np.asarray(client.copy_to_host(arg1, stream=stream)) + data = np.asarray(client.copy_to_host(arg0, stream=stream)) stream.sync() end_time = time.time() elapsed = end_time - start_time - print(np.asarray(client.copy_to_host(arg0))) + print(np.asarray(data)) print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration") diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py index 35515e054..b8c56a6df 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py @@ -86,6 +86,35 @@ def infer_output_shape(client, session, exe, input_shape): return output_shape +class CupyOutputAllocator(runtime.OutputAllocator): + def __init__(self): + super().__init__(self) + + def set_tensor_name(self, tensor_name): + self.tensor_name = tensor_name + + def set_current_memory(self, memory): + self.memory = memory + + def set_output_size(self, size): + self.size = size + + def reallocate_output(self, tensor_name, memory, size, alignment): + assert self.tensor_name == tensor_name + assert self.memory == memory + + if size > self.size: + # For now just fail if reallocation is required. + assert 0 + + return self.memory + + def notify_shape(self, tensor_name, dims, nb_dims): + assert self.tensor_name == tensor_name + self.dims = dims + self.nb_dims = nb_dims + + def test_program(program: str, input_shape: Iterable[int], debug: bool = True): # Build/parse the main function. with ir.Context() as context: @@ -115,7 +144,7 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True): return session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) - session = runtime.RuntimeSession(session_options, exe) + session = runtime.RuntimeSession(session_options, exe, gpu_allocator=None) arg0 = client.create_memref( np.ones(input_shape, dtype=np.float32).data, @@ -134,6 +163,15 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True): stream=stream, ) + # # Preallocate dummy memory for 1 element. + # arg2 = client.create_memref( + # np.zeros((1, 1, 1), dtype=np.float32).data, + # device=devices[0], + # stream=stream, + # ) + # output_allocator = CupyOutputAllocator() + # arg1.set_output_allocator(output_allocator) + session.execute_function( "main", in_args=[arg0, arg1], out_args=[arg2], stream=stream )