Skip to content
Merged
50 changes: 30 additions & 20 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,

- **Floating point X-bit quantization:**
- **Floating point X-bit quantization:** (in torchao <= 0.14.0, not supported in torchao >= 0.15.0)
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
Expand Down Expand Up @@ -531,12 +531,18 @@ def post_init(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()

if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
is_floatx_quant_type = self.quant_type.startswith("fp")
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.0"):
raise ValueError(
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.0. "
f"Please downgrade to torchao <= 0.14.0 to use this quantization type."
)

raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
Expand Down Expand Up @@ -622,14 +628,15 @@ def _get_torchao_quant_type_to_method(cls):
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)

if is_torchao_version("<=", "0.14.0"):
from torchao.quantization import fpx_weight_only
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor

Expand All @@ -650,18 +657,21 @@ def generate_float8dq_types(dtype: torch.dtype):
return types

def generate_fpx_quantization_types(bits: int):
types = {}
if is_torchao_version("<=", "0.14.0"):
types = {}

for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)

non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)

return types
return types
else:
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")

INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
Expand Down Expand Up @@ -710,15 +720,15 @@ def generate_fpx_quantization_types(bits: int):
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
# fpx weight + bfloat16/float16 activation
**generate_fpx_quantization_types(3),
**generate_fpx_quantization_types(4),
**generate_fpx_quantization_types(5),
**generate_fpx_quantization_types(6),
**generate_fpx_quantization_types(7),
}

if is_torchao_version("<=", "0.14.0"):
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))

UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
Expand Down
39 changes: 36 additions & 3 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,12 @@ def test_quantization(self):
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.0"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use is_torchao_version here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import would be a bit ugly for it lives in a different folder, also it seemed like other torchao version checks in that file seemed to use this syntax. Should I still change it?

QUANTIZATION_TYPES_TO_TEST.extend([
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on

for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
Expand All @@ -271,6 +274,33 @@ def test_quantization(self):
)
self._test_quant_type(quantization_config, expected_slice, model_id)

def test_floatx_quantization(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think floatX is used much so we can avoid this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a skip flag for this. We can also just remove it entirely if you think that's better.

for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.0"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
self._test_quant_type(
quantization_config,
np.array(
[
0.4648,
0.5195,
0.5547,
0.4180,
0.4434,
0.6445,
0.4316,
0.4531,
0.5625,
]
),
model_id,
)
else:
# Make sure the correct error is thrown
with self.assertRaisesRegex(ValueError, "Please downgrade"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])

def test_int4wo_quant_bfloat16_conversion(self):
"""
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
Expand Down Expand Up @@ -794,8 +824,11 @@ def test_quantization(self):
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.0"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
# fmt: on

for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
Expand Down