Skip to content

Commit

Permalink
Add const quantization for torch.gather ("params" attribute only). (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c authored Sep 23, 2024
1 parent 34d24a4 commit 3eed10b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,14 @@ def _run_operation(n: BaseNode,
else:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
if isinstance(op_func, PytorchQuantizationWrapper) and isinstance(n, FunctionalNode) and n.functional_op is not torch.gather:
# in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
# Temporary patch: for torch.gather this is not the case, so need to merge inputs.
out_tensors_of_n_float = op_func(*input_tensors)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)

# Add a fake quant node if the node has an activation threshold.
out_tensors_of_n = out_tensors_of_n_float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def fully_quantized_wrapper(node: common.BaseNode,
for attr in weight_quantizers if isinstance(attr, int)}
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
# Temporary patch: for torch.gather this is not the case, so args & kwargs shouldn't be
# saved in the warpper.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
} if isinstance(node, FunctionalNode) and not node.functional_op is torch.gather else {}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
split,
chunk,
unbind,
gather,
MaxPool2d])
tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS, [Flatten,
flatten,
Expand All @@ -88,6 +87,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
squeeze,
permute,
transpose])
tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather])
tp.OperationsSetToLayers(OPSET_MERGE_OPS,
[torch.stack, torch.cat, torch.concat, torch.concatenate])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(self):
self.register_buffer('concatenate_const_3', to_torch_tensor(np.random.randint(-128, 127, size=(1, 3, 36, 36))))
self.register_buffer('stack_const_1', to_torch_tensor(np.random.randint(-128, 127, size=(1, 39, 36, 36))))
self.register_buffer('stack_const_2', to_torch_tensor(np.random.randint(-128, 127, size=(1, 39, 36, 36))))
self.register_buffer('gather_const', to_torch_tensor(np.random.randint(-128, 127, size=(1, 2*36*36))))

def forward(self, x):
x = torch.cat([self.cat_const_1, x, self.cat_const_2], dim=2)
Expand All @@ -161,7 +162,10 @@ def forward(self, x):
self.concatenate_const_3, self.concatenate_const_1], dim=1)
x = torch.stack([self.stack_const_1, x, self.stack_const_2], dim=1)
x = torch.reshape(x, (1, 3*39, 36, 36))
return x

inds = torch.argmax(torch.reshape(x, (-1, 117, 36*36)), dim=2)
b = torch.reshape(torch.gather(self.gather_const, 1, inds), (-1, 117, 1, 1))
return x + b


class ConstQuantizationMultiInputTest(BasePytorchFeatureNetworkTest):
Expand All @@ -185,7 +189,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat))
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}')
self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-3), msg=f'fail cosine similarity check: {cs}')

# check quantization layers:
for op in [torch.cat, torch.concat, torch.concatenate, torch.stack]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,16 @@ def __init__(self):
self.const1 = to_torch_tensor(np.random.random((32,)))
self.const2 = to_torch_tensor(np.random.random((32,)))
self.const3 = to_torch_tensor(np.random.random((1, 5, 32, 32)))
self.gather_const = to_torch_tensor(np.random.random((1, 2000)))

def forward(self, x):
x1 = sum(
[self.const1, x, self.const2]) # not really a 3-input add operation, but just in case torch will support it
x = torch.cat([x1, self.const3, x], dim=1)
return x

inds = torch.argmax(torch.reshape(x, (-1, 37*32, 32)), dim=1)
b = torch.reshape(torch.gather(self.gather_const, 1, inds), (-1, 1, 1, 32))
return x + b


class ConstRepresentationMultiInputTest(ConstRepresentationTest):
Expand Down

0 comments on commit 3eed10b

Please sign in to comment.