1212
1313#pragma once
1414
15+ #include < cfloat>
1516#include < fmha/traits.h>
1617#include < fmha/utils.h>
1718
@@ -1250,6 +1251,23 @@ struct Tile_o_normalizer
12501251 BYTES_PER_ELEMENT = sizeof (float )
12511252 };
12521253
1254+ // Initialize the attention sinks.
1255+ template <typename Params, typename Block_info>
1256+ inline __device__ Tile_o_normalizer (Params const & params, Block_info const & binfo)
1257+ : attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] : -FLT_MAX)
1258+ {
1259+ }
1260+
1261+ // Update the sum when attention sinks are used.
1262+ inline __device__ void update_sum (float const (&max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
1263+ {
1264+ #pragma unroll
1265+ for (int i = 0 ; i < ROWS_PER_THREAD; ++i)
1266+ {
1267+ sum[i] += expf (attention_sink_value_ - max[i]);
1268+ }
1269+ }
1270+
12531271 // Update o.
12541272 inline __device__ void update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&curr_max)[ROWS_PER_THREAD],
12551273 float const (&prev_max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
@@ -1331,8 +1349,9 @@ struct Tile_o_normalizer
13311349 }
13321350
13331351 // Update o.
1334- inline __device__ void final_update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float const (&sum)[ROWS_PER_THREAD])
1352+ inline __device__ void final_update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&sum)[ROWS_PER_THREAD])
13351353 {
1354+
13361355#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation
13371356#pragma unroll
13381357 for (int mi = 0 ; mi < MMAS_M; ++mi)
@@ -1403,6 +1422,9 @@ struct Tile_o_normalizer
14031422 }
14041423#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION
14051424 }
1425+
1426+ // Attention sink value.
1427+ float attention_sink_value_;
14061428};
14071429
14081430template <typename Traits, typename Cta_tile>
@@ -1461,6 +1483,23 @@ struct Tile_o_normalizer_fp32
14611483 BYTES_PER_ELEMENT = sizeof (float )
14621484 };
14631485
1486+ // Initialize the attention sinks.
1487+ template <typename Params, typename Block_info>
1488+ inline __device__ Tile_o_normalizer_fp32 (Params const & params, Block_info const & binfo)
1489+ : attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] : -FLT_MAX)
1490+ {
1491+ }
1492+
1493+ // Update the sum when attention sinks are used.
1494+ inline __device__ void update_sum (float const (&max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
1495+ {
1496+ #pragma unroll
1497+ for (int i = 0 ; i < ROWS_PER_THREAD; ++i)
1498+ {
1499+ sum[i] += expf (attention_sink_value_ - max[i]);
1500+ }
1501+ }
1502+
14641503 // Update o.
14651504 inline __device__ void update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&curr_max)[ROWS_PER_THREAD],
14661505 float const (&prev_max)[ROWS_PER_THREAD], float (&sum)[ROWS_PER_THREAD])
@@ -1501,7 +1540,7 @@ struct Tile_o_normalizer_fp32
15011540 }
15021541
15031542 // Update o after P * V
1504- inline __device__ void final_update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float const (&sum)[ROWS_PER_THREAD])
1543+ inline __device__ void final_update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&sum)[ROWS_PER_THREAD])
15051544 {
15061545
15071546#pragma unroll
@@ -1517,9 +1556,7 @@ struct Tile_o_normalizer_fp32
15171556 int jj = 2 * mi + ii;
15181557
15191558 // The diviser.
1520- // printf("curr_sum_[ii] %lf %lf \n", curr_sum_[ii], curr_sum_[ii]);
15211559 beta[ii] = (sum[jj] == 0 .f || sum[jj] != sum[jj]) ? 1 .f : 1 .f / sum[jj];
1522- // printf("beta %lf \n", beta[ii]);
15231560 }
15241561
15251562#pragma unroll
@@ -1538,6 +1575,9 @@ struct Tile_o_normalizer_fp32
15381575 }
15391576 }
15401577 }
1578+
1579+ // Attention sink value.
1580+ float attention_sink_value_;
15411581};
15421582
15431583template <typename Cta_tile>
@@ -1550,8 +1590,12 @@ struct Tile_o_normalizer<Ampere_hmma_fp32_traits, Cta_tile>
15501590 // The base class.
15511591 using Base = Tile_o_normalizer_fp32<Traits, Cta_tile>;
15521592
1553- // Default ctor
1554- Tile_o_normalizer () = default ;
1593+ // The ctor.
1594+ template <typename Params, typename Block_info>
1595+ inline __device__ Tile_o_normalizer (Params const & params, Block_info const & binfo)
1596+ : Base(params, binfo)
1597+ {
1598+ }
15551599};
15561600
15571601template <typename Cta_tile>
@@ -1564,10 +1608,15 @@ struct Tile_o_normalizer<Ampere_hmma_bf16_traits, Cta_tile>
15641608 // The base class.
15651609 using Base = Tile_o_normalizer_fp32<Traits, Cta_tile>;
15661610
1567- // Default ctor
1568- Tile_o_normalizer () = default ;
1611+ // The ctor.
1612+ template <typename Params, typename Block_info>
1613+ inline __device__ Tile_o_normalizer (Params const & params, Block_info const & binfo)
1614+ : Base(params, binfo)
1615+ {
1616+ }
15691617};
15701618
1619+ // The attention sinks are not enabled for Volta.
15711620template <typename Cta_tile>
15721621struct Tile_o_normalizer <Volta_hmma_fp16_16x16x16_traits, Cta_tile>
15731622{
@@ -1747,98 +1796,21 @@ struct Tile_o_normalizer<Ada_qmma_e4m3_fp32_traits, Cta_tile>
17471796 // The base class.
17481797 using Base = Tile_o_normalizer_fp32<Traits, Cta_tile>;
17491798
1750- // Default ctor
1751- Tile_o_normalizer () = default ;
1752-
1753- // The fragment accumulator.
1754- using Fragment_accu = Fragment_accumulator<Traits>;
1755-
1756- // The Mma tile.
1757- using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
1758-
1759- // The number of MMAs in the M dimension.
1760- enum
1761- {
1762- MMAS_M = Mma_tile::MMAS_M
1763- };
1764-
1765- // The number of MMAs in the N dimension.
1766- enum
1767- {
1768- MMAS_N = Mma_tile::VALID_MMAS_N
1769- };
1770-
1771- // The number of rows per thread.
1772- enum
1773- {
1774- ROWS_PER_THREAD = 2 * MMAS_M
1775- };
1776-
1777- // The number of registers per thread.
1778- enum
1779- {
1780- REGS_PER_THREAD = 8
1781- };
1782-
1783- // Warps.
1784- enum
1785- {
1786- WARPS_M = Cta_tile::WARPS_M
1787- };
1788-
1789- enum
1790- {
1791- WARPS_N = Cta_tile::WARPS_N
1792- };
1793-
1794- enum
1795- {
1796- WARPS_K = Cta_tile::WARPS_K
1797- };
1798-
1799- // softmax data bytes
1800- enum
1799+ // The ctor.
1800+ template <typename Params, typename Block_info>
1801+ inline __device__ Tile_o_normalizer (Params const & params, Block_info const & binfo)
1802+ : Base(params, binfo)
18011803 {
1802- BYTES_PER_ELEMENT = sizeof (float )
1803- };
1804+ }
18041805
1805- // Update o after P * V, the only difference from the basic class is we need to dequant the sum for softmax saver .
1806- inline __device__ void final_update (Fragment_accu (&acc_o)[MMAS_M][MMAS_N] , float (&sum)[ROWS_PER_THREAD])
1806+ // Update the sum.
1807+ inline __device__ void update_sum ( float const (&max)[Base::ROWS_PER_THREAD] , float (&sum)[Base:: ROWS_PER_THREAD])
18071808 {
1808-
1809- constexpr float dequant_scale = Traits::SOFTMAX_FP_DEQUANT_SCALE;
1809+ // Take the log2f(Traits::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum.
18101810#pragma unroll
1811- for (int mi = 0 ; mi < MMAS_M ; ++mi )
1811+ for (int i = 0 ; i < Base::ROWS_PER_THREAD ; ++i )
18121812 {
1813-
1814- // Precompute the scaling factors for the 2 rows.
1815- float beta[2 ];
1816- #pragma unroll
1817- for (int ii = 0 ; ii < 2 ; ++ii)
1818- {
1819- // The row.
1820- int jj = 2 * mi + ii;
1821-
1822- // The diviser.
1823- beta[ii] = (sum[jj] == 0 .f || sum[jj] != sum[jj]) ? 1 .f : 1 .f / sum[jj];
1824- // softmax saver need the original sum.
1825- sum[jj] = sum[jj] * dequant_scale;
1826- }
1827-
1828- #pragma unroll
1829- for (int ni = 0 ; ni < MMAS_N; ++ni)
1830- {
1831- #pragma unroll
1832- for (int ii = 0 ; ii < REGS_PER_THREAD; ++ii)
1833- {
1834- // The register for O.
1835- float acc_o_f = acc_o[mi][ni].elt (ii);
1836- // Compute the next accumulator.
1837- acc_o_f = acc_o_f * beta[(ii & 2 ) / 2 ];
1838- // Update the accumulator.
1839- acc_o[mi][ni].elt (ii) = acc_o_f;
1840- }
1841- }
1813+ sum[i] += expf (this ->attention_sink_value_ - max[i]) * Traits::SOFTMAX_FP_QUANT_SCALE;
18421814 }
18431815 }
18441816};
@@ -1878,8 +1850,12 @@ struct Tile_o_normalizer<Ada_qmma_e4m3_fp32_traits, Cta_tile, true>
18781850 REGS_PER_THREAD = 8
18791851 };
18801852
1881- // Default ctor
1882- Tile_o_normalizer () = default ;
1853+ // The ctor.
1854+ template <typename Params, typename Block_info>
1855+ inline __device__ Tile_o_normalizer (Params const & params, Block_info const & binfo)
1856+ : Base(params, binfo)
1857+ {
1858+ }
18831859
18841860 inline __device__ void merge (Fragment_accu (&acc_dst)[MMAS_M][MMAS_N], Fragment_accu (&acc_src)[MMAS_M][MMAS_N])
18851861 {
0 commit comments