Skip to content

Commit 36e0631

Browse files
ksivamanpre-commit-ci[bot]vthumbe1503vasunvidiatimmoon10
authored andcommitted
GEMM + Swiglu fused Grouped MLP for MXFP8 (#2769)
* GEMM + Swiglu fused Grouped MLP for MXFP8 Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * cleanup/lint Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Properly cache the alpha tensor Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * nD dummy grad Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 0 tokens in entire rank Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tmp downgrade cublas version check Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * delayed wgrad tests pass for basic gl Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * merge everything Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Rebase into fused_mxfp8_grouped_mlp; unit tests for delayed wgrad working Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix tests being skipped for fusible ops Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Integrate mxfp8 dbias kernel in group_quantize Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add bias/dbias fused support with cute GEMMs Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Check bias/dbias support Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Pack biases more efficiently Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * GroupedTensor for biases to avoid concat Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * format Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Support 1D grouped tensor shape for bias and fix checkpointing Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fixes and tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Refactor grouped tensor marking for paged stashing Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Remove setting logical_shape in mark_grouped_tensor Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Cleanup logical_shape Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * pass the tests for now Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address some review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more cleanups Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * refactor wgrad logic Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename argument from single_grouped_parameter to single_grouped_weight Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Check wgrad store context is not empty for 0 token case. Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Test only checks for fusion if fused kernel is available Signed-off-by: Tim Moon <tmoon@nvidia.com> * fix the tolerance to be of bf16 for the cute gemm Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * Update transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * address further review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address more review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address more review comments + test for zero grouped tensor work case Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cublaslt remove zero work gemm avoidance Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the wgrad test Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * split dbias functionality from gq api Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Format and lint Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * port fixes and add better doc for page stashing war Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Guard fusion via env Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Change to trigger CI Remove unnecessary blank line in docstring. * To retrigger CI * Space to trigger the pipeline * fix zero work cublas gemm Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent b8e17cb commit 36e0631

30 files changed

+3784
-234
lines changed

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_
4141
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
4242
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
4343
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
44-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
44+
NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
4545
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
4646
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
4747
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"

tests/cpp/operator/test_grouped_gemm.cu

Lines changed: 369 additions & 4 deletions
Large diffs are not rendered by default.

tests/cpp/operator/test_swizzle.cu

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,115 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row
110110
}
111111
}
112112

113+
// Zero out padding in a scale_inv CPU buffer so that the CPU reference
114+
// matches the kernel, which zeroes elements outside the original dims.
115+
// The buffer is stored in leading-dim-major order (row-major for rowwise,
116+
// column-major for colwise). `padded_rows x padded_cols` is the full
117+
// (padded) shape; `orig_rows` / `orig_cols` are the unpadded extents.
118+
static void zero_scale_inv_padding(uint8_t *buf,
119+
size_t padded_rows, size_t padded_cols,
120+
size_t orig_rows, size_t orig_cols) {
121+
for (size_t r = 0; r < padded_rows; ++r) {
122+
for (size_t c = 0; c < padded_cols; ++c) {
123+
if (r >= orig_rows || c >= orig_cols) {
124+
buf[r * padded_cols + c] = 0;
125+
}
126+
}
127+
}
128+
}
129+
130+
void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const size_t K) {
131+
using namespace transformer_engine;
132+
using namespace test;
133+
134+
std::vector<std::unique_ptr<Tensor>> input_tensors;
135+
std::vector<std::unique_ptr<Tensor>> output_tensors;
136+
std::vector<Tensor*> input_ptrs;
137+
std::vector<Tensor*> output_ptrs;
138+
input_tensors.reserve(num_tensors);
139+
output_tensors.reserve(num_tensors);
140+
input_ptrs.reserve(num_tensors);
141+
output_ptrs.reserve(num_tensors);
142+
143+
constexpr size_t BLOCK_SIZE = 32;
144+
const std::vector<size_t> shape{M, K};
145+
for (int i = 0; i < num_tensors; ++i) {
146+
auto input = std::make_unique<Tensor>("input_" + std::to_string(i), shape,
147+
DType::kFloat8E4M3, true, true,
148+
NVTE_MXFP8_1D_SCALING);
149+
auto output = std::make_unique<Tensor>("output_" + std::to_string(i), shape,
150+
DType::kFloat8E4M3, true, true,
151+
NVTE_MXFP8_1D_SCALING);
152+
fillUniform(input.get());
153+
fillUniform(output.get());
154+
155+
// The grouped swizzle kernel zeroes scale_inv elements that fall
156+
// outside the original (unpadded) dimensions. Mirror that in the
157+
// per-tensor CPU buffers so the CPU reference produces identical output.
158+
input->to_cpu();
159+
const NVTEShape rs = input->rowwise_scale_inv_shape();
160+
zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr<uint8_t>(),
161+
rs.data[0], rs.data[1],
162+
M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE);
163+
const NVTEShape cs = input->columnwise_scale_inv_shape();
164+
zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr<uint8_t>(),
165+
cs.data[0], cs.data[1],
166+
(M + BLOCK_SIZE - 1) / BLOCK_SIZE, K);
167+
input->from_cpu();
168+
169+
input_ptrs.push_back(input.get());
170+
output_ptrs.push_back(output.get());
171+
input_tensors.emplace_back(std::move(input));
172+
output_tensors.emplace_back(std::move(output));
173+
}
174+
175+
GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING);
176+
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING);
177+
const uint8_t input_swizzled = 0;
178+
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
179+
kNVTEGroupedWithGEMMSwizzledScales,
180+
&input_swizzled, sizeof(input_swizzled));
181+
const uint8_t output_swizzled = 1;
182+
nvte_set_grouped_tensor_param(grouped_output.get_handle(),
183+
kNVTEGroupedWithGEMMSwizzledScales,
184+
&output_swizzled, sizeof(output_swizzled));
185+
186+
const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape();
187+
const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape();
188+
const size_t row_numel = row_shape.data[0] * row_shape.data[1];
189+
const size_t col_numel = col_shape.data[0] * col_shape.data[1];
190+
191+
NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0, num_tensors * row_numel));
192+
NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel));
193+
194+
nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(),
195+
grouped_output.get_handle(), 0);
196+
197+
std::vector<uint8_t> output_row(num_tensors * row_numel);
198+
std::vector<uint8_t> output_col(num_tensors * col_numel);
199+
NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(),
200+
output_row.size(), cudaMemcpyDeviceToHost));
201+
NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), grouped_output.columnwise_scale_inv.get(),
202+
output_col.size(), cudaMemcpyDeviceToHost));
203+
204+
std::vector<uint8_t> ref_row(num_tensors * row_numel);
205+
std::vector<uint8_t> ref_col(num_tensors * col_numel);
206+
for (int i = 0; i < num_tensors; ++i) {
207+
compute_ref_swizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr<uint8_t>(),
208+
ref_row.data() + i * row_numel,
209+
row_shape.data[0], row_shape.data[1]);
210+
compute_ref_swizzle<128, 4, false>(
211+
input_tensors[i]->columnwise_cpu_scale_inv_ptr<uint8_t>(),
212+
ref_col.data() + i * col_numel,
213+
col_shape.data[1], col_shape.data[0]);
214+
}
215+
216+
compareResults("grouped_swizzle_rowwise", output_row.data(), ref_row.data(),
217+
num_tensors * row_numel);
218+
compareResults("grouped_swizzle_colwise", output_col.data(), ref_col.data(),
219+
num_tensors * col_numel);
220+
}
221+
113222
class SwizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};
114223

115224

@@ -126,6 +235,41 @@ TEST_P(SwizzleTestSuite, TestSwizzle) {
126235
transa);
127236
}
128237

238+
class SwizzleGroupedTestSuite
239+
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};
240+
241+
TEST_P(SwizzleGroupedTestSuite, TestGroupedSwizzleMXFP8) {
242+
const auto num_tensors = std::get<0>(GetParam());
243+
const auto M = std::get<1>(GetParam());
244+
const auto K = std::get<2>(GetParam());
245+
performTestGroupedSwizzleMXFP8(num_tensors, M, K);
246+
}
247+
248+
INSTANTIATE_TEST_SUITE_P(
249+
OperatorTest,
250+
SwizzleGroupedTestSuite,
251+
::testing::Values(
252+
// M and K both divisible by 128
253+
std::make_tuple(3, 256, 256),
254+
std::make_tuple(4, 128, 128),
255+
// M not divisible by 128
256+
std::make_tuple(3, 200, 256),
257+
std::make_tuple(2, 65, 256),
258+
// K not divisible by 128
259+
std::make_tuple(3, 256, 160),
260+
std::make_tuple(2, 256, 96),
261+
// Neither M nor K divisible by 128
262+
std::make_tuple(3, 200, 160),
263+
std::make_tuple(4, 33, 64),
264+
std::make_tuple(2, 1, 32)
265+
),
266+
[](const testing::TestParamInfo<SwizzleGroupedTestSuite::ParamType>& info) {
267+
return "n" + std::to_string(std::get<0>(info.param)) +
268+
"_M" + std::to_string(std::get<1>(info.param)) +
269+
"_K" + std::to_string(std::get<2>(info.param));
270+
}
271+
);
272+
129273
namespace {
130274

131275
std::vector<std::pair<int, int>> num_tiles = {

0 commit comments

Comments
 (0)