@@ -105,34 +105,3 @@ def test_group(
105105 model_shape [1 ],
106106 int (model_shape [0 ] / group_size ),
107107 )
108-
109-
110- @torch .no_grad
111- @pytest .mark .parametrize ("input_symmetry" , [True , False ])
112- @pytest .mark .parametrize ("weight_symmetry" , [True , False ])
113- @pytest .mark .parametrize ("input_shape" , [(32 , 256 ), (300 , 200 ), (400 , 400 )])
114- def test_token (
115- mock_per_channel_calibration ,
116- mock_per_token_calibration ,
117- input_symmetry ,
118- weight_symmetry ,
119- input_shape ,
120- ):
121- model = Linear (input_shape [1 ], 256 )
122- quant_config = create_config (
123- input_symmetry ,
124- weight_symmetry ,
125- w_strategy = QuantizationStrategy .CHANNEL ,
126- i_strategy = QuantizationStrategy .TOKEN ,
127- )
128- apply_quantization_config (model , quant_config )
129-
130- inputs = torch .randn (input_shape )
131- mock_per_channel_calibration (model , base_name = "weight" , value = model .weight )
132- mock_per_token_calibration (model , base_name = "input" , value = inputs )
133-
134- assert model .input_scale .shape == (1 , 1 )
135- assert model .input_zero_point .shape == (1 , 1 )
136-
137- assert model .weight_scale .shape == (256 , 1 )
138- assert model .weight_zero_point .shape == (256 , 1 )
0 commit comments