Skip to content

Commit c6f5032

Browse files
liudangyicopybara-github
authored andcommitted
Make HowToQuantize fields optional with default values.
This simplifies the usage of `qarray.HowToQuantize` by providing default values for `channelwise_axes`, `tiled_axes`, and `calibration_method`. Tests are updated to reflect this change. PiperOrigin-RevId: 811189153
1 parent af61307 commit c6f5032

File tree

4 files changed

+9
-31
lines changed

4 files changed

+9
-31
lines changed

qwix/_src/core/qarray.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,19 @@ class HowToQuantize:
241241
qtype: jax.typing.DTypeLike
242242
# Channelwise axes will have individual scales, which has the same effect
243243
# as setting their tile sizes to 1 in tiled_axes.
244-
channelwise_axes: Collection[int]
244+
channelwise_axes: Collection[int] = ()
245245
# Tiled axes have subchannel quantization enabled. The value is a mapping
246246
# from the tiled axis to the tile size. If the tile size is a float, it has
247247
# to be "1 / tile_count" and the actual tile size will be
248248
# round(axis_size * tile_size). Note that 1 and 1.0 have very different
249249
# meanings: a tile size of 1 means to use per-channel scale, while a
250250
# tile size of 1.0 means to use shared scale.
251-
tiled_axes: Mapping[int, int | float]
251+
tiled_axes: Mapping[int, int | float] = dataclasses.field(
252+
default_factory=dict
253+
)
252254
# The calibration method to use. The format is <method>[,<args>], e.g.
253255
# "absmax" or "fixed,-10,10". Check calibrate() for supported methods.
254-
calibration_method: str
256+
calibration_method: str = 'absmax'
255257

256258

257259
ShapeT: TypeAlias = Sequence[int]

tests/core/einsum_test.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,7 @@ def test_fake_quantization(self):
275275
fp_res = jnp.einsum(
276276
einsum_str, lhs, rhs, precision=jax.lax.Precision.HIGHEST
277277
)
278-
how = qarray.HowToQuantize(
279-
qtype=jnp.int8,
280-
channelwise_axes=(),
281-
tiled_axes={},
282-
calibration_method='absmax',
283-
)
278+
how = qarray.HowToQuantize(qtype=jnp.int8)
284279
lhs = qarray.quantize(lhs, how)
285280
rhs = qarray.quantize(rhs, how)
286281

@@ -321,21 +316,11 @@ def test_dequant_on_inputs(self):
321316
rhs = self._make_array((128, 128, 16), jnp.bfloat16)
322317
lhs = qarray.quantize(
323318
lhs,
324-
qarray.HowToQuantize(
325-
qtype=jnp.int8,
326-
channelwise_axes=(0, 1),
327-
tiled_axes={},
328-
calibration_method='absmax',
329-
),
319+
qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=(0, 1)),
330320
)
331321
rhs = qarray.quantize(
332322
rhs,
333-
qarray.HowToQuantize(
334-
qtype=jnp.int8,
335-
channelwise_axes=(0, 2),
336-
tiled_axes={},
337-
calibration_method='absmax',
338-
),
323+
qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=(0, 2)),
339324
)
340325
out = einsum.einsum('TNH,NHD -> TD', lhs, rhs)
341326
self.assertEqual(out.shape, (16, 16))

tests/core/pallas_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,7 @@ def dequantize_pallas(q: qarray.QArray):
211211
x = jax.random.uniform(jax.random.key(0), input_shape, jnp.float32)
212212
how = qarray.HowToQuantize(
213213
qtype="int8",
214-
channelwise_axes=[],
215214
tiled_axes=tiled_axes,
216-
calibration_method="absmax",
217215
)
218216
qx = qarray.quantize(x, how)
219217
self.assertTrue(jnp.allclose(dequantize_pallas(qx), qarray.dequantize(qx)))
@@ -262,7 +260,6 @@ def pallas_batch_matmul(
262260
qtype="int8",
263261
channelwise_axes=[0, 1],
264262
tiled_axes={2: 128},
265-
calibration_method="absmax",
266263
)
267264
qx = qarray.quantize(
268265
jax.random.uniform(jax.random.key(0), (4, 256, 256), jnp.float32), x_how
@@ -271,7 +268,6 @@ def pallas_batch_matmul(
271268
qtype="int8",
272269
channelwise_axes=[2],
273270
tiled_axes={1: 128},
274-
calibration_method="absmax",
275271
)
276272
qy = qarray.quantize(
277273
jax.random.uniform(jax.random.key(1), (4, 256, 256), jnp.float32), y_how

tests/core/qarray_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,7 @@ def test_exact_quantization(self, with_error):
153153
array += jax.random.uniform(
154154
jax.random.key(42), array.shape, minval=-1e-7, maxval=1e-7
155155
)
156-
how = qarray.HowToQuantize(
157-
qtype=jnp.int8,
158-
channelwise_axes=[],
159-
tiled_axes={},
160-
calibration_method='minmax',
161-
)
156+
how = qarray.HowToQuantize(qtype=jnp.int8, calibration_method='minmax')
162157
q_array = qarray.quantize(array, how)
163158
self.assertEqual(q_array.zero_point, jnp.array(-128, dtype=jnp.int8), array)
164159
expected_q_array = jnp.arange(-128, 128, dtype=jnp.int8)

0 commit comments

Comments
 (0)