1717
1818namespace fbgemm_gpu {
1919
20+ template <typename T, typename ... Ts>
21+ constexpr inline bool is_one_of_v = (std::is_same_v<T, Ts> || ...);
22+
2023// //////////////////////////////////////////////////////////////////////////////
2124// Quantized Load and Store
2225// //////////////////////////////////////////////////////////////////////////////
@@ -37,32 +40,19 @@ DEVICE_INLINE void quantize_store(
3740template <typename dst_t , typename src_t >
3841DEVICE_INLINE Vec4T<dst_t > dequantize_load (
3942 const src_t * value,
40- const float2 /* unused */ ) {
41- return Vec4T<dst_t >(value);
42- }
43-
44- template <>
45- DEVICE_INLINE Vec4T<float > dequantize_load (
46- const uint8_t * value,
47- const float2 qparams) {
48- Vec4T<float > out;
49- out.acc .x = value[0 ] * qparams.x + qparams.y ;
50- out.acc .y = value[1 ] * qparams.x + qparams.y ;
51- out.acc .z = value[2 ] * qparams.x + qparams.y ;
52- out.acc .w = value[3 ] * qparams.x + qparams.y ;
53- return out;
54- }
43+ [[maybe_unused]] const float2 qparams) {
44+ if constexpr (
45+ std::is_same_v<src_t , uint8_t > && is_one_of_v<dst_t , float , at::Half>) {
46+ Vec4T<dst_t > out;
47+ out.acc .x = value[0 ] * qparams.x + qparams.y ;
48+ out.acc .y = value[1 ] * qparams.x + qparams.y ;
49+ out.acc .z = value[2 ] * qparams.x + qparams.y ;
50+ out.acc .w = value[3 ] * qparams.x + qparams.y ;
51+ return out;
5552
56- template <>
57- DEVICE_INLINE Vec4T<at::Half> dequantize_load (
58- const uint8_t * value,
59- const float2 qparams) {
60- Vec4T<at::Half> out;
61- out.acc .x = value[0 ] * qparams.x + qparams.y ;
62- out.acc .y = value[1 ] * qparams.x + qparams.y ;
63- out.acc .z = value[2 ] * qparams.x + qparams.y ;
64- out.acc .w = value[3 ] * qparams.x + qparams.y ;
65- return out;
53+ } else {
54+ return Vec4T<dst_t >(value);
55+ }
6656}
6757
6858template <typename emb_t >
@@ -74,12 +64,6 @@ DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) {
7464 return qparams;
7565}
7666
77- template <typename emb_t >
78- DEVICE_INLINE void store_qparams_to_row (emb_t * ptr, float2 qparams) {
79- CUDA_KERNEL_ASSERT (false ); // Only int8 embeddding should call this
80- }
81-
82- template <>
8367DEVICE_INLINE void store_qparams_to_row (uint8_t * ptr, float2 qparams) {
8468 auto ptr_as_uint = reinterpret_cast <uintptr_t >(ptr);
8569 if (ptr_as_uint % 8 == 0 ) {
@@ -112,12 +96,24 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
11296
11397// //////////////////////////////////////////////////////////////////////////////
11498// Weight Row
99+ //
100+ // This is a memory accessor around a row of dim_ number of embedding weights.
101+ // It provides for loading and storing of 4 elements at a time (Vec4T<dst_t>)
102+ // from and to the embedding table or cache. It also provides for quantization
103+ // and de-quantization of the data. The cache row pointer is optional, and if
104+ // not provided, then the embedding table is assumed to be the source of truth.
105+ //
106+ // Template parameters:
107+ // emb_t : The type of the embedding table (e.g. uint8_t, float, at::Half)
108+ // cache_t : The type of the cache
109+ // dst_t : The type of the registers
115110// //////////////////////////////////////////////////////////////////////////////
116111
117112template <typename emb_t , typename cache_t , typename dst_t >
118113// TODO: pass in dimension info and calculate qparams for rowwise integer
119114// quantization
120- struct WeightRow {
115+ class WeightRow {
116+ public:
121117 // Constructor for no stochastic rounding
122118 DEVICE_INLINE WeightRow (emb_t * row, cache_t * cache_row, int dim)
123119 : row_(row),
@@ -144,65 +140,54 @@ struct WeightRow {
144140 }
145141 }
146142
147- emb_t * row_;
148- cache_t * cache_row_;
149- int dim_;
150- StochasticRoundingRNGState stoc_rounding_state_;
151- StochasticRoundingRNGState* stoc_rounding_state_ptr_;
143+ // ////////////////////////////////////////////////////////////////////////////
144+ // Load 4 elements from the table row at element offset d into a register
145+ // variable (Vec4T<dst_t>)
146+ //
147+ // If the cache row pointer is valid, then data will be read from the cache
148+ // instead of embedding table.
149+ // ////////////////////////////////////////////////////////////////////////////
152150
153- // Load from cache if resident; else load from embedding
154151 DEVICE_INLINE Vec4T<dst_t > load (const int32_t d, const float2 qparams) const {
152+ // Load from the cache if resident; else load from the embedding table.
153+ //
154+ // Note: This method assumes that dst_t is of higher precision than cache_t
155+ // and emb_t
155156 if (cache_row_) {
156157 return dequantize_load<dst_t , cache_t >(cache_row_ + d, qparams);
157158 } else {
158159 return dequantize_load<dst_t , emb_t >(row_ + d, qparams);
159160 }
160161 }
161162
162- // Write back weight (high precision) to cache if resident; else write to
163- // embedding assume dst_t is higher precision than cache_t and emb_t
163+ // ////////////////////////////////////////////////////////////////////////////
164+ // Store regster variable of 4 elements (Vec4T<dst_t>) back into the table
165+ // into the table row at element offset d
166+ //
167+ // If the cache row pointer is valid, then data will be written to the cache
168+ // instead of embedding table.
169+ // ////////////////////////////////////////////////////////////////////////////
170+
164171 DEVICE_INLINE void
165172 store (const Vec4T<dst_t >& v, const int32_t d, const float2 qparams) {
173+ // Write back weight (high precision) to cache if resident; else write to
174+ // embedding table.
175+ //
176+ // Note: This method assumes that dst_t is of higher precision than cache_t
177+ // and emb_t
166178 if (cache_row_) {
167179 quantize_store (cache_row_ + d, v, stoc_rounding_state_ptr_, qparams);
168180 } else {
169181 quantize_store (row_ + d, v, stoc_rounding_state_ptr_, qparams);
170182 }
171183 }
172184
173- // Copy vector from src_vec to dst_vec (both are float)
174- DEVICE_INLINE void same_type_vector_copy (
175- float * dst_vec,
176- const float * src_vec) {
177- *reinterpret_cast <float4 *>(dst_vec) =
178- *reinterpret_cast <const float4 *>(src_vec);
179- }
180-
181- // Copy vector from src_vec to dst_vec (both are at::Half)
182- DEVICE_INLINE void same_type_vector_copy (
183- at::Half* dst_vec,
184- const at::Half* src_vec) {
185- *reinterpret_cast <float2 *>(dst_vec) =
186- *reinterpret_cast <const float2 *>(src_vec);
187- }
188-
189- // Evict cached row into embedding row (high prec -> low prec)
190- DEVICE_INLINE void evict_cache (const int32_t d, const float2 qparams) {
191- if constexpr (std::is_same_v<emb_t , cache_t >) {
192- // No conversion required when emb_t and cache_t are the same type
193- same_type_vector_copy (
194- reinterpret_cast <cache_t *>(row_ + d),
195- reinterpret_cast <const cache_t *>(cache_row_ + d));
196- } else {
197- // Does 2-step conversion: cache_t -> FP32 -> weight_t
198- const auto cache_slice = load (d, qparams);
199- quantize_store (row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
200- }
201- }
202-
203- DEVICE_INLINE void store_qparams (const float2 qparams) {
204- store_qparams_to_row (row_ + dim_, qparams);
205- }
185+ // ////////////////////////////////////////////////////////////////////////////
186+ // Fetch the quantization parameters of the table row
187+ //
188+ // Qparams are fetched from the end of the row in the embedding table, not the
189+ // cache.
190+ // ////////////////////////////////////////////////////////////////////////////
206191
207192 DEVICE_INLINE float2 load_qparams () const {
208193 if constexpr (std::is_same_v<emb_t , uint8_t >) {
@@ -212,13 +197,35 @@ struct WeightRow {
212197 }
213198 }
214199
200+ // ////////////////////////////////////////////////////////////////////////////
201+ // Update the quantization parameters of the table row
202+ //
203+ // Qparams are stored at the end of the row in the embedding table, not the
204+ // cache.
205+ // ////////////////////////////////////////////////////////////////////////////
206+
207+ template <typename T = emb_t >
208+ DEVICE_INLINE auto store_qparams (const float2 qparams) const
209+ -> std::enable_if_t<std::is_same_v<T, uint8_t>, void> {
210+ store_qparams_to_row (row_ + dim_, qparams);
211+ }
212+
213+ // ////////////////////////////////////////////////////////////////////////////
214+ // Load the row from the embedding table into the cache
215+ //
216+ // De-quantization will be applied if the embedding table type is uint8_t (low
217+ // prec -> high prec).
218+ // ////////////////////////////////////////////////////////////////////////////
219+
215220 DEVICE_INLINE void warp_copy_to_cache (
216221 cache_t * dst_row,
217222 const uint32_t dim_length,
223+
218224 const uint32_t num_lanes,
219225 const uint32_t lane_id) {
220226 if constexpr (std::is_same_v<emb_t , cache_t >) {
221- // No conversion required when emb_t and cache_t are the same type
227+ // If the embedding table and cache types are the same, then simply copy
228+ // data from cache to embedding table.
222229 for (auto d = lane_id * 4 ; d < dim_length; d += num_lanes * 4 ) {
223230 same_type_vector_copy (
224231 dst_row + d, reinterpret_cast <const cache_t *>(row_ + d));
@@ -236,6 +243,31 @@ struct WeightRow {
236243 }
237244 }
238245
246+ // ////////////////////////////////////////////////////////////////////////////
247+ // Copy the row from the embedding table into the cache
248+ // ////////////////////////////////////////////////////////////////////////////
249+
250+ DEVICE_INLINE void evict_cache (const uint32_t d, const float2 qparams) {
251+ if constexpr (std::is_same_v<emb_t , cache_t >) {
252+ // If the embedding table and cache types are the same, then simply copy
253+ // data from cache to embedding table.
254+ same_type_vector_copy (
255+ reinterpret_cast <emb_t *>(row_ + d),
256+ reinterpret_cast <const cache_t *>(cache_row_ + d));
257+ } else {
258+ // Else, do 2-step conversion: cache_t -> FP32 (register) -> weight_t
259+ const auto cache_slice = load (d, qparams);
260+ quantize_store (row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
261+ }
262+ }
263+
264+ // ////////////////////////////////////////////////////////////////////////////
265+ // Evict the row from the cache and into the embedding table.
266+ //
267+ // Quantization will be applied if the embedding table type is uint8_t (high
268+ // prec -> low prec).
269+ // ////////////////////////////////////////////////////////////////////////////
270+
239271 DEVICE_INLINE void warp_evict_cache (
240272 const uint32_t dim_length,
241273 const uint32_t num_lanes,
@@ -268,36 +300,86 @@ struct WeightRow {
268300 evict_cache (d, qparams);
269301 }
270302 }
303+
304+ private:
305+ // The pointer to the row of weights in the embedding table
306+ emb_t * const row_;
307+
308+ // The pointer to the row of weights in the cache
309+ cache_t * const cache_row_;
310+
311+ // The number of elements per table row
312+ int32_t const dim_;
313+
314+ // The state for stochastic rounding
315+ StochasticRoundingRNGState stoc_rounding_state_;
316+ StochasticRoundingRNGState* stoc_rounding_state_ptr_;
317+
318+ // ////////////////////////////////////////////////////////////////////////////
319+ // Copy 4 elements (float or at::Half) from src_vec to dst_vec
320+ //
321+ // Reinterpret cast to float4* or float2* for mass copy
322+ // ////////////////////////////////////////////////////////////////////////////
323+
324+ template <
325+ typename T,
326+ typename = std::enable_if_t <is_one_of_v<T, float , at::Half>>>
327+ DEVICE_INLINE void same_type_vector_copy (T* dst_vec, const T* src_vec) {
328+ // Copy vector from src_vec to dst_vec (both are float)
329+ using ptr_t = std::conditional_t <std::is_same_v<T, float >, float4 , float2 >;
330+ *reinterpret_cast <ptr_t *>(dst_vec) =
331+ *reinterpret_cast <const ptr_t *>(src_vec);
332+ }
271333};
272334
273335// //////////////////////////////////////////////////////////////////////////////
274336// Weight Row Accessor
275337//
276- // This is a basic memory accessor around a row of dim_ number of embedding
277- // weights of type row_t, and provides for loading 4 elements at a time into
278- // Vec4T<dst_t> with de-quantization support. Unlike WeightRow, this accessor
279- // is for reading only, and does not take into account embedding vs cache table,
280- // etc.
338+ // This is a lightweight memory accessor around a row of dim_ number of
339+ // embedding weights of type row_t (can be HBM or UVM), and provides for loading
340+ // 4 elements at a time into Vec4T<dst_t> with de-quantization support. Unlike
341+ // the heavyweight WeightRow class, this accessor is for reading values only,
342+ // and does not handle embedding vs cache tables, etc.
343+ //
344+ // Template parameters:
345+ // row_t : The type of the table row (e.g. uint8_t, float, at::Half)
346+ // dst_t : The type of the registers
281347// //////////////////////////////////////////////////////////////////////////////
282348
283349template <typename row_t , typename dst_t >
284- struct WeightRowAccessor {
285- const row_t * row_;
350+ class WeightRowAccessor {
351+ // The pointer to the row of weights in the table
352+ const row_t * const row_;
353+
354+ // The number of elements per table row.
355+ //
356+ // This is NOT necessarily equivalent to the row stride D_emb, as there may be
357+ // quantization parameters and optimizer states packed into the back of the
358+ // row.
359+ //
360+ // dim_ is presumed to be a multiple of 4, since it loads data into Vec4T for
361+ // max register occupancy.
286362 const int32_t dim_;
287- const float2 qparams_;
288363
364+ // [OPTIONAL] The quantization parameters for the row. If the row type is not
365+ // uint8_t, i.e. not quantized, then it is set to (0.0f, 0.0f).
366+ float2 qparams_ = make_float2(0 .0f , 0 .0f );
367+
368+ public:
289369 DEVICE_INLINE
290370 WeightRowAccessor (const row_t * const row, const int32_t dim)
291- : row_(row), dim_(dim), qparams_(qparams()) {}
292-
293- DEVICE_INLINE auto qparams () const {
371+ : row_(row), dim_(dim) {
294372 if constexpr (std::is_same_v<row_t , uint8_t >) {
295- return load_qparams_from_row<row_t >(row_ + dim_);
296- } else {
297- return make_float2 (0 .0f , 0 .0f );
373+ qparams_ = qparams ();
298374 }
299375 }
300376
377+ template <typename T = row_t >
378+ DEVICE_INLINE auto qparams () const
379+ -> std::enable_if_t<std::is_same_v<T, uint8_t>, float2> {
380+ return load_qparams_from_row<row_t >(row_ + dim_);
381+ }
382+
301383 DEVICE_INLINE Vec4T<dst_t > load (const int32_t d) const {
302384 return dequantize_load<dst_t , row_t >(row_ + d, qparams_);
303385 }
0 commit comments