Skip to content

Commit e879bf8

Browse files
[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True (#2677)
* cudnn now returns Stats always and Max only with `return_max_logit=true` Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix a typo that caused a bug Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update doc strings Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix more docs Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fixes from the feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update cudnn-frontend to v1.19.1 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update the cudnn frontend Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix a wrong omission Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4ead776 commit e879bf8

File tree

5 files changed

+41
-59
lines changed

5 files changed

+41
-59
lines changed

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
112112
}
113113

114114
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
115-
bool generate_stats = !return_max_logit;
115+
bool generate_stats = true; // Always return stats
116116
try {
117117
FADescriptor_v1 descriptor{
118118
b,
@@ -343,7 +343,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
343343
sdpa_options.set_sink_token(softmax_offset);
344344
}
345345

346-
std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp;
346+
std::shared_ptr<fe::graph::Tensor_attributes> Max;
347347
if (use_ragged_stats) {
348348
offset_stats =
349349
mha_graph->tensor(fe::graph::Tensor_attributes()
@@ -357,19 +357,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
357357
.set_name("Max")
358358
.set_dim({b, h, s_q, 1})
359359
.set_data_type(fe::DataType_t::FLOAT));
360-
Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes()
361-
.set_name("Sum_Exp")
362-
.set_dim({b, h, s_q, 1})
363-
.set_data_type(fe::DataType_t::FLOAT));
364360
if (use_ragged_stats) {
365361
Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
366-
Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
367362
} else {
368363
Max->set_stride({h * s_q, s_q, 1, 1});
369-
Sum_Exp->set_stride({h * s_q, s_q, 1, 1});
370364
}
371365
sdpa_options.set_logit_max(Max);
372-
sdpa_options.set_score_sum_exp(Sum_Exp);
373366
}
374367

375368
auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options));
@@ -387,13 +380,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
387380
O->set_ragged_offset(offset_o);
388381
}
389382

390-
if (!return_max_logit) {
391-
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
392-
if (use_ragged_stats) {
393-
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
394-
} else {
395-
Stats->set_stride({h * s_q, s_q, 1, 1});
396-
}
383+
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
384+
if (is_ragged_q && cudnn_runtime_version >= 90600) {
385+
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
386+
} else {
387+
Stats->set_stride({h * s_q, s_q, 1, 1});
397388
}
398389

399390
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
@@ -403,7 +394,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
403394
std::shared_ptr<fe::graph::Tensor_attributes>> // O
404395
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
405396
auto Stats_tuple =
406-
generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp);
397+
return_max_logit ? std::make_tuple(Stats, Max) : std::make_tuple(Stats, nullptr);
407398
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
408399
auto softmax_offset_tuple =
409400
is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
@@ -1137,6 +1128,16 @@ void fused_attn_arbitrary_seqlen_fwd(
11371128
size_t i = 0;
11381129
if (Aux_CTX_Tensors->size == 0) {
11391130
const auto cudnn_runtime_version = cudnnGetVersion();
1131+
1132+
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1133+
output_S->data.dptr = nullptr;
1134+
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
1135+
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
1136+
} else {
1137+
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
1138+
}
1139+
output_S->data.dtype = DType::kFloat32;
1140+
11401141
if (return_max_logit) {
11411142
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
11421143
output_Max->data.dptr = nullptr;
@@ -1147,25 +1148,6 @@ void fused_attn_arbitrary_seqlen_fwd(
11471148
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
11481149
}
11491150
output_Max->data.dtype = DType::kFloat32;
1150-
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1151-
output_Sum_Exp->data.dptr = nullptr;
1152-
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
1153-
(sm_arch_ != 120)) {
1154-
output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
1155-
} else {
1156-
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
1157-
}
1158-
output_Sum_Exp->data.dtype = DType::kFloat32;
1159-
} else {
1160-
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1161-
output_S->data.dptr = nullptr;
1162-
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
1163-
(sm_arch_ != 120)) {
1164-
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
1165-
} else {
1166-
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
1167-
}
1168-
output_S->data.dtype = DType::kFloat32;
11691151
}
11701152

11711153
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
@@ -1189,14 +1171,12 @@ void fused_attn_arbitrary_seqlen_fwd(
11891171

11901172
Aux_CTX_Tensors->size = i;
11911173
} else if (Aux_CTX_Tensors->size >= 2) {
1174+
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1175+
devPtrS1 = output_S->data.dptr;
1176+
11921177
if (return_max_logit) {
11931178
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1194-
devPtrS1 = output_Max->data.dptr;
1195-
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1196-
devPtrS2 = output_Sum_Exp->data.dptr;
1197-
} else {
1198-
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
1199-
devPtrS1 = output_S->data.dptr;
1179+
devPtrS2 = output_Max->data.dptr;
12001180
}
12011181
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
12021182
output_rng_state->data.dptr = rng_state->data.dptr;

transformer_engine/common/fused_attn/utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,23 @@ struct FADescriptor_v1 {
118118
cudnn_frontend::DataType_t o_tensor_type;
119119
cudnn_frontend::DataType_t do_tensor_type;
120120
cudnn_frontend::DataType_t dqkv_tensor_type;
121-
bool generate_max_sum_exp;
121+
bool return_max_logit;
122122

123123
bool operator<(const FADescriptor_v1 &rhs) const {
124124
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
125125
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq,
126126
bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type,
127127
softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
128128
deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type,
129-
dqkv_tensor_type, generate_max_sum_exp) <
129+
dqkv_tensor_type, return_max_logit) <
130130
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
131131
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
132132
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv,
133133
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
134134
rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right,
135135
rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type,
136136
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
137-
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
137+
rhs.dqkv_tensor_type, rhs.return_max_logit);
138138
}
139139
};
140140

transformer_engine/common/include/transformer_engine/fused_attn.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
206206
* \param[in] head_dim_v The head dimension of V.
207207
* \param[in] window_size_left Sliding window size (the left half).
208208
* \param[in] window_size_right Sliding window size (the right half).
209-
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
209+
* \param[in] return_max_logit Whether to produce Max along with Stats.
210210
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
211211
* \param[in] deterministic Whether determinism is required or not.
212212
*/
@@ -269,7 +269,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
269269
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
270270
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
271271
* \param[in] is_training Whether this is in training mode or inference.
272-
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
272+
* \param[in] return_max_logit Whether to produce Max along with Stats.
273273
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
274274
* \param[in] attn_scale Scaling factor for Q * K.T.
275275
* \param[in] dropout Dropout probability.

transformer_engine/pytorch/cpp_extensions/fused_attn.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -353,24 +353,26 @@ def fused_attn_fwd(
353353

354354
if return_max_logit:
355355
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
356-
# thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
357-
# thd (older cuDNN runtimes or sm120): output_tensors: out [tq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
358-
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
359-
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
360-
stats = output_tensors[1] + torch.log(output_tensors[2])
361-
max_tensor = output_tensors[1]
356+
# thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
357+
# thd (older cuDNN runtimes or sm120): output_tensors: out [tq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
358+
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
359+
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
360+
aux_ctx_tensors = [output_tensors[1]] + list(
361+
output_tensors[3:]
362+
) # Stats + rng_state + optional tensors
363+
max_tensor = output_tensors[2]
364+
amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3)
365+
362366
if qkv_format == "thd" and max_tensor.ndim == 4:
363367
# For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded
364368
# sequence positions. Exclude those padded positions when computing max_logit.
365369
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device)
366370
sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1)
367371
valid = sq_idx < seqlens_q.view(-1, 1, 1, 1)
368372
max_tensor = max_tensor.masked_fill(~valid, float("-inf"))
369-
amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3)
373+
370374
# Max -> max_logit [h]
371375
max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype)
372-
aux_ctx_tensors = [stats]
373-
aux_ctx_tensors.extend(output_tensors[3:])
374376
return output_tensors[0], aux_ctx_tensors, max_logit
375377

376378
# out, aux_ctx_tensors

transformer_engine/pytorch/csrc/extensions/attention.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,16 @@ std::vector<py::object> fused_attn_fwd(
259259
// f16_max512 : S [b, h, sq, skv]
260260
// f16_arbitrary:
261261
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
262-
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
262+
// return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
263263
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
264264
size_t i = 0;
265265
at::Tensor output_tensor;
266-
// intermediate softmax tensor, S or M
266+
// intermediate softmax tensor, S or M (for fp8)
267267
output_tensor =
268268
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
269269
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
270270
set_tensor_param(i++, output_tensor);
271-
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor
271+
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor
272272
if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
273273
output_tensor =
274274
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),

0 commit comments

Comments
 (0)