-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Torchao floatx version guard #12923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchao floatx version guard #12923
Changes from all commits
4672860
d8f0c56
dd67722
282fee5
2868e44
a7cb6b5
3f639da
33680db
09242d4
28785b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,9 +256,12 @@ def test_quantization(self): | |
| # Cutlass fails to initialize for below | ||
| # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), | ||
| # ===== | ||
| ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ]) | ||
| if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): | ||
| QUANTIZATION_TYPES_TO_TEST.extend([ | ||
| ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | ||
| ]) | ||
| # fmt: on | ||
|
|
||
| for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: | ||
|
|
@@ -271,6 +274,34 @@ def test_quantization(self): | |
| ) | ||
| self._test_quant_type(quantization_config, expected_slice, model_id) | ||
|
|
||
| @unittest.skip("Skipping floatx quantization tests") | ||
| def test_floatx_quantization(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think floatX is used much so we can avoid this test.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a skip flag for this. We can also just remove it entirely if you think that's better. |
||
| for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: | ||
| if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): | ||
| if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): | ||
| quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) | ||
| self._test_quant_type( | ||
| quantization_config, | ||
| np.array( | ||
| [ | ||
| 0.4648, | ||
| 0.5195, | ||
| 0.5547, | ||
| 0.4180, | ||
| 0.4434, | ||
| 0.6445, | ||
| 0.4316, | ||
| 0.4531, | ||
| 0.5625, | ||
| ] | ||
| ), | ||
| model_id, | ||
| ) | ||
| else: | ||
| # Make sure the correct error is thrown | ||
| with self.assertRaisesRegex(ValueError, "Please downgrade"): | ||
| quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) | ||
|
|
||
| def test_int4wo_quant_bfloat16_conversion(self): | ||
| """ | ||
| Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. | ||
|
|
@@ -794,8 +825,11 @@ def test_quantization(self): | |
| if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): | ||
| QUANTIZATION_TYPES_TO_TEST.extend([ | ||
| ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), | ||
| ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), | ||
| ]) | ||
| if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): | ||
| QUANTIZATION_TYPES_TO_TEST.extend([ | ||
| ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), | ||
| ]) | ||
| # fmt: on | ||
|
|
||
| for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also provide a recommendation on what needs to be done if a user wants to use X-bit ops in torchao >= 0.15.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they will need to downgrade to 0.14.1 or lower, we are removing this because we didn't see much usage for this feature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make that clear in the error message then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an error message for this in line 541 for if the user is trying to use the floatx in version >0.14.1. Is there a better place for it?