diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 644aaf864..674fee142 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2647,7 +2647,7 @@ def double_quant( """ coo_tensor = None - quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant( + quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant( A, col_stats, row_stats, @@ -2657,16 +2657,15 @@ def double_quant( ) if threshold > 0.0: - # Build COO tensor for any outliers. - outlier_mask = A.abs() >= threshold - outlier_locations = outlier_mask.nonzero() - outliers = A[outlier_mask] + # Build a COO tensor including all of the outlier columns. + outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32) + outliers = A[:, outlier_cols] coo_tensor = COOSparseTensor( A.shape[0], A.shape[1], outliers.numel(), - outlier_locations[:, 0].int(), - outlier_locations[:, 1].int(), + outlier_rows.repeat_interleave(outliers.size(1)), + outlier_cols.repeat(outliers.size(0)).int(), outliers, )