Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement custom memory management for internal and output allocations #94

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir-tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ We currently support only building on Linux x86 systems.
We support building several different ways (only via CMake) depending on use-case.

In each case, the LLVM-Project version that we are currently aligned to is
given in `build_tools/cmake/LLVMCommit.txt`.
given in `build_tools/cmake/LLVMCommit.cmake`.

Note that currently we provide an LLVM patch which essentially cherry-picks the
bug fixes from [this open MLIR PR](https://github.com/llvm/llvm-project/pull/91524).
Expand Down Expand Up @@ -82,7 +82,7 @@ git clone https://github.com/llvm/llvm-project.git llvm-project
# Checkout the right commit. Of course, you may try
# a newer commit or your own modified LLVM-Project.
cd llvm-project
git checkout $(cat build_tools/cmake/LLVMCommit.cmake | grep -Po '(?<=").*(?=")')
git checkout $(cat ../build_tools/cmake/LLVMCommit.cmake | grep -Po '(?<=").*(?=")')

# Apply patch from llvm-project PR 91524
git apply ../build_tools/llvm-project.patch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <stddef.h>
#include <stdint.h>

#include "cuda_runtime.h"

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -93,6 +95,50 @@ static inline bool mtrtDeviceIsNull(MTRT_Device device) { return !device.ptr; }
/// arguments are optional in functions below.
static inline MTRT_Device mtrtDeviceGetNull() { return MTRT_Device{nullptr}; }

//===----------------------------------------------------------------------===//
// MTRT_GpuAllocator
//===----------------------------------------------------------------------===//

// Function pointer types for the allocate and deallocate callbacks.
typedef void *(*AllocateFunc)(void *self, uint64_t size, uint64_t alignment, uint32_t flags, cudaStream_t* stream);
typedef bool (*DeallocateFunc)(void *self, void *memory, cudaStream_t* stream);

typedef struct MTRT_GpuAllocator {
void *ptr; // Pointer to the implementation (PyGpuAllocatorTrampoline in our
// case.)
// Function pointers to methods.
AllocateFunc allocate;
DeallocateFunc deallocate;
} MTRT_GpuAllocator;

//===----------------------------------------------------------------------===//
// MTRT_OutputAllocator
//===----------------------------------------------------------------------===//

// Function pointer types for the allocate and deallocate callbacks.
typedef void (*SetGpuAllocator)(void *self, MTRT_GpuAllocator gpuAllocator);
typedef void (*SetTensorName)(void *self, const char *tensorName);
typedef void (*SetCurrentMemory)(void *self, void *currentMemory);
typedef void (*SetOutputSize)(void *self, const int64_t outputSize);
typedef void *(*ReallocateOutputAsync)(void *self, char const *tensorName,
void *currentMemory, uint64_t size,
uint64_t alignment,
cudaStream_t *stream);
typedef void (*NotifyShape)(void *self, char const *tensorName, const int64_t *dims,
int64_t nbDims);

typedef struct MTRT_OutputAllocator {
void *ptr; // Pointer to the implementation (PyOutputAllocatorTrampoline in
// our case.)
// Function pointers to methods.
SetGpuAllocator setGpuAllocator;
SetTensorName setTensorName;
SetCurrentMemory setCurrentMemory;
SetOutputSize setOutputSize;
ReallocateOutputAsync reallocateOutputAsync;
NotifyShape notifyShape;
} MTRT_OutputAllocator;

//===----------------------------------------------------------------------===//
// MTRT_MemRefValue
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -170,6 +216,9 @@ typedef struct MTRT_MemRefValueInfo {
MLIR_CAPI_EXPORTED MTRT_Status
mtrtMemRefValueGetInfo(MTRT_MemRefValue memref, MTRT_MemRefValueInfo *info);

MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueSetOutputAllocator(
MTRT_MemRefValue memrefValue, MTRT_OutputAllocator pyOutputAllocator);

/// Create DL Managed tensor from MemRefValue.
MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefValueGetDLPackManagedTensor(
MTRT_MemRefValue memrefValue, MTRT_DLPackManagedTensor *outTensor);
Expand Down Expand Up @@ -360,7 +409,7 @@ typedef struct MTRT_RuntimeSession {
/// constant data. Therefore the Executable must outlive the RuntimeSession.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionCreate(
MTRT_RuntimeSessionOptions options, MTRT_Executable executable,
MTRT_RuntimeSession *result);
MTRT_GpuAllocator allocator, MTRT_RuntimeSession *result);

/// Destory the session. This does not destroy the associated Executable, which
/// may be shared among many sessions.
Expand All @@ -372,6 +421,10 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
return !session.ptr;
}

MLIR_CAPI_EXPORTED MTRT_Status mtrtAddMemRefOutputAllocatorSessionRegistry(
MTRT_MemRefValue memrefValue,
MTRT_OutputAllocator pyOutputAllocator);

/// Using `session`, execute the pubic function with the specified name.
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
Expand Down
29 changes: 27 additions & 2 deletions mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,12 @@ class MemRefValue : public RuntimeValue {
return v->getKind() == Kind::MemRef;
}

void setOutputAllocator(OutputAllocator* _outputAllocator) {
outputAllocator = _outputAllocator;
}

OutputAllocator *getOutputAllocator() { return outputAllocator; }

const std::optional<ScalarType> &getScalarType() const { return scalarType; }

RuntimeClient *getClient() { return client; }
Expand All @@ -691,6 +697,7 @@ class MemRefValue : public RuntimeValue {
/// address.
std::optional<const Device *> device;
std::optional<ScalarType> scalarType{};
OutputAllocator *outputAllocator{nullptr};
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -867,7 +874,9 @@ class RuntimeSession {
sol::state state,
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator,
std::unique_ptr<AllocTracker> allocTracker,
std::unique_ptr<ResourceTracker> resourceTracker);
std::unique_ptr<ResourceTracker> resourceTracker,
std::unique_ptr<OutputAllocatorTracker> outputAllocatorTracker,
std::unique_ptr<GpuAllocator> gpuAllocator);

ExecutableView getExecutable() const { return executable; }

Expand All @@ -881,14 +890,21 @@ class RuntimeSession {

ResourceTracker &getResourceTracker() { return *resourceTracker; }

OutputAllocatorTracker &getOutputAllocatorTracker() {
return *outputAllocatorTracker;
}

GpuAllocator &getGpuAllocator() { return *gpuAllocator; }

private:
RuntimeSessionOptions options;
ExecutableView executable;

std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator;
std::unique_ptr<AllocTracker> allocTracker;
std::unique_ptr<ResourceTracker> resourceTracker;

std::unique_ptr<OutputAllocatorTracker> outputAllocatorTracker;
std::unique_ptr<GpuAllocator> gpuAllocator;
sol::state state;
};

Expand Down Expand Up @@ -970,6 +986,14 @@ class RuntimeClient {
return pinnedMemoryAllocator;
}

void addOutputAllocator(std::unique_ptr<OutputAllocator> outputAllocator) {
outputAllocators.emplace_back(std::move(outputAllocator));
}

OutputAllocator* getLastOutputAllocator() {
return outputAllocators.back().get();
}

private:
RuntimeClient(llvm::SmallVector<std::unique_ptr<Device>> devices)
: devices(std::move(devices)) {}
Expand All @@ -978,6 +1002,7 @@ class RuntimeClient {
PinnedMemoryAllocator pinnedMemoryAllocator;
AllocTracker allocTracker;
ResourceTracker resourceTracker;
std::vector<std::unique_ptr<OutputAllocator>> outputAllocators;
};

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ void registerLuaRuntimeMethods(lua_State *state,
const RuntimeSessionOptions &options,
PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker,
ResourceTracker *resourceTracker);
ResourceTracker *resourceTracker,
OutputAllocatorTracker *outputAllocatorTracker,
GpuAllocator *allocator);

} // namespace mlirtrt::runtime
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace mlirtrt::runtime {
/// `main` function. It is assumed that `main` takes no arguments and returns an
/// integer result (which is returned if the execution is successful).
/// TODO: this should take a handle to a function for streaming output/errors.
StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript);
StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript,
GpuAllocator *allocator);

/// Synchronously run a serialized executor Executable one time. An `Executable`
/// is essentially a Lua script packaged with metadata and serialized constants
Expand All @@ -48,12 +49,15 @@ StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript);
/// execution is successful).
/// TODO: this should take a handle to a function for
/// streaming output/errors.
StatusOr<int64_t> runExecutorExecutable(std::unique_ptr<Executable> executable);
StatusOr<int64_t>
runExecutorExecutable(std::unique_ptr<Executable> executable,
std::unique_ptr<GpuAllocator> allocator);

/// Create an execution state. This will setup a Lua environment and invoke
/// global initialization.
StatusOr<std::unique_ptr<RuntimeSession>>
createRuntimeSessionWithLuaBackend(ExecutableView executable,
std::unique_ptr<GpuAllocator> allocator,
const RuntimeSessionOptions &options);

/// Set the primary stream for the loaded executable to use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class ResourceTracker;
/// Lua state.
void registerExecutorTensorRTModuleLuaRuntimeMethods(
lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker, ResourceTracker *resourceTracker);
AllocTracker *allocTracker, ResourceTracker *resourceTracker,
OutputAllocatorTracker *outputAllocatorTracker, GpuAllocator *allocator);

} // namespace mlirtrt::runtime

Expand Down
133 changes: 133 additions & 0 deletions mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,139 @@ namespace mlirtrt {

struct EventPool;

//===----------------------------------------------------------------------===//
// GpuAllocator and CustomTensorRTAllocator
//===----------------------------------------------------------------------===//

class GpuAllocator {
public:
GpuAllocator() = default;
virtual ~GpuAllocator() = default;
virtual void *allocate(uint64_t const size, uint64_t const alignment,
uint32_t flags, cudaStream_t* stream) {
return nullptr;
}
virtual bool deallocate(void *const memory,
cudaStream_t* stream) {
return false;
}
};

class CustomTensorRTAllocator : public GpuAllocator {
public:
CustomTensorRTAllocator() = default;
~CustomTensorRTAllocator() = default;
void *allocate(uint64_t const size, uint64_t const alignment, uint32_t flags,
cudaStream_t* stream) override;
bool deallocate(void *const memory,
cudaStream_t* stream) override;
};

//===----------------------------------------------------------------------===//
// OutputAllocator and CustomTensorRTOuputAllocator
//===----------------------------------------------------------------------===//

//!
//! Class to allocate memory for outputs with data-dependent shapes. The sizes
//! of those are unknown so pre-allocation is not possible.
//!
class OutputAllocator {
public:
virtual ~OutputAllocator() = default;
virtual void setGpuAllocator(GpuAllocator* gpuAllocator) = 0;
virtual void setTensorName(const char *tensorName) = 0;
virtual void setCurrentMemory(void *currentMemory) = 0;
virtual void setOutputSize(const int64_t outputSize) = 0;
virtual void *reallocateOutputAsync(char const *tensorName,
void *currentMemory, uint64_t size,
uint64_t alignment,
cudaStream_t * /*stream*/) = 0;
virtual void notifyShape(char const *tensorName, const int64_t *dims,
int64_t nbDims) = 0;
};

class CustomTensorRTOuputAllocator : public OutputAllocator {
public:
CustomTensorRTOuputAllocator() = default;
~CustomTensorRTOuputAllocator() {
if (mOutputPtr != nullptr) {
cudaFree(mOutputPtr);
}
}

void setGpuAllocator(GpuAllocator* gpuAllocator) override {
mGpuAllocator = gpuAllocator;
}

//! Methods are called just after construction. TODO: can they be called
//! during construction?
void setTensorName(const char *tensorName) override {
mTensorName = tensorName;
}

void setCurrentMemory(void *currentMemory) override {
mCurrentMemory = currentMemory;
}

void setOutputSize(int64_t outputSize) override { mOutputSize = outputSize; }

void *reallocateOutputAsync(char const *tensorName, void *currentMemory,
uint64_t size, uint64_t alignment,
cudaStream_t * /*stream*/) override;

void notifyShape(char const *tensorName, const int64_t *dims,
int64_t nbDims) override;

//! nullptr if memory could not be allocated
void *mOutputPtr{nullptr};

//! Size of allocation pointed to by output.
uint64_t mOutputSize{0};

bool mReallocateOutputCalled{false};

bool mNotifyShapeCalled{false};

//! Dimensions of tensor.
std::vector<int64_t> mOutputDims;

private:
GpuAllocator* mGpuAllocator;
const char *mTensorName;
void *mCurrentMemory;
};

class OutputAllocatorTracker {
public:
OutputAllocatorTracker() = default;
~OutputAllocatorTracker() = default;

OutputAllocatorTracker(const OutputAllocatorTracker &) = delete;
OutputAllocatorTracker &operator=(const OutputAllocatorTracker &) = delete;
OutputAllocatorTracker(OutputAllocatorTracker &&) = default;
OutputAllocatorTracker &operator=(OutputAllocatorTracker &&) = default;

// Add a new OutputAllocator
void addAllocator(void *ptr, OutputAllocator *allocator) {
mOutputAllocatorRegistry.emplace_back(std::make_pair(ptr, allocator));
}

// Get a reference to an OutputAllocator
OutputAllocator *getAllocator(void *ptr) {
auto it = std::find_if(
mOutputAllocatorRegistry.begin(), mOutputAllocatorRegistry.end(),
[ptr](const auto &pair) { return pair.first == ptr; });

if (it != mOutputAllocatorRegistry.end()) {
return it->second;
}
return nullptr;
}

private:
std::vector<std::pair<void *, OutputAllocator *>> mOutputAllocatorRegistry;
};

//===----------------------------------------------------------------------===//
// PoolTrackedCudaEvent
//===----------------------------------------------------------------------===//
Expand Down
Loading