Skip to content

Commit 2f66edc

Browse files
Torchao floatx version guard (#12923)
* Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support * Adding torchao version guard for floatx usage Summary: TorchAO removing floatx support, added version guard in quantization_config.py Altered tests in test_torchao.py to version guard floatx Created new test to verify version guard of floatx support --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent be38f41 commit 2f66edc

2 files changed

Lines changed: 67 additions & 23 deletions

File tree

src/diffusers/quantizers/quantization_config.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
457457
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
458458
`float8_e4m3_tensor`, `float8_e4m3_row`,
459459
460-
- **Floating point X-bit quantization:**
460+
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
461461
- Full function names: `fpx_weight_only`
462462
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
463463
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
@@ -531,12 +531,18 @@ def post_init(self):
531531
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
532532

533533
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
534-
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
535-
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
534+
is_floatx_quant_type = self.quant_type.startswith("fp")
535+
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
536+
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
536537
raise ValueError(
537538
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
538539
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
539540
)
541+
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
542+
raise ValueError(
543+
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
544+
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
545+
)
540546

541547
raise ValueError(
542548
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
@@ -622,14 +628,15 @@ def _get_torchao_quant_type_to_method(cls):
622628
float8_dynamic_activation_float8_weight,
623629
float8_static_activation_float8_weight,
624630
float8_weight_only,
625-
fpx_weight_only,
626631
int4_weight_only,
627632
int8_dynamic_activation_int4_weight,
628633
int8_dynamic_activation_int8_weight,
629634
int8_weight_only,
630635
uintx_weight_only,
631636
)
632637

638+
if is_torchao_version("<=", "0.14.1"):
639+
from torchao.quantization import fpx_weight_only
633640
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
634641
from torchao.quantization.observer import PerRow, PerTensor
635642

@@ -650,18 +657,21 @@ def generate_float8dq_types(dtype: torch.dtype):
650657
return types
651658

652659
def generate_fpx_quantization_types(bits: int):
653-
types = {}
660+
if is_torchao_version("<=", "0.14.1"):
661+
types = {}
654662

655-
for ebits in range(1, bits):
656-
mbits = bits - ebits - 1
657-
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
663+
for ebits in range(1, bits):
664+
mbits = bits - ebits - 1
665+
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
658666

659-
non_sign_bits = bits - 1
660-
default_ebits = (non_sign_bits + 1) // 2
661-
default_mbits = non_sign_bits - default_ebits
662-
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
667+
non_sign_bits = bits - 1
668+
default_ebits = (non_sign_bits + 1) // 2
669+
default_mbits = non_sign_bits - default_ebits
670+
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
663671

664-
return types
672+
return types
673+
else:
674+
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
665675

666676
INT4_QUANTIZATION_TYPES = {
667677
# int4 weight + bfloat16/float16 activation
@@ -710,15 +720,15 @@ def generate_fpx_quantization_types(bits: int):
710720
**generate_float8dq_types(torch.float8_e4m3fn),
711721
# float8 weight + float8 activation (static)
712722
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
713-
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
714-
# fpx weight + bfloat16/float16 activation
715-
**generate_fpx_quantization_types(3),
716-
**generate_fpx_quantization_types(4),
717-
**generate_fpx_quantization_types(5),
718-
**generate_fpx_quantization_types(6),
719-
**generate_fpx_quantization_types(7),
720723
}
721724

725+
if is_torchao_version("<=", "0.14.1"):
726+
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
727+
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
728+
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
729+
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
730+
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
731+
722732
UINTX_QUANTIZATION_DTYPES = {
723733
"uintx_weight_only": uintx_weight_only,
724734
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),

tests/quantization/torchao/test_torchao.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,12 @@ def test_quantization(self):
256256
# Cutlass fails to initialize for below
257257
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
258258
# =====
259-
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
260-
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
261259
])
260+
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
261+
QUANTIZATION_TYPES_TO_TEST.extend([
262+
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
263+
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
264+
])
262265
# fmt: on
263266

264267
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
@@ -271,6 +274,34 @@ def test_quantization(self):
271274
)
272275
self._test_quant_type(quantization_config, expected_slice, model_id)
273276

277+
@unittest.skip("Skipping floatx quantization tests")
278+
def test_floatx_quantization(self):
279+
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
280+
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
281+
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
282+
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
283+
self._test_quant_type(
284+
quantization_config,
285+
np.array(
286+
[
287+
0.4648,
288+
0.5195,
289+
0.5547,
290+
0.4180,
291+
0.4434,
292+
0.6445,
293+
0.4316,
294+
0.4531,
295+
0.5625,
296+
]
297+
),
298+
model_id,
299+
)
300+
else:
301+
# Make sure the correct error is thrown
302+
with self.assertRaisesRegex(ValueError, "Please downgrade"):
303+
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
304+
274305
def test_int4wo_quant_bfloat16_conversion(self):
275306
"""
276307
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
@@ -794,8 +825,11 @@ def test_quantization(self):
794825
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
795826
QUANTIZATION_TYPES_TO_TEST.extend([
796827
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
797-
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
798828
])
829+
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
830+
QUANTIZATION_TYPES_TO_TEST.extend([
831+
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
832+
])
799833
# fmt: on
800834

801835
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:

0 commit comments

Comments
 (0)