Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 107 additions & 28 deletions examples/77_blackwell_fmha/77_blackwell_fmha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ struct Options {
bool verify = false;
bool verbose = false;

int window_size_left = -1;
int window_size_right = -1;
bool local = false;

bool causal = false;
bool causal_q_begin = true;
bool residual = false;
Expand Down Expand Up @@ -261,6 +265,9 @@ struct Options {
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);

cmd.get_cmd_line_argument("window_size_left", window_size_left, defaults.window_size_left);
cmd.get_cmd_line_argument("window_size_right", window_size_right, defaults.window_size_right);

verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
persistent = cmd.check_cmd_line_flag("persistent");
Expand All @@ -270,12 +277,13 @@ struct Options {
std::string causal_type;
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
if (mask == "no" || mask == "") {
causal = residual = false;
local = causal = residual = false;
if (varlen) {
residual = true;
}
}
else if (mask == "causal") {
local = false;
residual = false;
causal = true;
if(causal_type == "qend") {
Expand All @@ -285,9 +293,24 @@ struct Options {
}
}
else if (mask == "residual") {
local = false;
residual = true;
causal = false;
}
else if (mask == "local") {
local = true;
residual = false;
causal = false;
if(causal_type == "qend") {
causal_q_begin = false;
} else {
causal_q_begin = true;
}
if (varlen) {
residual = true;
}
}

cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
Expand Down Expand Up @@ -318,9 +341,11 @@ struct Options {
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --window_size_left=<int> Window size left for local attention\n"
<< " --window_size_right=<int> Window size right for local attention\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --mask=<no|residual|causal|local> Enables masking\n"
<< " --causal-type=<qbegin|qend> Causal mask type\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
Expand Down Expand Up @@ -423,7 +448,7 @@ struct FwdRunner {
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;

using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D ((H_R, H_K), B)
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D ((H_R, H_K), B)
using StrideV = StrideK;
Expand All @@ -433,7 +458,7 @@ struct FwdRunner {
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;

using Mainloop =
using Mainloop =
cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV,
TileShape, StrideQ, StrideK, StrideV,
Expand Down Expand Up @@ -494,7 +519,7 @@ struct FwdRunner {
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer, Options const & options) {
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Expand All @@ -514,12 +539,12 @@ struct FwdRunner {
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);

auto [Q, K, D, HB] = problem_shape;

auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);

fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask(options.window_size_left, options.window_size_right));

cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
Expand All @@ -538,15 +563,15 @@ struct FwdRunner {

bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
std::cerr << "failed O: max diff " << max_diff
std::cerr << "failed O: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}

reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);

bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}

Expand All @@ -562,7 +587,7 @@ struct FwdRunner {

// generate Q as --b times
// gaussian (--Q, --Q / 2) sampled positive
// track cumulative
// track cumulative
std::mt19937 rng(0x202305151552ull);
std::normal_distribution<double> dist_q(get<0>(problem_size), get<0>(problem_size) / 2);
std::normal_distribution<double> dist_kv(get<1>(problem_size), get<1>(problem_size) / 2);
Expand All @@ -585,7 +610,7 @@ struct FwdRunner {
int max_seqlen_kv = 0;

for (int i = 0; i < num_batches; i++) {
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
kVarlenSame ? get<0>(problem_size) :
generate_positive_int(dist_q, rng);
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
Expand Down Expand Up @@ -626,7 +651,7 @@ struct FwdRunner {
int h_r = options.h / options.h_k;
assert(options.h % options.h_k == 0);
auto problem_shape_in = cute::make_tuple(options.q, options.k, options.d, cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));

ProblemShapeType problem_shape;
decltype(problem_shape_in) problem_size;

Expand Down Expand Up @@ -690,7 +715,7 @@ struct FwdRunner {
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
buffer.device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
}
};

buffers.push_back(std::make_unique<DeviceBuffer>());
Expand All @@ -710,20 +735,40 @@ struct FwdRunner {
return problem_shape;
}

auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index, const Options& options) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
}
typename Operation::Arguments arguments{
problem_shape_,
{ buffers[buffer_index]->block_Q.get(), stride_Q,
buffers[buffer_index]->block_K.get(), stride_K,
buffers[buffer_index]->block_V.get(), stride_V },
{ buffers[buffer_index]->block_O.get(), stride_O,
buffers[buffer_index]->block_LSE.get(), stride_LSE },
hw_info
problem_shape_, // 1st field: Problem dimensions

// 2nd field: Mainloop arguments - input tensor data and scaling parameters
{
// Nested Load arguments for tensor pointers and strides
{ buffers[buffer_index]->block_Q.get(), stride_Q, // Query tensor pointer and stride
buffers[buffer_index]->block_K.get(), stride_K, // Key tensor pointer and stride
buffers[buffer_index]->block_V.get(), stride_V, // Value tensor pointer and stride
options.window_size_left, // window_size_left: for local attention
options.window_size_right // window_size_right: for local attention
},

// Scaling parameters for attention computation
0.0f, // scale_softmax: 0.0f means use default 1/sqrt(D)
1.0f, // scale_q: scaling factor for Q tensor dequantization
1.0f, // scale_k: scaling factor for K tensor dequantization
1.0f, // scale_v: scaling factor for V tensor dequantization
1.0f, // inv_scale_o: inverse scaling factor for O tensor quantization
options.window_size_left, // window_size_left: for local attention
options.window_size_right // window_size_right: for local attention
},

// 3rd field: Epilogue arguments - output tensors O, LSE with their memory pointers and strides
{ buffers[buffer_index]->block_O.get(), stride_O, // Output tensor pointer and stride
buffers[buffer_index]->block_LSE.get(), stride_LSE },// Log-sum-exp tensor pointer and stride

hw_info // 4th field: Hardware info (SM count, etc.)
};
return arguments;
}
Expand All @@ -733,7 +778,7 @@ struct FwdRunner {
ProblemShapeType problem_shape = initialize(options);

int buffer_index = 0;
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index, options);

Operation op;

Expand Down Expand Up @@ -769,7 +814,7 @@ struct FwdRunner {
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
arguments = get_arguments(problem_shape, hw_info, buffer_index, options);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
Expand Down Expand Up @@ -814,7 +859,7 @@ struct FwdRunner {
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
arguments = get_arguments(problem_shape, hw_info, buffer_index, options);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
Expand Down Expand Up @@ -866,7 +911,30 @@ struct FwdRunner {
flops *= static_cast<double>(size<1>(problem_shape));
flops *= static_cast<double>(size<3,1>(problem_shape));
}
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);

double flops_ratio = 1.0;
if (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>>) {
flops_ratio = 0.5;
}
if (std::is_same_v<ActiveMask, LocalMask<true>> || std::is_same_v<ActiveMask, LocalMask<false>>) {
// For regular sequences
int seqlen_q = size<0>(problem_shape);
int seqlen_k = size<1>(problem_shape);

double total_valid_pairs = 0.0;
for (int row_idx = 0; row_idx < seqlen_q; row_idx++) {
int col_left = std::max(row_idx - options.window_size_left, 0);
int col_right = std::min(row_idx + options.window_size_right, seqlen_k - 1);
// Valid positions in this row
if (col_right >= col_left) {
total_valid_pairs += (col_right - col_left + 1);
}
}
double total_positions = static_cast<double>(seqlen_q) * static_cast<double>(seqlen_k);
flops_ratio = (total_positions > 0) ? (total_valid_pairs / total_positions) : 1.0;
flops_ratio = std::min(flops_ratio, 1.0);
}
flops *= 4.0 * flops_ratio;
flops *= static_cast<double>(size<2>(problem_shape));
flops *= static_cast<double>(size<3,0>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
Expand All @@ -883,7 +951,7 @@ struct FwdRunner {
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape, *buffers[0]);
passed = verify(problem_shape, *buffers[0], options);
if (passed) example_result.verified = true;
}

Expand Down Expand Up @@ -935,7 +1003,7 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
}
else
else
{
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
Expand Down Expand Up @@ -968,7 +1036,7 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
}
else
else
{
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
Expand Down Expand Up @@ -1083,6 +1151,7 @@ int main_single(int argc, char const **args) {

std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
std::cout << (options.local ? ("Local with window size " + std::to_string(options.window_size_left) + " " + std::to_string(options.window_size_right)) : "Not local") << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;

auto with_mask = [&](auto fn) {
Expand All @@ -1093,6 +1162,16 @@ int main_single(int argc, char const **args) {
fn(CausalMask<false>{});
}
}
else if (options.local) {
if (options.window_size_left == -1 || options.window_size_right == -1) {
throw std::runtime_error("Error: --window_size_left and --window_size_right must be set for local attention.");
}
if(options.causal_q_begin) {
fn(LocalMask{});
} else {
fn(LocalMask<false>{});
}
}
else if (options.residual) {
fn(ResidualMask{});
}
Expand Down
Loading