-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmm.metal
73 lines (62 loc) · 2.68 KB
/
mm.metal
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <metal_stdlib>
using namespace metal;
typedef struct MatrixParams {
int a_rows, a_cols;
int b_rows, b_cols;
} MatrixParams;
kernel void matrix_multiply_naive(
device const MatrixParams *params,
constant float *A,
constant float *B,
device float *C,
// Indicates the thread's unique position within the entire grid of threads being executed
// The uint2 type is a 2D coordinate, with fields x and y representing its indices on each axis
// This parameter is not directly provided from the calling code, but provided by the Metal framework
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= params->a_rows || gid.y >= params->b_cols) {
return; // This thread is out of matrix dimensionality range, do nothing
}
float sum = 0.0;
int k;
// Loop unrolling; improves performance by a notable margin
for (k = 0; k <= params->a_cols - 4; k += 4) {
sum += A[gid.x * params->a_cols + k] * B[k * params->b_cols + gid.y];
sum += A[gid.x * params->a_cols + k + 1] * B[(k + 1) * params->b_cols + gid.y];
sum += A[gid.x * params->a_cols + k + 2] * B[(k + 2) * params->b_cols + gid.y];
sum += A[gid.x * params->a_cols + k + 3] * B[(k + 3) * params->b_cols + gid.y];
}
// Handle any remaining elements
for (; k < params->a_cols; ++k) {
sum += A[gid.x * params->a_cols + k] * B[k * params->b_cols + gid.y];
}
C[gid.x * params->b_cols + gid.y] = sum;
}
kernel void matrix_multiply_transpose(
device const MatrixParams *params,
constant float *A,
constant float *B,
device float *C,
// Indicates the thread's unique position within the entire grid of threads being executed
// The uint2 type is a 2D coordinate, with fields x and y representing its indices on each axis
// This parameter is not directly provided from the calling code, but provided by the Metal framework
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= params->a_rows || gid.y >= params->b_cols) {
return; // This thread is out of matrix dimensionality range, do nothing
}
float sum = 0.0;
int k;
// Loop unrolling; improves performance by a notable margin
for (k = 0; k <= params->a_cols - 4; k += 4) {
sum += A[gid.x * params->a_cols + k] * B[gid.y * params->b_cols + k];
sum += A[gid.x * params->a_cols + k + 1] * B[gid.y * params->b_cols + k + 1];
sum += A[gid.x * params->a_cols + k + 2] * B[gid.y * params->b_cols + k + 2];
sum += A[gid.x * params->a_cols + k + 3] * B[gid.y * params->b_cols + k + 3];
}
// Handle any remaining elements
for (; k < params->a_cols; ++k) {
sum += A[gid.x * params->a_cols + k] * B[gid.y * params->b_cols + k];
}
C[gid.x * params->b_cols + gid.y] = sum;
}