Skip to content

Commit

Permalink
Fix tf imports when tf was not found
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Nov 29, 2023
1 parent c1a3bc1 commit 476fcc8
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions model_compression_toolkit/pruning/keras/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \
set_quantization_configuration_to_graph
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
TargetPlatformCapabilities

if FOUND_TF:
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from tensorflow.keras.models import Model
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL

Expand Down Expand Up @@ -73,7 +72,7 @@ def keras_pruning_experimental(model: Model,
fw_impl)

# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
# as it prepares the graph for compression.
# as it prepares the graph for the pruning process.
float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
quant_config=DEFAULTCONFIG,
mixed_precision_enable=False)
Expand All @@ -88,7 +87,7 @@ def keras_pruning_experimental(model: Model,
target_platform_capabilities)

# Apply the pruning process.
pruned_graph = pruner.get_pruned_graph() #TODO:rename
pruned_graph = pruner.get_pruned_graph() # TODO:rename

# Retrieve pruning information which includes the pruning masks and scores.
pruning_info = pruner.get_pruning_info()
Expand Down

0 comments on commit 476fcc8

Please sign in to comment.