Skip to content

Commit

Permalink
Fix Keras input_as_list in node builder (#1129)
Browse files Browse the repository at this point in the history
Fixing the test to decide whether a function op layer receives the input as a list or not.
The current test checks to see if the first argument is a list, but some ops (like subtract) can receive the arguments as separate lists representing tensors. In this case, the inputs are not in a single list.
The solution is to manually check for inputs list only on specific operators that can get all inputs in a list.

---------

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
ofirgo and Ofir Gordon authored Jul 21, 2024
1 parent 9d54fe9 commit 58a2b51
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ def build_node(node: KerasNode,
output_shape = keras_layer.get_output_shape_at(io_index)

if layer_class in [TFOpLambda, SlicingOpLambda]:
# Some functional ops (such as tf.concat) should receive the input tensors as a list
# and some are not (such as tf.multiply), so each FunctionalNode holds
# a flag to indicate that.
inputs_as_list = __is_functional_inputs_a_list(op_call_args)
# Some functional ops should receive the input tensors as a list,
# so each FunctionalNode holds a flag to indicate that.
# Other functional ops can receive each argument as list, but in that case not all inputs appear in that list.
inputs_as_list = __is_functional_inputs_a_list(op_call_args, keras_layer)

kwarg2index = get_kwargs2index(keras_layer)

Expand Down Expand Up @@ -301,21 +301,20 @@ def build_node(node: KerasNode,
return node


def __is_functional_inputs_a_list(op_call_args: Any) -> bool:
def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""
Check whether the input tensors should be passed as a list
or not.
or not. This is relevant only for specific TF operators that are specified in the function's condition.
Args:
op_call_args: Arguments list to check.
keras_layer: TFOpLambda layer.
Returns:
Whether the input tensors should be passed as a list or not.
"""

if len(op_call_args) > 0 and isinstance(op_call_args[0], list):
inputs_as_list = True
for arg in op_call_args[0]:
inputs_as_list = inputs_as_list and (is_tensor(arg) or is_const(arg))
return inputs_as_list
return False
return (keras_layer.symbol in
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and
len(op_call_args) > 0 and
isinstance(op_call_args[0], list))

0 comments on commit 58a2b51

Please sign in to comment.