Skip to content

Commit

Permalink
Torch model reader refactor (#1103)
Browse files Browse the repository at this point in the history
Refactor PyTorch model reader & builder to handle weights and tensors that are keyword inputs.
  • Loading branch information
elad-c authored Jun 16, 2024
1 parent 1f62738 commit 920408b
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,
functional_op: Any = None,
inputs_as_list: bool = False,
has_activation: bool = True,
tensor_input_indices = None):
tensor_input_allocs = None):
"""
Init a FunctionalNode object.
Expand All @@ -44,7 +44,7 @@ def __init__(self,
functional_op: The op the node implements.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
has_activation: Whether the node has activations that we might want to quantize.
tensor_input_indices: A list of indices for activation tensors in the node's input tensor list
tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list
"""

Expand All @@ -63,7 +63,7 @@ def __init__(self,
self.op_call_args = op_call_args
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_indices = [] if tensor_input_indices is None else tensor_input_indices
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs

@property
def type(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core import common
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
Expand Down Expand Up @@ -66,31 +67,39 @@ def _build_input_tensors_list(node: BaseNode,
return input_tensors


def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
tensor_input_indices: List = None) -> List:
def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict,
tensor_input_allocs: List = None) -> Tuple[List, Dict]:
"""
Merge input tensors list with positional weights and op_call_args, according to correct order.
Args:
_node: The node the inputs are for.
input_tensors: activation input tensors to node.
op_call_args: framework node call args.
op_call_kwargs: framework node call kwargs.
tensor_input_allocs: List of input allocations to node.
Returns:
Combined list of input_tensors and op_call_args.
"""
if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs:
_input_list = op_call_args.copy()
if tensor_input_indices is None:
tensor_input_indices = _node.tensor_input_indices
assert len(tensor_input_indices) == len(input_tensors), \
f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
for i, t in zip(tensor_input_indices, input_tensors):
_input_list.insert(i, t)
if tensor_input_allocs is None:
tensor_input_allocs = _node.tensor_input_allocs
if len(tensor_input_allocs) != len(input_tensors):
Logger.error(f'Mismatch between input tensors ({len(tensor_input_allocs)}) '
f'and indices {len(input_tensors)} in node {_node.name}.') # pragma: no cover
for i, t in zip(tensor_input_allocs, input_tensors):
# insert input tensors in either args or kwargs, according to tensor_input_allocs
if isinstance(i, str):
assert i not in op_call_kwargs
op_call_kwargs.update({i: t})
else:
_input_list.insert(i, t)
else:
_input_list = input_tensors + op_call_args

return _input_list
return _input_list, op_call_kwargs


def _run_operation(n: BaseNode,
Expand Down Expand Up @@ -125,14 +134,15 @@ def _run_operation(n: BaseNode,
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
for t in input_tensors]
_tensor_input_indices = None
_tensor_input_allocs = None
else:
_tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs = _merge_inputs(n, input_tensors, op_call_args, tensor_input_indices=_tensor_input_indices)
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)

# Add a fake quant node if the node has an activation threshold.
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
OP_CALL_ARGS = 'op_call_args'
OP_CALL_KWARGS = 'op_call_kwargs'
INPUTS_AS_LIST = 'inputs_as_list'
TENSOR_INPUT_INDICES = 'tensor_input_indices'
TENSOR_INPUT_ALLOCS = 'tensor_input_allocs'
INPLACE = 'inplace'
HARDTANH_MIN_VAL = 'min_val'
HARDTANH_MAX_VAL = 'max_val'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def substitute(self,

# When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width)
# this substitution converts it x.reshape((-1, channels, height, width)), so need to update the
# tensor_input_indices attribute.
# scalar argument's shape is [1] so remove those indices from tensor_input_indices
# tensor_input_allocs attribute.
# scalar argument's shape is [1] so remove those indices from tensor_input_allocs
# node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]]
node.tensor_input_indices = node.tensor_input_indices[:sum([i != [1] for i in node.input_shape])]
node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])]

# modify the node input info
node.input_shape = [node.input_shape[0]]
Expand Down
18 changes: 12 additions & 6 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from model_compression_toolkit.core.common.graph.edge import Edge
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_INDICES, GET_ATTR
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
from model_compression_toolkit.logger import Logger

Expand Down Expand Up @@ -140,7 +140,7 @@ def nodes_builder(model: GraphModule,
weights.update({i: consts_dict[input_node]})

tensor_meta = input_node.meta
if tensor_meta[TYPE] == torch.Tensor:
if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
input_shape += [list(tensor_meta[TENSOR_META].shape)]
elif tensor_meta[TYPE] == tuple:
input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
Expand All @@ -159,8 +159,11 @@ def nodes_builder(model: GraphModule,

# filter Nodes from framework attributes, we replace these attributes with nx graph nodes
framework_attr_filtered = {}
framework_attr_nodes = {}
for k, v in framework_attr.items():
if not isinstance(v, torch.fx.node.Node):
if isinstance(v, torch.fx.node.Node):
framework_attr_nodes[k] = v
else:
framework_attr_filtered[k] = v
framework_attr = framework_attr_filtered

Expand All @@ -177,15 +180,18 @@ def nodes_builder(model: GraphModule,
[isinstance(n, torch.fx.node.Node) for n in node.args[0]])
inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
tensor_input_index = []
tensor_input_alloc = []
op_call_args = list(node.args)
if inputs_as_list:
op_call_args.pop(0)
else:
for in_node in node.all_input_nodes:
for i, arg in enumerate(node.args):
if arg == in_node:
tensor_input_index.append(i)
tensor_input_alloc.append(i)
for k, arg in framework_attr_nodes.items():
if arg == in_node:
tensor_input_alloc.append(k)

# remove torch.fx.node.Node from inputs to graph_node_type
op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
Expand All @@ -197,7 +203,7 @@ def nodes_builder(model: GraphModule,
OP_CALL_ARGS: op_call_args,
OP_CALL_KWARGS: node_kwargs,
INPUTS_AS_LIST: inputs_as_list,
TENSOR_INPUT_INDICES: tensor_input_index}
TENSOR_INPUT_ALLOCS: tensor_input_alloc}
else:
graph_node_type = BaseNode
kwargs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,56 @@ def __init__(self, unit_test, func, const, indices):

def create_networks(self):
return ConstRepresentationGetIndexNet(self.func, self.const, self.indices)


class ConstRepresentationCodeNet(nn.Module):
def __init__(self, input_shape):
super().__init__()
self.input_shape = input_shape
self.conv2d = nn.Conv2d(3, 16, 3, 2, padding=1)
self.bn = nn.BatchNorm2d(16)
self.register_buffer('sub_const', 10 * torch.rand((1, 16, 64)))

def forward(self, x):
_shape = x.shape[2:]
x = self.conv2d(x)

# input tensor in kwargs
x = nn.functional.interpolate(x, size=_shape)

# reshaping batch_norm input to 3 axes to avoid bn-folding.
x = x.reshape((-1, 16, int(np.prod(self.input_shape))))

# input const in kwargs (not the first kwargs!)
x = nn.functional.batch_norm(x,
self.bn.running_mean, self.bn.running_var,
momentum=0.2, eps=1e-6, bias=self.bn.bias)

# input all tensors and consts in kwargs
x = torch.sub(input=self.sub_const, other=x)

return torch.reshape(x, (-1, 16) + self.input_shape)


class ConstRepresentationCodeTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test)

def create_networks(self):
return ConstRepresentationCodeNet(self.input_shape[2:])

def get_tpc(self):
tp = generate_test_tp_model({'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_activation_quantization': False})
return generate_pytorch_tpc(name="linear_collapsing_test", tp_model=tp)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
in_torch_tensor = to_torch_tensor(input_x[0])
set_model(float_model)
y = float_model(in_torch_tensor)
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat))
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
7 changes: 5 additions & 2 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@
from tests.pytorch_tests.model_tests.feature_models.metadata_test import MetadataTest
from tests.pytorch_tests.model_tests.feature_models.tpc_test import TpcTest
from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest, ConstRepresentationLinearLayerTest, ConstRepresentationGetIndexTest
ConstRepresentationMultiInputTest, ConstRepresentationLinearLayerTest, ConstRepresentationGetIndexTest, \
ConstRepresentationCodeTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \
AdvancedConstQuantizationTest
Expand Down Expand Up @@ -269,7 +270,9 @@ def test_const_representation(self):
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.ConvTranspose2d(16, 16, 1),
const=c_img, enable_weights_quantization=enable_weights_quantization).run_test()


ConstRepresentationCodeTest(self).run_test()

def test_permute_substitution(self):
"""
This test checks the permute substitution feature
Expand Down

0 comments on commit 920408b

Please sign in to comment.