Skip to content

Commit 3c6c35a

Browse files
committed
fix zero work cublas gemm
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1 parent 49c2169 commit 3c6c35a

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

tests/pytorch/test_numerics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3036,7 +3036,7 @@ def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -
30363036

30373037
def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a):
30383038
"""Create a GroupedTensor with non-zero logical_shape but zero first_dims."""
3039-
buf = torch.randn(k, logical_last_dim, dtype=dtype, device=device)
3039+
buf = torch.randn(0, logical_last_dim, dtype=dtype, device=device)
30403040
if use_mxfp8:
30413041
if is_a:
30423042
rowwise, columnwise = transa, not transa

transformer_engine/pytorch/csrc/type_converters.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) {
221221
DType data_dtype =
222222
quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype;
223223
ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data));
224+
} else if (quantizer_dtype != DType::kNumTypes) {
225+
ret.set_rowwise_data(nullptr, quantizer_dtype, std::vector<size_t>{0});
224226
}
225227

226228
// Columnwise data
@@ -229,6 +231,8 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) {
229231
DType data_dtype =
230232
quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype;
231233
ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data));
234+
} else if (quantizer_dtype != DType::kNumTypes) {
235+
ret.set_columnwise_data(nullptr, quantizer_dtype, std::vector<size_t>{0});
232236
}
233237

234238
// Scale

transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def make_grouped_tensor(
543543
all_same_last = last_dims is None
544544

545545
assert all_same_last, "Last dim must be uniform for GroupedTensor"
546-
assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor"
546+
assert logical_first_dim >=0 , "Logical first dim must be non-negative for GroupedTensor"
547547
assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor"
548548

549549
# assert (

0 commit comments

Comments
 (0)