Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cpp/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ struct CudaDataType<__nv_bfloat16>
};
#endif

inline int getSMVersion()
/// @brief Get the SM version of the current device.
/// @param queryRealSmArch Whether to query the real SM architecture. example usage: use real sm arch when do LUT tuning
/// and use fake sm arch when reuse sm120 code on sm121 devices.
/// @return The SM version of the current device.
inline int getSMVersion(bool queryRealSmArch = false)
{
int device{-1};
check_cuda_error(cudaGetDevice(&device));
Expand All @@ -304,7 +308,7 @@ inline int getSMVersion()
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
int sm = sm_major * 10 + sm_minor;
if (sm == 121)
if (sm == 121 && !queryRealSmArch)
{
return 120;
}
Expand Down
13 changes: 9 additions & 4 deletions cpp/kernels/fmha_v2/fmha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
# ada fp8 fmha only supports non-tiled kernels currently.
if dtype == '-e4m3' and sm_version == 89 and tiled_kernel == "":
pytest.skip("ada fp8 fmha only supports non-tiled kernels currently.")
# Known accuracy issue in this case.
skip_dense_mask_test = False
if d == 64 and dtype in ['-fp16-fp32', '-bf16'] and tiled_kernel == "":
skip_dense_mask_test = True

# use higher error tolerance for bf16 and e4m3.
epsilon = ''
Expand Down Expand Up @@ -107,10 +111,11 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
if "softcapping-scale-bmm1" in flag:
pytest.skip("skipping softcapping-scale-bmm1 for sm89 e4m3 fmha.")

subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
if not skip_dense_mask_test:
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
Expand Down
170 changes: 73 additions & 97 deletions cpp/kernels/fmha_v2/src/fmha/fragment.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#pragma once

#include <cfloat>
#include <fmha/traits.h>
#include <fmha/utils.h>

Expand Down Expand Up @@ -1250,6 +1251,23 @@ struct Tile_o_normalizer
BYTES_PER_ELEMENT = sizeof(float)
};

// Initialize the attention sinks.
template <typename Params, typename Block_info>
inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo)
: attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] : -FLT_MAX)
{
}

// Update the sum when attention sinks are used.
inline __device__ void update_sum(float const (&max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
{
#pragma unroll
for (int i = 0; i < ROWS_PER_THREAD; ++i)
{
sum[i] += expf(attention_sink_value_ - max[i]);
}
}

// Update o.
inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&curr_max)[ROWS_PER_THREAD],
float const (&prev_max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
Expand Down Expand Up @@ -1331,8 +1349,9 @@ struct Tile_o_normalizer
}

// Update o.
inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float const (&sum)[ROWS_PER_THREAD])
inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&sum)[ROWS_PER_THREAD])
{

#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
Expand Down Expand Up @@ -1403,6 +1422,9 @@ struct Tile_o_normalizer
}
#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION
}

// Attention sink value.
float attention_sink_value_;
};

template <typename Traits, typename Cta_tile>
Expand Down Expand Up @@ -1461,6 +1483,23 @@ struct Tile_o_normalizer_fp32
BYTES_PER_ELEMENT = sizeof(float)
};

// Initialize the attention sinks.
template <typename Params, typename Block_info>
inline __device__ Tile_o_normalizer_fp32(Params const& params, Block_info const& binfo)
: attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] : -FLT_MAX)
{
}

// Update the sum when attention sinks are used.
inline __device__ void update_sum(float const (&max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
{
#pragma unroll
for (int i = 0; i < ROWS_PER_THREAD; ++i)
{
sum[i] += expf(attention_sink_value_ - max[i]);
}
}

// Update o.
inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&curr_max)[ROWS_PER_THREAD],
float const (&prev_max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
Expand Down Expand Up @@ -1501,7 +1540,7 @@ struct Tile_o_normalizer_fp32
}

// Update o after P * V
inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float const (&sum)[ROWS_PER_THREAD])
inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&sum)[ROWS_PER_THREAD])
{

#pragma unroll
Expand All @@ -1517,9 +1556,7 @@ struct Tile_o_normalizer_fp32
int jj = 2 * mi + ii;

// The diviser.
// printf("curr_sum_[ii] %lf %lf \n", curr_sum_[ii], curr_sum_[ii]);
beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj];
// printf("beta %lf \n", beta[ii]);
}

#pragma unroll
Expand All @@ -1538,6 +1575,9 @@ struct Tile_o_normalizer_fp32
}
}
}

// Attention sink value.
float attention_sink_value_;
};

template <typename Cta_tile>
Expand All @@ -1550,8 +1590,12 @@ struct Tile_o_normalizer<Ampere_hmma_fp32_traits, Cta_tile>
// The base class.
using Base = Tile_o_normalizer_fp32<Traits, Cta_tile>;

// Default ctor
Tile_o_normalizer() = default;
// The ctor.
template <typename Params, typename Block_info>
inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo)
: Base(params, binfo)
{
}
};

template <typename Cta_tile>
Expand All @@ -1564,10 +1608,15 @@ struct Tile_o_normalizer<Ampere_hmma_bf16_traits, Cta_tile>
// The base class.
using Base = Tile_o_normalizer_fp32<Traits, Cta_tile>;

// Default ctor
Tile_o_normalizer() = default;
// The ctor.
template <typename Params, typename Block_info>
inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo)
: Base(params, binfo)
{
}
};

// The attention sinks are not enabled for Volta.
template <typename Cta_tile>
struct Tile_o_normalizer<Volta_hmma_fp16_16x16x16_traits, Cta_tile>
{
Expand Down Expand Up @@ -1747,98 +1796,21 @@ struct Tile_o_normalizer<Ada_qmma_e4m3_fp32_traits, Cta_tile>
// The base class.
using Base = Tile_o_normalizer_fp32<Traits, Cta_tile>;

// Default ctor
Tile_o_normalizer() = default;

// The fragment accumulator.
using Fragment_accu = Fragment_accumulator<Traits>;

// The Mma tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;

// The number of MMAs in the M dimension.
enum
{
MMAS_M = Mma_tile::MMAS_M
};

// The number of MMAs in the N dimension.
enum
{
MMAS_N = Mma_tile::VALID_MMAS_N
};

// The number of rows per thread.
enum
{
ROWS_PER_THREAD = 2 * MMAS_M
};

// The number of registers per thread.
enum
{
REGS_PER_THREAD = 8
};

// Warps.
enum
{
WARPS_M = Cta_tile::WARPS_M
};

enum
{
WARPS_N = Cta_tile::WARPS_N
};

enum
{
WARPS_K = Cta_tile::WARPS_K
};

// softmax data bytes
enum
// The ctor.
template <typename Params, typename Block_info>
inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo)
: Base(params, binfo)
{
BYTES_PER_ELEMENT = sizeof(float)
};
}

// Update o after P * V, the only difference from the basic class is we need to dequant the sum for softmax saver.
inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&sum)[ROWS_PER_THREAD])
// Update the sum.
inline __device__ void update_sum(float const (&max)[Base::ROWS_PER_THREAD], float (&sum)[Base::ROWS_PER_THREAD])
{

constexpr float dequant_scale = Traits::SOFTMAX_FP_DEQUANT_SCALE;
// Take the log2f(Traits::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum.
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
for (int i = 0; i < Base::ROWS_PER_THREAD; ++i)
{

// Precompute the scaling factors for the 2 rows.
float beta[2];
#pragma unroll
for (int ii = 0; ii < 2; ++ii)
{
// The row.
int jj = 2 * mi + ii;

// The diviser.
beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj];
// softmax saver need the original sum.
sum[jj] = sum[jj] * dequant_scale;
}

#pragma unroll
for (int ni = 0; ni < MMAS_N; ++ni)
{
#pragma unroll
for (int ii = 0; ii < REGS_PER_THREAD; ++ii)
{
// The register for O.
float acc_o_f = acc_o[mi][ni].elt(ii);
// Compute the next accumulator.
acc_o_f = acc_o_f * beta[(ii & 2) / 2];
// Update the accumulator.
acc_o[mi][ni].elt(ii) = acc_o_f;
}
}
sum[i] += expf(this->attention_sink_value_ - max[i]) * Traits::SOFTMAX_FP_QUANT_SCALE;
}
}
};
Expand Down Expand Up @@ -1878,8 +1850,12 @@ struct Tile_o_normalizer<Ada_qmma_e4m3_fp32_traits, Cta_tile, true>
REGS_PER_THREAD = 8
};

// Default ctor
Tile_o_normalizer() = default;
// The ctor.
template <typename Params, typename Block_info>
inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo)
: Base(params, binfo)
{
}

inline __device__ void merge(Fragment_accu (&acc_dst)[MMAS_M][MMAS_N], Fragment_accu (&acc_src)[MMAS_M][MMAS_N])
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ inline __device__ void device_flash_attention_nl(Params const& params)
fmha::Clear_accumulator<Acc_type_o, Cta_tile_o::WARPS_K>::apply(acc_o);

// Flash attention updater
fmha::Tile_o_normalizer<Traits_o, Cta_tile_o, Kernel_traits::SAGE_ATTENTION> acc_o_normalizer;
fmha::Tile_o_normalizer<Traits_o, Cta_tile_o, Kernel_traits::SAGE_ATTENTION> acc_o_normalizer(params, binfo);
if constexpr (Kernel_traits::SAGE_ATTENTION)
{
acc_o_normalizer.move_to_first_block(params, bidb, bidh);
Expand Down Expand Up @@ -709,6 +709,8 @@ inline __device__ void device_flash_attention_nl(Params const& params)
}
} // Inner loop over the key/value sequence length.

// Update the sum if attention sinks are used.
acc_o_normalizer.update_sum(global_max, global_sum);
// Update acc_o of flash attention
acc_o_normalizer.final_update(acc_o, global_sum);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ inline __device__ void device_flash_attention_nl_tiled(Params const& params)
fmha::Clear_accumulator<Acc_type_o, Cta_tile_o::WARPS_K>::apply(acc_o);

// Flash attention updater
fmha::Tile_o_normalizer<Traits_o, Cta_tile_o> acc_o_normalizer;
fmha::Tile_o_normalizer<Traits_o, Cta_tile_o> acc_o_normalizer(params, binfo);
float global_max[Softmax::ROWS_PER_THREAD];
float global_sum[Softmax::ROWS_PER_THREAD];

Expand Down Expand Up @@ -588,6 +588,8 @@ inline __device__ void device_flash_attention_nl_tiled(Params const& params)

} // Inner loop over the key/value sequence length.

// Update the sum if attention sinks are used.
acc_o_normalizer.update_sum(global_max, global_sum);
// Update acc_o of flash attention
acc_o_normalizer.final_update(acc_o, global_sum);

Expand Down
10 changes: 10 additions & 0 deletions cpp/tensorrt_llm/common/cublasMMWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
}

void CublasMMWrapper::setBiasDescriptor(void* bias)
{
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(void*)));

cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
}

void CublasMMWrapper::destroyDescriptors()
{
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/common/cublasMMWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class CublasMMWrapper
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
void setScaleDescriptors(void* scale_a, void* scale_b);
void setBiasDescriptor(void* bias);
void destroyDescriptors();

cublasHandle_t getCublasHandle()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ struct CutlassGemmConfig
BLACKWELL = 1u << 4,
GROUPED_GEMM = 1u << 5,
FP8_ONLY = 1u << 6,
FP4_ONLY = 1u << 7
FP4_ONLY = 1u << 7,
FP8FP4_MIXED = 1u << 8
};

CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic;
Expand Down
Git LFS file not shown
Git LFS file not shown
Loading