Skip to content

Commit 63379bb

Browse files
committed
Remove config functions like int4_weight_only (#3145)
**Summary:** As a follow-up to #2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI
1 parent ff0e461 commit 63379bb

File tree

5 files changed

+21
-150
lines changed

5 files changed

+21
-150
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
243243
244244
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
245245
246-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
246+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
247247
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
248248
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
249249

test/quantization/test_quant_api.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -792,38 +792,28 @@ def test_int4wo_cuda_serialization(self):
792792

793793
def test_config_deprecation(self):
794794
"""
795-
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
795+
Test that old config functions like `Int8DynamicActivationInt4WeightConfig` trigger deprecation warnings.
796796
"""
797797
from torchao.quantization import (
798-
float8_dynamic_activation_float8_weight,
799-
float8_static_activation_float8_weight,
800-
float8_weight_only,
801-
fpx_weight_only,
802-
gemlite_uintx_weight_only,
803-
int4_dynamic_activation_int4_weight,
804-
int4_weight_only,
805-
int8_dynamic_activation_int4_weight,
806-
int8_dynamic_activation_int8_weight,
807-
int8_weight_only,
808-
uintx_weight_only,
798+
Float8StaticActivationFloat8WeightConfig,
799+
FPXWeightOnlyConfig,
800+
GemliteUIntXWeightOnlyConfig,
801+
Int4DynamicActivationInt4WeightConfig,
802+
Int8DynamicActivationInt4WeightConfig,
803+
UIntXWeightOnlyConfig,
809804
)
810805

811806
# Reset deprecation warning state, otherwise we won't log warnings here
812807
warnings.resetwarnings()
813808

814809
# Map from deprecated API to the args needed to instantiate it
815810
deprecated_apis_to_args = {
816-
float8_dynamic_activation_float8_weight: (),
817-
float8_static_activation_float8_weight: (torch.randn(3)),
818-
float8_weight_only: (),
819-
fpx_weight_only: (3, 2),
820-
gemlite_uintx_weight_only: (),
821-
int4_dynamic_activation_int4_weight: (),
822-
int4_weight_only: (),
823-
int8_dynamic_activation_int4_weight: (),
824-
int8_dynamic_activation_int8_weight: (),
825-
int8_weight_only: (),
826-
uintx_weight_only: (torch.uint4,),
811+
Float8StaticActivationFloat8WeightConfig: (torch.randn(3),),
812+
FPXWeightOnlyConfig: (3, 2),
813+
GemliteUIntXWeightOnlyConfig: (),
814+
Int4DynamicActivationInt4WeightConfig: (),
815+
Int8DynamicActivationInt4WeightConfig: (),
816+
UIntXWeightOnlyConfig: (torch.uint4,),
827817
}
828818

829819
# Call each deprecated API twice
@@ -832,19 +822,16 @@ def test_config_deprecation(self):
832822
cls(*args)
833823
cls(*args)
834824

835-
# Each call should have at least one warning.
836-
# Some of them can have two warnings - one for deprecation,
837-
# one for moving to prototype
838-
# 1 warning - just deprecation
839-
# 2 warnings - deprecation and prototype warnings
840-
self.assertTrue(len(_warnings) in (1, 2))
825+
self.assertTrue(len(_warnings) == 1)
841826
found_deprecated = False
842827
for w in _warnings:
843-
if "is deprecated and will be removed in a future release" in str(
828+
if "will be moving to prototype in a future release" in str(
844829
w.message
845830
):
846831
found_deprecated = True
847-
self.assertTrue(found_deprecated)
832+
self.assertTrue(
833+
found_deprecated, f"did not find deprecated warning for {cls}"
834+
)
848835

849836

850837
common_utils.instantiate_parametrized_tests(TestQuantFlow)

torchao/quantization/__init__.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,10 @@
6565
PlainLayout,
6666
TensorCoreTiledLayout,
6767
UIntXWeightOnlyConfig,
68-
float8_dynamic_activation_float8_weight,
69-
float8_static_activation_float8_weight,
70-
float8_weight_only,
71-
fpx_weight_only,
7268
fqn_matches_fqn_config,
73-
gemlite_uintx_weight_only,
74-
int4_dynamic_activation_int4_weight,
75-
int4_weight_only,
76-
int8_dynamic_activation_int4_weight,
77-
int8_dynamic_activation_int8_semi_sparse_weight,
78-
int8_dynamic_activation_int8_weight,
79-
int8_weight_only,
8069
intx_quantization_aware_training,
8170
quantize_,
8271
swap_conv2d_1x1_to_linear,
83-
uintx_weight_only,
8472
)
8573
from .quant_primitives import (
8674
MappingType,
@@ -131,20 +119,8 @@
131119
"ALL_AUTOQUANT_CLASS_LIST",
132120
# top level API - manual
133121
"quantize_",
134-
"int4_dynamic_activation_int4_weight",
135-
"int8_dynamic_activation_int4_weight",
136-
"int8_dynamic_activation_int8_weight",
137-
"int8_dynamic_activation_int8_semi_sparse_weight",
138-
"int4_weight_only",
139-
"int8_weight_only",
140122
"intx_quantization_aware_training",
141-
"float8_weight_only",
142-
"float8_dynamic_activation_float8_weight",
143-
"float8_static_activation_float8_weight",
144-
"uintx_weight_only",
145-
"fpx_weight_only",
146123
"fqn_matches_fqn_config",
147-
"gemlite_uintx_weight_only",
148124
"swap_conv2d_1x1_to_linear",
149125
"Int4DynamicActivationInt4WeightConfig",
150126
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@
9797
to_weight_tensor_with_linear_activation_quantization_metadata,
9898
)
9999
from torchao.utils import (
100-
_ConfigDeprecationWrapper,
101100
is_MI300,
102101
is_sm_at_least_89,
103102
is_sm_at_least_90,
@@ -146,18 +145,7 @@
146145
"autoquant",
147146
"_get_subclass_inserter",
148147
"quantize_",
149-
"int8_dynamic_activation_int4_weight",
150-
"int8_dynamic_activation_int8_weight",
151-
"int8_dynamic_activation_int8_semi_sparse_weight",
152-
"int4_weight_only",
153-
"int8_weight_only",
154148
"intx_quantization_aware_training",
155-
"float8_weight_only",
156-
"uintx_weight_only",
157-
"fpx_weight_only",
158-
"gemlite_uintx_weight_only",
159-
"float8_dynamic_activation_float8_weight",
160-
"float8_static_activation_float8_weight",
161149
"Int8DynActInt4WeightQuantizer",
162150
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
163151
"ModuleFqnToConfig",
@@ -464,7 +452,7 @@ def quantize_(
464452
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
465453
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
466454
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
467-
from torchao.quantization.quant_api import int4_weight_only
455+
from torchao.quantization.quant_api import Int4WeightOnlyConfig
468456
469457
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
470458
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -596,12 +584,6 @@ def __post_init__(self):
596584
)
597585

598586

599-
# for BC
600-
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
601-
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
602-
)
603-
604-
605587
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
606588
def _int8_dynamic_activation_int4_weight_transform(
607589
module: torch.nn.Module,
@@ -970,12 +952,6 @@ def __post_init__(self):
970952
)
971953

972954

973-
# for bc
974-
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
975-
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
976-
)
977-
978-
979955
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
980956
def _int4_dynamic_activation_int4_weight_transform(
981957
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1036,12 +1012,6 @@ def __post_init__(self):
10361012
)
10371013

10381014

1039-
# for BC
1040-
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1041-
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1042-
)
1043-
1044-
10451015
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10461016
def _gemlite_uintx_weight_only_transform(
10471017
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1119,11 +1089,6 @@ def __post_init__(self):
11191089
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11201090

11211091

1122-
# for BC
1123-
# TODO maybe change other callsites
1124-
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1125-
1126-
11271092
def _int4_weight_only_quantize_tensor(weight, config):
11281093
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11291094
# file clean of implementation details
@@ -1335,10 +1300,6 @@ def __post_init__(self):
13351300
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
13361301

13371302

1338-
# for BC
1339-
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1340-
1341-
13421303
def _int8_weight_only_quantize_tensor(weight, config):
13431304
mapping_type = MappingType.SYMMETRIC
13441305
target_dtype = torch.int8
@@ -1503,12 +1464,6 @@ def __post_init__(self):
15031464
)
15041465

15051466

1506-
# for BC
1507-
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1508-
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1509-
)
1510-
1511-
15121467
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15131468
layout = config.layout
15141469
act_mapping_type = config.act_mapping_type
@@ -1614,12 +1569,6 @@ def __post_init__(self):
16141569
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
16151570

16161571

1617-
# for BC
1618-
float8_weight_only = _ConfigDeprecationWrapper(
1619-
"float8_weight_only", Float8WeightOnlyConfig
1620-
)
1621-
1622-
16231572
def _float8_weight_only_quant_tensor(weight, config):
16241573
if config.version == 1:
16251574
warnings.warn(
@@ -1797,12 +1746,6 @@ def __post_init__(self):
17971746
self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum)
17981747

17991748

1800-
# for bc
1801-
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1802-
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1803-
)
1804-
1805-
18061749
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18071750
activation_dtype = config.activation_dtype
18081751
weight_dtype = config.weight_dtype
@@ -2005,12 +1948,6 @@ def __post_init__(self):
20051948
)
20061949

20071950

2008-
# for bc
2009-
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
2010-
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
2011-
)
2012-
2013-
20141951
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
20151952
def _float8_static_activation_float8_weight_transform(
20161953
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2096,12 +2033,6 @@ def __post_init__(self):
20962033
)
20972034

20982035

2099-
# for BC
2100-
uintx_weight_only = _ConfigDeprecationWrapper(
2101-
"uintx_weight_only", UIntXWeightOnlyConfig
2102-
)
2103-
2104-
21052036
@register_quantize_module_handler(UIntXWeightOnlyConfig)
21062037
def _uintx_weight_only_transform(
21072038
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2383,10 +2314,6 @@ def __post_init__(self):
23832314
)
23842315

23852316

2386-
# for BC
2387-
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2388-
2389-
23902317
@register_quantize_module_handler(FPXWeightOnlyConfig)
23912318
def _fpx_weight_only_transform(
23922319
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional, Type
15+
from typing import Any, Callable, Optional
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -368,25 +368,6 @@ def torch_version_at_least(min_version):
368368
return parse_version(torch.__version__) >= parse_version(min_version)
369369

370370

371-
class _ConfigDeprecationWrapper:
372-
"""
373-
A deprecation wrapper that directs users from a deprecated "config function"
374-
(e.g. `int4_weight_only`) to the replacement config class.
375-
"""
376-
377-
def __init__(self, deprecated_name: str, config_cls: Type):
378-
self.deprecated_name = deprecated_name
379-
self.config_cls = config_cls
380-
381-
def __call__(self, *args, **kwargs):
382-
warnings.warn(
383-
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
384-
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
385-
f" quantize_(model, {self.config_cls.__name__}(...))"
386-
)
387-
return self.config_cls(*args, **kwargs)
388-
389-
390371
"""
391372
Helper function for implementing aten op or torch function dispatch
392373
and dispatching to these implementations.

0 commit comments

Comments
 (0)