Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid quantization, allowing a single tensor to store rowwise data in one quantized format (e.g. FP8) and columnwise data in another (e.g. NVFP4), then dispatching the appropriate sub-storage to each GEMM based on the layout string. Key additions include Confidence Score: 4/5Safe to merge for Linear/LayerNorm paths; the P1 in _hybrid_split_quantize only fires on mixed hybrid/non-hybrid GroupedLinear quantizer lists not covered by the new tests. One P1 (_hybrid_split_quantize attribute access crash on mixed lists) keeps the score at 4. All P2s are minor style/robustness issues. GEMM dispatch logic, CUDA nullptr guards, and the large new test suite are correct. transformer_engine/pytorch/module/grouped_linear.py — _has_hybrid_quantizer / _hybrid_split_quantize mismatch Important Files Changed
|
| def _hybrid_split_quantize(tensor, m_splits, quantizers): | ||
| """Grouped split+quantize for HybridQuantizer lists. | ||
|
|
||
| Runs tex.split_quantize twice (once per direction with the native | ||
| sub-quantizers), then zips the results into HybridQuantizedTensorStorage. | ||
| Non-hybrid quantizers in the list fall back to per-split Python quantize. | ||
| """ | ||
| from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage as HybridStorage | ||
|
|
||
| row_quantizers = [q.rowwise_quantizer for q in quantizers] | ||
| col_quantizers = [q.columnwise_quantizer for q in quantizers] | ||
|
|
||
| row_results = tex.split_quantize(tensor, m_splits, row_quantizers) | ||
| col_results = tex.split_quantize(tensor, m_splits, col_quantizers) | ||
|
|
||
| return [ | ||
| HybridStorage( | ||
| rowwise_storage=row, | ||
| columnwise_storage=col, | ||
| rowwise_quantizer=rq, | ||
| columnwise_quantizer=cq, | ||
| quantizer=q, | ||
| fake_dtype=tensor.dtype, |
There was a problem hiding this comment.
_hybrid_split_quantize crashes on mixed-quantizer lists
_has_hybrid_quantizer returns True if any quantizer in the list is a HybridQuantizer, but _hybrid_split_quantize unconditionally accesses q.rowwise_quantizer and q.columnwise_quantizer for every element. If the list contains even one non-hybrid quantizer, this raises AttributeError at runtime.
The docstring claims "Non-hybrid quantizers in the list fall back to per-split Python quantize", but no such fallback exists in the implementation:
row_quantizers = [q.rowwise_quantizer for q in quantizers] # crashes if q is not HybridQuantizer
col_quantizers = [q.columnwise_quantizer for q in quantizers]Either the condition at the call site should assert all-or-nothing hybrid (all(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)), or the function needs to implement the per-element fallback its docstring promises. The same issue applies to all three call sites in both the forward and backward paths.
| "fake_dtype": self._dtype, | ||
| } | ||
|
|
||
| def __repr__(self): | ||
| return ( | ||
| "HybridQuantizedTensorStorage(" | ||
| f"rowwise={type(self._rowwise_storage).__name__}, " | ||
| f"columnwise={type(self._columnwise_storage).__name__}, " | ||
| f"dtype={self._dtype})" | ||
| ) |
There was a problem hiding this comment.
| def make_empty( | ||
| self, | ||
| shape: Iterable[int], | ||
| *, | ||
| dtype: torch.dtype = torch.float32, | ||
| device: Optional[torch.device] = None, | ||
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True | ||
| rowwise_empty = self.rowwise_quantizer.make_empty( | ||
| shape, | ||
| dtype=dtype, | ||
| device=device, | ||
| pin_memory=pin_memory, | ||
| ) | ||
| self.rowwise_quantizer.internal = False | ||
|
|
||
| self.columnwise_quantizer.internal = True | ||
| columnwise_empty = self.columnwise_quantizer.make_empty( | ||
| shape, | ||
| dtype=dtype, | ||
| device=device, |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
| def set_usage( | ||
| self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None | ||
| ) -> None: | ||
| super().set_usage(rowwise=rowwise, columnwise=columnwise) | ||
|
|
There was a problem hiding this comment.
This is redundant:
| def set_usage( | |
| self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None | |
| ) -> None: | |
| super().set_usage(rowwise=rowwise, columnwise=columnwise) |
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Description
Hybrid quantization is functional.
C++ optimizations will come in the next PRs.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: