Skip to content

Commit 9aadef9

Browse files
committed
defaults
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a4aba3a commit 9aadef9

File tree

2 files changed

+34
-43
lines changed

2 files changed

+34
-43
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,6 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
263263
block_structure = model.block_structure
264264
actorder = model.actorder
265265
dynamic = model.dynamic
266-
observer = model.observer
267-
dynamic = model.dynamic
268266

269267
# infer strategy
270268
if strategy is None:
@@ -316,45 +314,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
316314
"activation ordering"
317315
)
318316

319-
# infer observer w.r.t. dynamic
320-
if dynamic:
321-
supported_strategies = (
322-
QuantizationStrategy.TOKEN,
323-
QuantizationStrategy.TENSOR,
324-
QuantizationStrategy.TENSOR_GROUP,
325-
QuantizationStrategy.GROUP,
326-
)
327-
if strategy not in supported_strategies:
328-
raise ValueError(
329-
f"One of {supported_strategies} must be used for dynamic quant."
330-
)
331-
332-
if (
333-
dynamic == DynamicType.LOCAL
334-
and strategy != QuantizationStrategy.TENSOR_GROUP
335-
):
336-
raise ValueError("local is only supported for strategy tensor_group")
337-
338-
if observer is not None:
339-
if dynamic is True: # checking if dynamic is True, not "local"
340-
if (
341-
observer != "memoryless"
342-
): # avoid annoying users with old configs
343-
warnings.warn(
344-
"No observer is used for dynamic quant., setting to None"
345-
)
346-
observer = None
347-
else:
348-
if dynamic == DynamicType.LOCAL:
349-
observer = "minmax"
350-
351-
elif observer is None:
352-
# default to minmax for non-dynamic cases
353-
observer = "minmax"
354-
355317
# write back modified values
356318
model.strategy = strategy
357-
model.observer = observer
358319
return model
359320

360321
def pytorch_dtype(self) -> torch.dtype:
@@ -373,10 +334,6 @@ def pytorch_dtype(self) -> torch.dtype:
373334
else:
374335
raise ValueError(f"Invalid quantization type {self.type}")
375336

376-
@deprecated("QuantizationArgs.observer")
377-
def get_observer(self) -> str:
378-
return self.observer
379-
380337
model_config = ConfigDict(extra="forbid")
381338

382339

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
5959
weights = model.weights
6060
format = model.format
6161

62+
# validate input args
6263
if inputs is not None:
6364
if inputs.strategy not in (
6465
QuantizationStrategy.TOKEN,
@@ -84,15 +85,21 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
8485
if inputs.actorder is not None:
8586
raise ValueError("Cannot apply actorder to input activations")
8687

88+
if inputs.observer is None:
89+
inputs.observer
90+
91+
# validate output args
8792
if outputs is not None:
8893
if outputs.actorder is not None:
8994
raise ValueError("Cannot apply actorder to output activations")
9095

96+
# validate format
9197
if format == CompressionFormat.mixed_precision.value:
9298
raise ValueError(
9399
"mixed-precision cannot be set as a format for a QuantizationScheme"
94100
)
95101

102+
# validate matching group sizes
96103
if (
97104
inputs
98105
and weights
@@ -110,8 +117,35 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
110117
stacklevel=2,
111118
)
112119

120+
# set observer defaults
121+
model._validate_observers()
122+
113123
return model
114124

125+
def _validate_observers(self):
126+
inputs = self.input_activations
127+
weights = self.weights
128+
outputs = self.output_activations
129+
130+
if inputs is not None and inputs.observer is None:
131+
if inputs.dynamic:
132+
inputs.observer = "memoryless-minmax"
133+
else:
134+
inputs.observer = "static-minmax"
135+
136+
if weights is not None and weights.observer is None:
137+
weights.observer = "memoryless-minmax"
138+
139+
if outputs is not None and outputs.observer is None:
140+
if outputs.dynamic:
141+
outputs.observer = "memoryless-minmax"
142+
else:
143+
outputs.observer = "static-minmax"
144+
145+
self.input_activations = inputs
146+
self.weights = weights
147+
self.output_activations = outputs
148+
115149
model_config = ConfigDict(extra="forbid")
116150

117151

0 commit comments

Comments
 (0)