Skip to content

Commit

Permalink
Improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Jan 6, 2025
1 parent 9977952 commit 2e8a70d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __eq__(self, other) -> bool:
"""
if isinstance(other, Cut):
return self.mem_elements == other.mem_elements
return False
return False # pragma: no cover

def __hash__(self):
return hash((frozenset(self.op_order), frozenset(self.op_record), self.mem_elements))

def __repr__(self):
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>"
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,6 @@
BATCH_INPUT_SHAPE = 'batch_input_shape'


def get_node_name_from_layer(layer: Layer) -> str:
"""
Get a node's name from the layer it was built from. For TensorFlowOpLayer
we remove the prefix "tf_op_layer".
Args:
layer: Keras Layer to get its corresponding node's name.
Returns:
Name of the node that was built from the passed layer.
"""

name = layer.name
if isinstance(layer, TensorFlowOpLayer): # remove TF op layer prefix
name = '_'.join(name.split('_')[3:])
return name


def is_layer_fake_quant(layer: Layer) -> bool:
"""
Check whether a Keras layer is a fake quantization layer or not.
Args:
layer: Layer to check if it's a fake quantization layer or not.
Returns:
Whether a Keras layer is a fake quantization layer or not.
"""
# in tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda
return (isinstance(layer, TensorFlowOpLayer) and layer.node_def.op == FQ_NODE_OP_V2_3) or (
isinstance(layer, TFOpLambda) and layer.symbol == FQ_NODE_OP_V2_4)


class KerasModelBuilder(BaseModelBuilder):
"""
Builder for Keras models.
Expand Down Expand Up @@ -291,7 +259,7 @@ def _run_operation(self,
arg = n.weights.get(pos)
if arg is None:
if len(input_tensors) == 0:
Logger.critical(f"Couldn't find a weight or input tensor matching operator's "
Logger.critical(f"Couldn't find a weight or input tensor matching operator's " # pragma: no cover
f"argument name '{k}' in location {pos} for node {n.name}.")
arg = input_tensors.pop(0)
op_call_kwargs.update({k: arg})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def substitute(self,
strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides

padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID'
Expand All @@ -153,7 +153,7 @@ def substitute(self,
dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
Expand Down

0 comments on commit 2e8a70d

Please sign in to comment.