Skip to content

Commit 9f9105a

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Write/Invoke Metal shaders from C++ (pytorch#141547)
By introducing `DynamicMetalShaderLibrary` and `MetalShaderFunction` Add unittests that also serves as an example of how API works Using this primitive, one can compile and dispatch any 1D or 2D shader over MPS tensor using the following pattern ```cpp auto x = torch::empty({8, 16}, at::device(at::kMPS)); DynamicMetalShaderLibrary lib(R"MTL( kernel void full(device float* t, constant ulong2& strides, uint2 idx [[thread_position_in_grid]]) { t[idx.x*strides.x + idx.y*strides.y] = idx.x + 33.0 * idx.y; } )MTL"); auto func = lib.getKernelFunction("full"); func->runCommandBlock([&] { func->startEncoding(); func->setArg(0, x); func->setArg(1, x.strides()); func->dispatch({8, 16}); }); ``` Pull Request resolved: pytorch#141547 Approved by: https://github.com/Skylion007
1 parent 5c2584a commit 9f9105a

File tree

5 files changed

+223
-18
lines changed

5 files changed

+223
-18
lines changed

aten/src/ATen/native/mps/MetalShaderLibrary.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,90 @@
44
typedef id<MTLLibrary> MTLLibrary_t;
55
typedef id<MTLFunction> MTLFunction_t;
66
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
7+
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
78
#else
89
typedef void MTLCompileOptions;
910
typedef void* MTLLibrary_t;
1011
typedef void* MTLFunction_t;
1112
typedef void* MTLComputePipelineState_t;
13+
typedef void* MTLComputeCommandEncoder_t;
1214
#endif
1315

16+
#include <functional>
17+
#include <optional>
18+
#include <type_traits>
1419
#include <unordered_map>
1520
#include <vector>
1621

22+
// Forward declaration of TensorBase
23+
namespace at {
24+
class TensorBase;
25+
}
26+
1727
namespace at::native::mps {
28+
29+
namespace detail {
30+
template <typename T>
31+
class has_size_type {
32+
template <typename U>
33+
static constexpr std::true_type check(typename U::size_type*);
34+
template <typename>
35+
static constexpr std::false_type check(...);
36+
37+
public:
38+
static constexpr bool value = decltype(check<T>(nullptr))::value;
39+
};
40+
41+
template <typename T>
42+
constexpr bool has_size_type_v = has_size_type<T>::value;
43+
44+
} // namespace detail
45+
46+
class MetalKernelFunction {
47+
public:
48+
MetalKernelFunction(MTLComputePipelineState_t cps_);
49+
~MetalKernelFunction();
50+
MetalKernelFunction(MetalKernelFunction&) = delete;
51+
// Shader properties
52+
uint64_t getMaxThreadsPerThreadgroup() const;
53+
uint64_t getThreadExecutionWidth() const;
54+
uint64_t getStaticThreadGroupMemoryLength() const;
55+
void runCommandBlock(std::function<void(void)> f);
56+
// Methods below should be called from runCommandBlock functionT
57+
void startEncoding();
58+
void setArg(unsigned idx, const at::TensorBase& t);
59+
void setArg(unsigned idx, const void* ptr, uint64_t size);
60+
template <
61+
typename T,
62+
typename = std::enable_if_t<
63+
std::is_integral_v<T> || std::is_same_v<T, float> ||
64+
(std::is_class_v<T> && std::is_trivially_copyable_v<T> &&
65+
!detail::has_size_type_v<T>)>>
66+
inline void setArg(unsigned idx, const T val) {
67+
setArg(idx, &val, sizeof(T));
68+
}
69+
70+
template <
71+
typename Container,
72+
typename = std::enable_if_t<detail::has_size_type_v<Container>>>
73+
inline void setArg(unsigned idx, const Container& values) {
74+
setArg(
75+
idx,
76+
values.data(),
77+
values.size() * sizeof(typename Container::value_type));
78+
}
79+
void dispatch(
80+
uint64_t length,
81+
std::optional<uint64_t> groupSize = std::nullopt);
82+
void dispatch(
83+
std::array<uint64_t, 2> length,
84+
std::optional<std::array<uint64_t, 2>> groupSize = std::nullopt);
85+
86+
private:
87+
MTLComputePipelineState_t cps;
88+
MTLComputeCommandEncoder_t encoder = nullptr;
89+
};
90+
1891
class MetalShaderLibrary {
1992
public:
2093
MetalShaderLibrary(const std::string& src)
@@ -31,6 +104,8 @@ class MetalShaderLibrary {
31104
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
32105
virtual ~MetalShaderLibrary() = default;
33106
std::vector<std::string> getFunctionNames();
107+
std::shared_ptr<MetalKernelFunction> getKernelFunction(
108+
const std::string& name);
34109
inline MTLComputePipelineState_t getPipelineStateForFunc(
35110
const std::string& fname) {
36111
return getLibraryPipelineState(getLibrary(), fname).first;
@@ -71,4 +146,13 @@ class MetalShaderLibrary {
71146
cplMap;
72147
};
73148

149+
class DynamicMetalShaderLibrary : public MetalShaderLibrary {
150+
public:
151+
DynamicMetalShaderLibrary(const std::string& src) : MetalShaderLibrary(src) {
152+
// Compile right away
153+
getLibrary();
154+
}
155+
~DynamicMetalShaderLibrary();
156+
};
157+
74158
} // namespace at::native::mps

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -343,23 +343,6 @@ inline bool is_dense_in_storage(const TensorBase& t) {
343343
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
344344
}
345345

346-
namespace detail {
347-
template <typename T>
348-
class has_size_type {
349-
template <typename U>
350-
static constexpr std::true_type check(typename U::size_type*);
351-
template <typename>
352-
static constexpr std::false_type check(...);
353-
354-
public:
355-
static constexpr bool value = decltype(check<T>(nullptr))::value;
356-
};
357-
358-
template <typename T>
359-
constexpr bool has_size_type_v = has_size_type<T>::value;
360-
361-
} // namespace detail
362-
363346
template <typename encoder_t,
364347
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
365348
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
// Copyright © 2022 Apple Inc.
2+
#include <ATen/core/TensorBase.h>
3+
#include <ATen/native/mps/MetalShaderLibrary.h>
4+
#include <functional>
25
#include <stdexcept>
36
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
47
#include <ATen/TensorIterator.h>
@@ -868,6 +871,10 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
868871
return rc;
869872
}
870873

874+
std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const std::string& name) {
875+
return std::make_shared<MetalKernelFunction>(getPipelineStateForFunc(name));
876+
}
877+
871878
class BundledShaderLibary : public MetalShaderLibrary {
872879
public:
873880
BundledShaderLibary() : MetalShaderLibrary("") {}
@@ -916,4 +923,63 @@ static dispatch_data_t getSectionData(const std::string& name) {
916923
return l;
917924
}
918925

926+
// DynamicMetalShaderLibrary implementation
927+
DynamicMetalShaderLibrary::~DynamicMetalShaderLibrary() {
928+
[library release];
929+
}
930+
931+
// MetalKernelFunction implementation
932+
MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_) : cps([cps_ retain]) {}
933+
934+
MetalKernelFunction::~MetalKernelFunction() {
935+
[cps release];
936+
}
937+
938+
void MetalKernelFunction::runCommandBlock(std::function<void(void)> run) {
939+
dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ^() {
940+
@autoreleasepool {
941+
run();
942+
}
943+
});
944+
}
945+
946+
void MetalKernelFunction::startEncoding() {
947+
encoder = getCurrentMPSStream()->commandEncoder();
948+
[encoder setComputePipelineState:cps];
949+
}
950+
951+
void MetalKernelFunction::dispatch(uint64_t length, std::optional<uint64_t> group_size) {
952+
auto group_size_val = group_size.value_or(std::min(length, getMaxThreadsPerThreadgroup()));
953+
[encoder dispatchThreads:MTLSizeMake(length, 1, 1) threadsPerThreadgroup:MTLSizeMake(group_size_val, 1, 1)];
954+
}
955+
956+
void MetalKernelFunction::dispatch(std::array<uint64_t, 2> length, std::optional<std::array<uint64_t, 2>> group_size) {
957+
auto group_size_val =
958+
group_size.value_or(std::array<uint64_t, 2>{std::min(length[0], getMaxThreadsPerThreadgroup()), 1});
959+
[encoder dispatchThreads:MTLSizeMake(length[0], length[1], 1)
960+
threadsPerThreadgroup:MTLSizeMake(group_size_val[0], group_size_val[1], 1)];
961+
}
962+
963+
void MetalKernelFunction::setArg(unsigned idx, const at::TensorBase& t) {
964+
TORCH_CHECK(t.device().type() == kMPS, "Tensor must be on GPU");
965+
mtl_setBuffer(encoder, t, idx);
966+
}
967+
968+
void MetalKernelFunction::setArg(unsigned idx, const void* ptr, uint64_t size) {
969+
TORCH_CHECK(size > 0);
970+
[encoder setBytes:ptr length:size atIndex:idx];
971+
}
972+
973+
uint64_t MetalKernelFunction::getMaxThreadsPerThreadgroup() const {
974+
return [cps maxTotalThreadsPerThreadgroup];
975+
}
976+
977+
uint64_t MetalKernelFunction::getThreadExecutionWidth() const {
978+
return [cps threadExecutionWidth];
979+
}
980+
981+
uint64_t MetalKernelFunction::getStaticThreadGroupMemoryLength() const {
982+
return [cps staticThreadgroupMemoryLength];
983+
}
984+
919985
} // namespace at::native::mps

aten/src/ATen/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ list(APPEND ATen_VEC_TEST_SRCS
110110

111111
list(APPEND ATen_MPS_TEST_SRCS
112112
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_print.cpp
113-
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_allocator.cpp)
113+
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_allocator.cpp
114+
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_metal_library.cpp)
114115
if(APPLE AND USE_MPS)
115116
list(APPEND ATen_MPS_TEST_SRCS
116117
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_objc_interface.mm)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#include <gtest/gtest.h>
2+
#include <stdexcept>
3+
#include <torch/torch.h>
4+
#include <ATen/native/mps/MetalShaderLibrary.h>
5+
6+
using namespace at::native::mps;
7+
TEST(MPSTestMetalLibrary, ShaderCreation) {
8+
MetalShaderLibrary lib("// Empty library");
9+
ASSERT_EQ(lib.getFunctionNames().size(), 0);
10+
}
11+
12+
TEST(MPSTestMetalLibrary, SyntaxErrorThrows) {
13+
ASSERT_THROW(new DynamicMetalShaderLibrary("printf(x);"), c10::Error);
14+
}
15+
16+
TEST(MPSTestMetalLibrary, ArangeShader) {
17+
auto y = torch::arange(10.0, at::device(at::kMPS));
18+
auto x = torch::empty(10, at::device(at::kMPS));
19+
DynamicMetalShaderLibrary lib(R"MTL(
20+
kernel void foo(device float* t, uint idx [[thread_position_in_grid]]) {
21+
t[idx] = idx;
22+
}
23+
)MTL");
24+
auto func = lib.getKernelFunction("foo");
25+
func->runCommandBlock([&] {
26+
func->startEncoding();
27+
func->setArg(0, x);
28+
func->dispatch(x.numel());
29+
});
30+
ASSERT_TRUE((x==y).all().item().toBool());
31+
}
32+
33+
TEST(MPSTestMetalLibrary, ArangeWithArgsShader) {
34+
const auto size = 10;
35+
const float start = .25;
36+
const float step = .4;
37+
auto x = torch::empty(size, at::device(at::kMPS));
38+
auto y = torch::arange(start, start + size * step, step, at::device(at::kMPS));
39+
ASSERT_EQ(x.numel(), y.numel());
40+
DynamicMetalShaderLibrary lib(R"MTL(
41+
kernel void foo(device float* t, constant float& start, constant float& step, uint idx [[thread_position_in_grid]]) {
42+
t[idx] = start + idx * step;
43+
}
44+
)MTL");
45+
auto func = lib.getKernelFunction("foo");
46+
func->runCommandBlock([&] {
47+
func->startEncoding();
48+
func->setArg(0, x);
49+
func->setArg(1, start);
50+
func->setArg(2, step);
51+
func->dispatch(x.numel());
52+
});
53+
ASSERT_TRUE((x==y).all().item().toBool());
54+
}
55+
TEST(MPSTestMetalLibrary, Arange2DShader) {
56+
const auto size = 16;
57+
auto x = torch::empty({size, size}, at::device(at::kMPS));
58+
DynamicMetalShaderLibrary lib(R"MTL(
59+
kernel void full(device float* t, constant ulong2& strides, uint2 idx [[thread_position_in_grid]]) {
60+
t[idx.x*strides.x + idx.y*strides.y] = idx.x + 33.0 * idx.y;
61+
}
62+
)MTL");
63+
auto func = lib.getKernelFunction("full");
64+
func->runCommandBlock([&] {
65+
func->startEncoding();
66+
func->setArg(0, x);
67+
func->setArg(1, x.strides());
68+
func->dispatch({static_cast<uint64_t>(x.size(0)), static_cast<uint64_t>(x.size(1))});
69+
});
70+
ASSERT_EQ(x.sum().item().to<int>(), 65280);
71+
}

0 commit comments

Comments
 (0)