From ef5289aa71aa221aa2583433d4106bd052babe7e Mon Sep 17 00:00:00 2001 From: elad-c Date: Sun, 22 Sep 2024 13:24:16 +0300 Subject: [PATCH] Add const quantization for torch.gather ("params" attribute only). --- .../pytorch/back2framework/pytorch_model_builder.py | 11 ++++++++--- .../pytorch/builder/fully_quantized_model_builder.py | 4 +++- .../tpc_models/imx500_tpc/v4/tpc_pytorch.py | 2 +- .../feature_models/const_quantization_test.py | 8 ++++++-- .../feature_models/const_representation_test.py | 6 +++++- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 3d7e1f32f..7527ea189 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -145,9 +145,14 @@ def _run_operation(n: BaseNode, else: out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs) else: - merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(), - tensor_input_allocs=_tensor_input_allocs) - out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs) + if isinstance(op_func, PytorchQuantizationWrapper) and isinstance(n, FunctionalNode) and n.functional_op is not torch.gather: + # in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper. + # Temporary patch: for torch.gather this is not the case, so need to merge inputs. + out_tensors_of_n_float = op_func(*input_tensors) + else: + merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(), + tensor_input_allocs=_tensor_input_allocs) + out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs) # Add a fake quant node if the node has an activation threshold. out_tensors_of_n = out_tensors_of_n_float diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py index 64621f265..1c7ebcf91 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py @@ -50,9 +50,11 @@ def fully_quantized_wrapper(node: common.BaseNode, for attr in weight_quantizers if isinstance(attr, int)} # When wrapping functional nodes, need to set call args\kwargs in wrapper, because they # are used during wrapper call method. + # Temporary patch: for torch.gather this is not the case, so args & kwargs shouldn't be + # saved in the warpper. func_node_kwargs = {OP_CALL_ARGS: node.op_call_args, OP_CALL_KWARGS: node.op_call_kwargs - } if isinstance(node, FunctionalNode) else {} + } if isinstance(node, FunctionalNode) and not node.functional_op is torch.gather else {} return PytorchQuantizationWrapper(module, weight_quantizers, weights_values, is_inputs_as_list=node.inputs_as_list, **func_node_kwargs) diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py index f2a10b3d0..a2eabdf79 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py @@ -78,7 +78,6 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): split, chunk, unbind, - gather, MaxPool2d]) tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS, [Flatten, flatten, @@ -88,6 +87,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): squeeze, permute, transpose]) + tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather]) tp.OperationsSetToLayers(OPSET_MERGE_OPS, [torch.stack, torch.cat, torch.concat, torch.concatenate]) diff --git a/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py b/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py index ad865ad73..e239e7fe9 100644 --- a/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py @@ -152,6 +152,7 @@ def __init__(self): self.register_buffer('concatenate_const_3', to_torch_tensor(np.random.randint(-128, 127, size=(1, 3, 36, 36)))) self.register_buffer('stack_const_1', to_torch_tensor(np.random.randint(-128, 127, size=(1, 39, 36, 36)))) self.register_buffer('stack_const_2', to_torch_tensor(np.random.randint(-128, 127, size=(1, 39, 36, 36)))) + self.register_buffer('gather_const', to_torch_tensor(np.random.randint(-128, 127, size=(1, 2*36*36)))) def forward(self, x): x = torch.cat([self.cat_const_1, x, self.cat_const_2], dim=2) @@ -161,7 +162,10 @@ def forward(self, x): self.concatenate_const_3, self.concatenate_const_1], dim=1) x = torch.stack([self.stack_const_1, x, self.stack_const_2], dim=1) x = torch.reshape(x, (1, 3*39, 36, 36)) - return x + + inds = torch.argmax(torch.reshape(x, (-1, 117, 36*36)), dim=2) + b = torch.reshape(torch.gather(self.gather_const, 1, inds), (-1, 117, 1, 1)) + return x + b class ConstQuantizationMultiInputTest(BasePytorchFeatureNetworkTest): @@ -185,7 +189,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= y_hat = quantized_model(in_torch_tensor) self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat)) - self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}') + self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-3), msg=f'fail cosine similarity check: {cs}') # check quantization layers: for op in [torch.cat, torch.concat, torch.concatenate, torch.stack]: diff --git a/tests/pytorch_tests/model_tests/feature_models/const_representation_test.py b/tests/pytorch_tests/model_tests/feature_models/const_representation_test.py index 4c3d9734d..b7a03e19f 100644 --- a/tests/pytorch_tests/model_tests/feature_models/const_representation_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/const_representation_test.py @@ -86,12 +86,16 @@ def __init__(self): self.const1 = to_torch_tensor(np.random.random((32,))) self.const2 = to_torch_tensor(np.random.random((32,))) self.const3 = to_torch_tensor(np.random.random((1, 5, 32, 32))) + self.gather_const = to_torch_tensor(np.random.random((1, 2000))) def forward(self, x): x1 = sum( [self.const1, x, self.const2]) # not really a 3-input add operation, but just in case torch will support it x = torch.cat([x1, self.const3, x], dim=1) - return x + + inds = torch.argmax(torch.reshape(x, (-1, 37*32, 32)), dim=1) + b = torch.reshape(torch.gather(self.gather_const, 1, inds), (-1, 1, 1, 32)) + return x + b class ConstRepresentationMultiInputTest(ConstRepresentationTest):