Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 2 commits intoNVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 2 commits intoNVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet
Copy link
Copy Markdown
Collaborator

@negvet negvet commented Mar 31, 2026

Description

Hybrid quantization is functional.
C++ optimizations will come in the next PRs.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This PR introduces hybrid quantization, allowing a single tensor to store rowwise data in one quantized format (e.g. FP8) and columnwise data in another (e.g. NVFP4), then dispatching the appropriate sub-storage to each GEMM based on the layout string. Key additions include HybridQuantizer, HybridQuantizedTensor(Storage), per-operand unwrap helpers in gemm.py, a CUDA kernel extension to support columnwise-only output (null output.dptr), and bypass flags in GroupedLinear/LayerNormLinear/LayerNormMLP for unsupported fused paths.

Confidence Score: 4/5

Safe to merge for Linear/LayerNorm paths; the P1 in _hybrid_split_quantize only fires on mixed hybrid/non-hybrid GroupedLinear quantizer lists not covered by the new tests.

One P1 (_hybrid_split_quantize attribute access crash on mixed lists) keeps the score at 4. All P2s are minor style/robustness issues. GEMM dispatch logic, CUDA nullptr guards, and the large new test suite are correct.

transformer_engine/pytorch/module/grouped_linear.py — _has_hybrid_quantizer / _hybrid_split_quantize mismatch

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds _hybrid_split_quantize and _has_hybrid_quantizer; function crashes with AttributeError on any non-hybrid quantizer in a mixed list, despite docstring claiming graceful fallback.
transformer_engine/pytorch/tensor/hybrid_tensor.py New HybridQuantizer and HybridQuantizedTensor; make_empty lacks try/finally exception safety around internal flag mutations.
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py New HybridQuantizedTensorStorage; repr shows 'NoneType' instead of 'None' for dropped sub-storages.
transformer_engine/pytorch/cpp_extensions/gemm.py _unwrap_hybrid_A/_unwrap_hybrid_B correctly select sub-storage by layout flag; wired into both general_gemm and general_grouped_gemm.
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Supports columnwise-only output via null output.dptr; nullptr kernel guards are correct but missing a defensive NVTE_CHECK for scale_inv.dptr.

Comments Outside Diff (1)

  1. transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu, line 503-528 (link)

    P2 Missing NVTE_CHECK when return_identity=false but scale_inv.dptr is non-null

    With zero strides, all GPU threads would race-write to index [0] of scale_inv, silently corrupting data. An NVTE_CHECK(scale_inv.dptr == nullptr) guard when !return_identity would catch this misuse early.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +64 to +86
def _hybrid_split_quantize(tensor, m_splits, quantizers):
"""Grouped split+quantize for HybridQuantizer lists.

Runs tex.split_quantize twice (once per direction with the native
sub-quantizers), then zips the results into HybridQuantizedTensorStorage.
Non-hybrid quantizers in the list fall back to per-split Python quantize.
"""
from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage as HybridStorage

row_quantizers = [q.rowwise_quantizer for q in quantizers]
col_quantizers = [q.columnwise_quantizer for q in quantizers]

row_results = tex.split_quantize(tensor, m_splits, row_quantizers)
col_results = tex.split_quantize(tensor, m_splits, col_quantizers)

return [
HybridStorage(
rowwise_storage=row,
columnwise_storage=col,
rowwise_quantizer=rq,
columnwise_quantizer=cq,
quantizer=q,
fake_dtype=tensor.dtype,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _hybrid_split_quantize crashes on mixed-quantizer lists

_has_hybrid_quantizer returns True if any quantizer in the list is a HybridQuantizer, but _hybrid_split_quantize unconditionally accesses q.rowwise_quantizer and q.columnwise_quantizer for every element. If the list contains even one non-hybrid quantizer, this raises AttributeError at runtime.

The docstring claims "Non-hybrid quantizers in the list fall back to per-split Python quantize", but no such fallback exists in the implementation:

row_quantizers = [q.rowwise_quantizer for q in quantizers]  # crashes if q is not HybridQuantizer
col_quantizers = [q.columnwise_quantizer for q in quantizers]

Either the condition at the call site should assert all-or-nothing hybrid (all(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)), or the function needs to implement the per-element fallback its docstring promises. The same issue applies to all three call sites in both the forward and backward paths.

Comment on lines +146 to +155
"fake_dtype": self._dtype,
}

def __repr__(self):
return (
"HybridQuantizedTensorStorage("
f"rowwise={type(self._rowwise_storage).__name__}, "
f"columnwise={type(self._columnwise_storage).__name__}, "
f"dtype={self._dtype})"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __repr__ shows NoneType instead of None for missing sub-storages

When _rowwise_storage or _columnwise_storage is None, type(None).__name__ produces the string "NoneType" rather than "None". HybridQuantizedTensor.__repr__ already handles this correctly with an explicit is not None guard.

Comment on lines +75 to +97
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True
rowwise_empty = self.rowwise_quantizer.make_empty(
shape,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.rowwise_quantizer.internal = False

self.columnwise_quantizer.internal = True
columnwise_empty = self.columnwise_quantizer.make_empty(
shape,
dtype=dtype,
device=device,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 make_empty leaves sub-quantizer internal flag set on exception

If make_empty raises, the internal = False reset is skipped and the sub-quantizer is permanently left with internal=True. Consider using a try/finally block for both sub-quantizer flag resets.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Comment on lines +114 to +118
def set_usage(
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
super().set_usage(rowwise=rowwise, columnwise=columnwise)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant:

Suggested change
def set_usage(
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
super().set_usage(rowwise=rowwise, columnwise=columnwise)

Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants