-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Torchao floatx version guard #12923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchao floatx version guard #12923
Changes from 5 commits
4672860
d8f0c56
dd67722
282fee5
2868e44
a7cb6b5
3f639da
33680db
09242d4
28785b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,9 +256,12 @@ def test_quantization(self): | |
| # Cutlass fails to initialize for below | ||
| # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), | ||
| # ===== | ||
| ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ]) | ||
| if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.0"): | ||
|
||
| QUANTIZATION_TYPES_TO_TEST.extend([ | ||
| ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ]) | ||
| # fmt: on | ||
|
|
||
| for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: | ||
|
|
@@ -271,6 +274,33 @@ def test_quantization(self): | |
| ) | ||
| self._test_quant_type(quantization_config, expected_slice, model_id) | ||
|
|
||
| def test_floatx_quantization(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think floatX is used much so we can avoid this test.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a skip flag for this. We can also just remove it entirely if you think that's better. |
||
| for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: | ||
| if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): | ||
| if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.0"): | ||
| quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) | ||
| self._test_quant_type( | ||
| quantization_config, | ||
| np.array( | ||
| [ | ||
| 0.4648, | ||
| 0.5195, | ||
| 0.5547, | ||
| 0.4180, | ||
| 0.4434, | ||
| 0.6445, | ||
| 0.4316, | ||
| 0.4531, | ||
| 0.5625, | ||
| ] | ||
| ), | ||
| model_id, | ||
| ) | ||
| else: | ||
| # Make sure the correct error is thrown | ||
| with self.assertRaisesRegex(ValueError, "Please downgrade"): | ||
| quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) | ||
|
|
||
| def test_int4wo_quant_bfloat16_conversion(self): | ||
| """ | ||
| Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. | ||
|
|
@@ -794,8 +824,11 @@ def test_quantization(self): | |
| if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): | ||
| QUANTIZATION_TYPES_TO_TEST.extend([ | ||
| ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), | ||
| ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), | ||
| ]) | ||
| if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.0"): | ||
| QUANTIZATION_TYPES_TO_TEST.extend([ | ||
| ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), | ||
| ]) | ||
| # fmt: on | ||
|
|
||
| for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.