Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch model reader refactor #1103

Merged
merged 4 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,
functional_op: Any = None,
inputs_as_list: bool = False,
has_activation: bool = True,
tensor_input_indices = None):
tensor_input_allocs = None):
"""
Init a FunctionalNode object.

Expand All @@ -44,7 +44,7 @@ def __init__(self,
functional_op: The op the node implements.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
has_activation: Whether the node has activations that we might want to quantize.
tensor_input_indices: A list of indices for activation tensors in the node's input tensor list
tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list

"""

Expand All @@ -63,7 +63,7 @@ def __init__(self,
self.op_call_args = op_call_args
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_indices = [] if tensor_input_indices is None else tensor_input_indices
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs

@property
def type(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,39 @@ def _build_input_tensors_list(node: BaseNode,
return input_tensors


def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
tensor_input_indices: List = None) -> List:
def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict,
tensor_input_allocs: List = None) -> List:
"""
Merge input tensors list with positional weights and op_call_args, according to correct order.

Args:
_node: The node the inputs are for.
input_tensors: activation input tensors to node.
op_call_args: framework node call args.
op_call_kwargs: framework node call kwargs.
tensor_input_allocs: List of input allocations to node.

Returns:
Combined list of input_tensors and op_call_args.
"""
if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs:
_input_list = op_call_args.copy()
if tensor_input_indices is None:
tensor_input_indices = _node.tensor_input_indices
assert len(tensor_input_indices) == len(input_tensors), \
f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
for i, t in zip(tensor_input_indices, input_tensors):
_input_list.insert(i, t)
if tensor_input_allocs is None:
tensor_input_allocs = _node.tensor_input_allocs
assert len(tensor_input_allocs) == len(input_tensors), \
f'Mismatch between input tensors ({len(tensor_input_allocs)}) and indices {len(input_tensors)}'
for i, t in zip(tensor_input_allocs, input_tensors):
if isinstance(i, str):
if i in op_call_kwargs:
a=1
assert i not in op_call_kwargs
op_call_kwargs.update({i: t})
else:
_input_list.insert(i, t)
else:
_input_list = input_tensors + op_call_args

return _input_list
return _input_list, op_call_kwargs


def _run_operation(n: BaseNode,
Expand Down Expand Up @@ -125,14 +133,15 @@ def _run_operation(n: BaseNode,
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
for t in input_tensors]
_tensor_input_indices = None
_tensor_input_allocs = None
else:
_tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs = _merge_inputs(n, input_tensors, op_call_args, tensor_input_indices=_tensor_input_indices)
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)

# Add a fake quant node if the node has an activation threshold.
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
OP_CALL_ARGS = 'op_call_args'
OP_CALL_KWARGS = 'op_call_kwargs'
INPUTS_AS_LIST = 'inputs_as_list'
TENSOR_INPUT_INDICES = 'tensor_input_indices'
TENSOR_INPUT_ALLOCS = 'tensor_input_allocs'
INPLACE = 'inplace'
HARDTANH_MIN_VAL = 'min_val'
HARDTANH_MAX_VAL = 'max_val'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def substitute(self,

# When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width)
# this substitution converts it x.reshape((-1, channels, height, width)), so need to update the
# tensor_input_indices attribute.
# scalar argument's shape is [1] so remove those indices from tensor_input_indices
# tensor_input_allocs attribute.
# scalar argument's shape is [1] so remove those indices from tensor_input_allocs
# node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]]
node.tensor_input_indices = node.tensor_input_indices[:sum([i != [1] for i in node.input_shape])]
node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])]

# modify the node input info
node.input_shape = [node.input_shape[0]]
Expand Down
18 changes: 12 additions & 6 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from model_compression_toolkit.core.common.graph.edge import Edge
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_INDICES, GET_ATTR
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
from model_compression_toolkit.logger import Logger

Expand Down Expand Up @@ -140,7 +140,7 @@ def nodes_builder(model: GraphModule,
weights.update({i: consts_dict[input_node]})

tensor_meta = input_node.meta
if tensor_meta[TYPE] == torch.Tensor:
if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
input_shape += [list(tensor_meta[TENSOR_META].shape)]
elif tensor_meta[TYPE] == tuple:
input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
Expand All @@ -159,8 +159,11 @@ def nodes_builder(model: GraphModule,

# filter Nodes from framework attributes, we replace these attributes with nx graph nodes
framework_attr_filtered = {}
framework_attr_nodes = {}
for k, v in framework_attr.items():
if not isinstance(v, torch.fx.node.Node):
if isinstance(v, torch.fx.node.Node):
framework_attr_nodes[k] = v
else:
framework_attr_filtered[k] = v
framework_attr = framework_attr_filtered

Expand All @@ -177,15 +180,18 @@ def nodes_builder(model: GraphModule,
[isinstance(n, torch.fx.node.Node) for n in node.args[0]])
inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
tensor_input_index = []
tensor_input_alloc = []
op_call_args = list(node.args)
if inputs_as_list:
op_call_args.pop(0)
else:
for in_node in node.all_input_nodes:
for i, arg in enumerate(node.args):
if arg == in_node:
tensor_input_index.append(i)
tensor_input_alloc.append(i)
for k, arg in framework_attr_nodes.items():
if arg == in_node:
tensor_input_alloc.append(k)

# remove torch.fx.node.Node from inputs to graph_node_type
op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
Expand All @@ -197,7 +203,7 @@ def nodes_builder(model: GraphModule,
OP_CALL_ARGS: op_call_args,
OP_CALL_KWARGS: node_kwargs,
INPUTS_AS_LIST: inputs_as_list,
TENSOR_INPUT_INDICES: tensor_input_index}
TENSOR_INPUT_ALLOCS: tensor_input_alloc}
else:
graph_node_type = BaseNode
kwargs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,56 @@ def __init__(self, unit_test, func, const, indices):

def create_networks(self):
return ConstRepresentationGetIndexNet(self.func, self.const, self.indices)


class ConstRepresentationCodeNet(nn.Module):
def __init__(self, input_shape):
super().__init__()
self.input_shape = input_shape
self.conv2d = nn.Conv2d(3, 16, 3, 2, padding=1)
self.bn = nn.BatchNorm2d(16)
self.register_buffer('sub_const', 10 * torch.rand((1, 16, 64)))

def forward(self, x):
_shape = x.shape[2:]
x = self.conv2d(x)

# input tensor in kwargs
x = nn.functional.interpolate(x, size=_shape)

# reshaping batch_norm input to 3 axes to avoid bn-folding.
x = x.reshape((-1, 16, int(np.prod(self.input_shape))))

# input const in kwargs (not the first kwargs!)
x = nn.functional.batch_norm(x,
self.bn.running_mean, self.bn.running_var,
momentum=0.2, eps=1e-6, bias=self.bn.bias)

# input all tensors and consts in kwargs
x = torch.sub(input=self.sub_const, other=x)

return torch.reshape(x, (-1, 16) + self.input_shape)


class ConstRepresentationCodeTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test)

def create_networks(self):
return ConstRepresentationCodeNet(self.input_shape[2:])

def get_tpc(self):
tp = generate_test_tp_model({'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_activation_quantization': False})
return generate_pytorch_tpc(name="linear_collapsing_test", tp_model=tp)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
in_torch_tensor = to_torch_tensor(input_x[0])
set_model(float_model)
y = float_model(in_torch_tensor)
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat))
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
7 changes: 5 additions & 2 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@
from tests.pytorch_tests.model_tests.feature_models.metadata_test import MetadataTest
from tests.pytorch_tests.model_tests.feature_models.tpc_test import TpcTest
from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest, ConstRepresentationLinearLayerTest, ConstRepresentationGetIndexTest
ConstRepresentationMultiInputTest, ConstRepresentationLinearLayerTest, ConstRepresentationGetIndexTest, \
ConstRepresentationCodeTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \
AdvancedConstQuantizationTest
Expand Down Expand Up @@ -269,7 +270,9 @@ def test_const_representation(self):
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.ConvTranspose2d(16, 16, 1),
const=c_img, enable_weights_quantization=enable_weights_quantization).run_test()


ConstRepresentationCodeTest(self).run_test()

def test_permute_substitution(self):
"""
This test checks the permute substitution feature
Expand Down
Loading