-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathflash_attention2_v5.cu
More file actions
executable file
·471 lines (432 loc) · 13.6 KB
/
flash_attention2_v5.cu
File metadata and controls
executable file
·471 lines (432 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
#include <iostream>
#include <fstream>
#include <iomanip>
#include <limits>
#include <cassert>
#include <cuda_runtime.h>
/*
Implementation of Flash Attention 2
https://arxiv.org/pdf/2307.08691
Q: Mxd
K: Nxd
V: Nxd
O: Mxd
l,m: Mx1
*/
// A10G GPU has max 99KB of shared memory available per block
constexpr int A10G_SRAM_SIZE = 99 * 1024;
// We set SRAM SIZE for the algorithm to 25000 floats
// This is slightly smaller to fit the additional li/mi vectors too
constexpr int SRAM_SIZE = 25000;
constexpr float NEGATIVE_INF = std::numeric_limits<float>::lowest();
int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
/**
* Load a block of the matrix from src to dst.
* src intended to be global, dst intended to be shared memory.
* Matrix is size MxN
*/
__device__ void matrix_block_load(
float* dst,
const float* src,
int M,
int N,
int block_size,
int block_idx
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
int num_elts = M * N;
int block_start = block_idx * block_size * N;
int block_end = block_start + block_size * N;
for (int i = block_start + tid; i < block_end; i += num_threads) {
dst[i - block_start] = (i < num_elts) ? src[i] : 0;
}
}
/**
* Load a block of the matrix from src to dst, and transpose it.
* src intended to be global, dst intended to be shared memory.
* Matrix is size MxN
* Loop block size is the size of the current block, either block size
* or can be less for the very last block if block size not a divisor of M.
* Matrix stored in dst will be size N x loop_block_size.
*/
__device__ void matrix_block_load_transpose(
float* dst,
const float* src,
int M,
int N,
int block_size,
int loop_block_size,
int block_idx,
int dst_padding
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
int num_elts = M * N;
int block_start = block_idx * block_size * N;
int block_end = block_start + block_size * N;
loop_block_size += dst_padding;
for (int i = block_start + tid; i < min(block_end, num_elts); i += num_threads) {
int r = (i - block_start) / N;
int c = i % N;
dst[c * loop_block_size + r] = src[i];
}
}
/**
* Store src into a block of dst.
* src intended to be shared memory, dst intended to be global.
* dst is size M x N, src is size block_size x N.
*/
__device__ void matrix_block_store(
float* dst,
const float* src,
int M,
int N,
int block_size,
int block_idx
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
int block_start = block_idx * block_size * N;
int block_end = min(M * N, block_start + block_size * N);
for (int i = block_start + tid; i < block_end; i += num_threads) {
dst[i] = src[i - block_start];
}
}
/**
* Fill array of size N with fill_value.
*/
__device__ void array_fill(
float* array,
float fill_value,
int N
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
for (int i = tid; i < N; i += num_threads) {
array[i] = fill_value;
}
}
/**
* Computes matrix multiplication A * B.
* A is of size MxK, B is of size KxN.
* Output C is of size MxN.
* If add_to_output, A * B is added to C instead of overwriting it.
* This is a simple version, not optimized for speed.
*/
template <bool add_to_output = false>
__device__ void matrix_multiply(
const float* A,
const float* B,
float* C,
int M,
int N,
int K,
int NPad
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
int num_elts = M * N;
for (int i = tid; i < num_elts; i += num_threads) {
int m = i / N;
int n = i % N;
float sum = 0;
for (int k = 0; k < K; ++k) {
sum += A[m * K + k] * B[k * NPad + n];
}
if constexpr (add_to_output) {
C[i] += sum;
} else {
C[i] = sum;
}
}
}
__device__ void divide_by_scalar(
float* array,
float scalar,
int N
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
for (int i = tid; i < N; i += num_threads) {
array[i] /= scalar;
}
}
/**
* Assigns mi_cur to max(mi_prev, rowmax(Si)).
* mi_cur / mi_prev are vectors of size Br in smem.
* Si is a matrix of size Br x Bc in smem.
*/
__device__ void mi_update(
float* mi_cur,
const float* mi_prev,
const float* Si,
int Br,
int Bc
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
for (int i = tid; i < Br; i += num_threads) {
float max_val = mi_prev[i];
for (int j = 0; j < Bc; ++j) {
max_val = max(max_val, Si[i * Bc + j]);
}
mi_cur[i] = max_val;
}
}
/**
* Converts Si to Pi, where Pi = exp(Si - mi).
* Si is a matrix of size Br x Bc in smem.
* Pi is a matrix of size Br x Bc in smem.
* mi is a vector of size Br in smem.
*/
__device__ void si_to_pi(
float* SiPi,
const float* mi,
int Br,
int Bc
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
for (int i = tid; i < Br * Bc; i += num_threads) {
int r = i / Bc;
SiPi[i] = exp(SiPi[i] - mi[r]);
}
}
/**
* Update li to exp(mi_prev - mi_cur) * li + rowsum(Pi).
* li is a vector of size Br in smem.
* Pi is a matrix of size Br x Bc in smem.
* mi_prev is a vector of size Br in smem.
* mi_cur is a vector of size Br in smem.
*/
__device__ void li_update(
float* li,
const float* Pi,
const float* mi_prev,
const float* mi_cur,
int Br,
int Bc
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
for (int i = tid; i < Br; i += num_threads) {
float sum = 0;
for (int j = 0; j < Bc; ++j) {
sum += Pi[i * Bc + j];
}
li[i] = exp(mi_prev[i] - mi_cur[i]) * li[i] + sum;
}
}
/**
* Update Oi to diag(exp(mi_prev - mi_cur)) * Oi + Pi * V.
* Oi is a matrix of size Br x d in smem.
* Pi is a matrix of size Br x Bc in smem.
* V is a matrix of size Bc x d in smem.
* mi_prev, mi_cur are vectors of size Br in smem.
*/
__device__ void Oi_update(
float* Oi,
const float* Pi,
const float* VT,
const float* mi_prev,
const float* mi_cur,
int Br,
int Bc,
int d
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
int num_elts = Br * d;
for (int i = tid; i < num_elts; i += num_threads) {
int r = i / d;
Oi[i] *= exp(mi_prev[r] - mi_cur[r]);
}
matrix_multiply<true>(Pi, VT, Oi, Br, d, Bc, d);
}
/**
* Divide each row of Oi by that value of li.
* Oi is a matrix of size Br x d in smem.
* li is a vector of size Br in smem.
*/
__device__ void Oi_scale(
float* Oi,
const float* li,
int Br,
int d
) {
int tid = threadIdx.x;
int num_threads = blockDim.x;
int num_elts = Br * d;
for (int i = tid; i < num_elts; i += num_threads) {
int r = i / d;
Oi[i] /= li[r];
}
}
__global__ void flash_attention_2_kernel(
const float* Q,
const float* K,
const float* V,
float* O,
const int M,
const int N,
const int d,
const int Br,
const int Bc,
const int Tr,
const int Tc,
const int alloc_size
) {
extern __shared__ float s[];
float *Oi = s;
float *Qi = &s[alloc_size];
// will first store Ki, then get overriden to ViT
float *KiVi = &s[2 * alloc_size];
// will first store Si, then get overriden to Pi
float *SiPi = &s[3 * alloc_size];
float *li = &s[4 * alloc_size];
float *mi = &s[4 * alloc_size + Br];
float *mi2 = &s[4 * alloc_size + 2 * Br];
float* mi_prev = mi; // m(i,j-1)
float* mi_cur = mi2; // m(i,j)
int i = blockIdx.x;
int loopBr = min(Br, M - i * Br);
matrix_block_load(Qi, Q, M, d, Br, i);
array_fill(Oi, 0, loopBr * d);
array_fill(li, 0, loopBr);
array_fill(mi_prev, NEGATIVE_INF, loopBr);
__syncthreads();
for (int j = 0; j < Tc; j++) {
int loopBc = min(Bc, N - j * Bc);
int dst_padding = (loopBc % 2 == 0);
matrix_block_load_transpose(KiVi, K, N, d, Bc, loopBc, j, dst_padding);
__syncthreads();
matrix_multiply(Qi, KiVi, SiPi, loopBr, loopBc, d, loopBc + dst_padding);
__syncthreads();
divide_by_scalar(SiPi, sqrtf(d), loopBr * loopBc);
__syncthreads();
mi_update(mi_cur, mi_prev, SiPi, loopBr, loopBc);
__syncthreads();
si_to_pi(SiPi, mi_cur, loopBr, loopBc);
__syncthreads();
li_update(li, SiPi, mi_prev, mi_cur, loopBr, loopBc);
matrix_block_load(KiVi, V, N, d, Bc, j);
__syncthreads();
Oi_update(Oi, SiPi, KiVi, mi_prev, mi_cur, loopBr, loopBc, d);
__syncthreads();
// swap mi_prev / mi_cur
auto tmp = mi_prev;
mi_prev = mi_cur;
mi_cur = tmp;
}
Oi_scale(Oi, li, loopBr, d);
__syncthreads();
matrix_block_store(O, Oi, M, d, Br, i);
}
// Q, K, V, O are device pointers
void flash_attention_2(const float* Q, const float* K, const float* V, float* O, int M, int N, int d) {
int Bc = ceildiv(SRAM_SIZE, 4 * d);
int Br = min(Bc, d);
int Tr = ceildiv(M, Br);
int Tc = ceildiv(N, Bc);
int alloc_size = max(Br * Bc, Bc * d);
int shmem_needed = (4 * alloc_size + 3 * Br) * sizeof(float);
// call kernel
const int threadsPerBlock = 1024;
const int blocksPerGrid = Tr;
std::cout << "Shared memory needed: " << shmem_needed << " bytes" << std::endl;
flash_attention_2_kernel<<<blocksPerGrid, threadsPerBlock, shmem_needed>>>(
Q, K, V, O, M, N, d, Br, Bc, Tr, Tc, alloc_size
);
}
int main(int argc, char* argv[]) {
// Print device properties
int device_id = 0;
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
std::cout << "Device ID: " << device_id << std::endl;
std::cout << "Compute capability: " << device_prop.major << "." << device_prop.minor << std::endl;
std::cout << "Device name: " << device_prop.name << std::endl;
std::cout << "Total global memory: " << device_prop.totalGlobalMem / (1024 * 1024) << " MB" << std::endl;
std::cout << "Total number of multiprocessors: " << device_prop.multiProcessorCount << std::endl;
std::cout << "Shared memory per block: " << device_prop.sharedMemPerBlock << " KB" << std::endl;
std::cout << "Max threads per block: " << device_prop.maxThreadsPerBlock << std::endl;
std::cout << "Max threads dim: " << device_prop.maxThreadsDim[0] << ", " << device_prop.maxThreadsDim[1] << ", " << device_prop.maxThreadsDim[2] << std::endl;
std::cout << "Max grid size: " << device_prop.maxGridSize[0] << ", " << device_prop.maxGridSize[1] << ", " << device_prop.maxGridSize[2] << std::endl;
std::cout << "Warp size: " << device_prop.warpSize << std::endl;
std::cout << "Max threads per multiprocessor: " << device_prop.maxThreadsPerMultiProcessor << std::endl;
std::cout << "Max shared memory per multiprocessor: " << device_prop.sharedMemPerMultiprocessor / (1024) << " KB" << std::endl;
std::cout << "Max registers per multiprocessor: " << device_prop.regsPerMultiprocessor << std::endl;
// We're using A10G GPU, which has 99KB available shared memory per block
cudaFuncSetAttribute(flash_attention_2_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, A10G_SRAM_SIZE);
// Benchmark parameters
constexpr int M = 8192;
constexpr int N = 8192;
constexpr int d = 32;
std::cout << "M: " << M << ", N: " << N << ", d: " << d << std::endl;
// Initialize a testcase
float* Q = new float[M * d];
float* K = new float[N * d];
float* V = new float[N * d];
float* O = new float[M * d];
for (int i = 0; i < M * d; ++i) {
Q[i] = static_cast<float>(i) / (M * d);
}
for (int i = 0; i < N * d; ++i) {
K[i] = static_cast<float>(i) * 2 / (N * d);
V[i] = static_cast<float>(i) * 3 / (N * d);
}
std::cout << "Matrices initialized on CPU." << std::endl;
// Allocate memory on the device
float *d_Q, *d_K, *d_V, *d_O;
cudaMalloc((void**)&d_Q, M * d * sizeof(float));
cudaMalloc((void**)&d_K, N * d * sizeof(float));
cudaMalloc((void**)&d_V, N * d * sizeof(float));
cudaMalloc((void**)&d_O, M * d * sizeof(float));
// Copy input vectors from host to device
cudaMemcpy(d_Q, Q, M * d * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_K, K, N * d * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_V, V, N * d * sizeof(float), cudaMemcpyHostToDevice);
std::cout << "Matrices copied to GPU, running kernel..." << std::endl;
// Call Flash Attention 2
flash_attention_2(d_Q, d_K, d_V, d_O, M, N, d);
// Copy result from device to host
cudaMemcpy(O, d_O, M * d * sizeof(float), cudaMemcpyDeviceToHost);
// Print or write output
if (argc > 1) {
std::ofstream outfile(argv[1]);
if (outfile.is_open()) {
for (int i = 0; i < M * d; ++i) {
outfile << std::setprecision(10) << O[i];
if ((i + 1) % d == 0) {
outfile << "\n";
} else {
outfile << " ";
}
}
outfile.close();
std::cout << "Output written to " << argv[1] << std::endl;
} else {
std::cerr << "Failed to open file: " << argv[1] << std::endl;
}
} else {
std::cout << "Output:" << std::endl;
for (int i = 0; i < 10 && i < M * d; ++i) {
std::cout << std::setprecision(10) << O[i] << " ";
}
std::cout << "..." << std::endl;
}
// Free allocated memory
cudaFree(d_Q);
cudaFree(d_K);
cudaFree(d_V);
cudaFree(d_O);
// Free dynamically allocated memory
delete[] Q;
delete[] K;
delete[] V;
delete[] O;
return 0;
}