Skip to content

Commit

Permalink
Deep copy the model in data generation in order not to modify its state.
Browse files Browse the repository at this point in the history
  • Loading branch information
liord committed Sep 25, 2024
1 parent 8402b45 commit ad6f9b2
Showing 1 changed file with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy

import time
from typing import Callable, Any, Tuple, List, Union

Expand Down Expand Up @@ -179,8 +181,11 @@ def pytorch_data_generation_experimental(
# get the model device
device = get_working_device()

# copy model for data generation
model_for_data_gen = copy.deepcopy(model)

# get a static graph representation of the model using torch.fx
fx_model = symbolic_trace(model)
fx_model = symbolic_trace(model_for_data_gen)

# Get Data Generation functions and classes
image_pipeline, normalization, bn_layer_weighting_fn, bn_alignment_loss_fn, output_loss_fn, \
Expand Down Expand Up @@ -208,23 +213,23 @@ def pytorch_data_generation_experimental(
scheduler = scheduler_get_fn(data_generation_config.n_iter)

# Set the current model
set_model(model)
set_model(model_for_data_gen)

# Create an activation extractor object to extract activations from the model
activation_extractor = PytorchActivationExtractor(
model,
model_for_data_gen,
fx_model,
data_generation_config.bn_layer_types,
data_generation_config.last_layer_types)

# Create an orig_bn_stats_holder object to hold original BatchNorm statistics
orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model, data_generation_config.bn_layer_types)
orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model_for_data_gen, data_generation_config.bn_layer_types)
if orig_bn_stats_holder.get_num_bn_layers() == 0:
Logger.critical(
f'Data generation requires a model with at least one BatchNorm layer.') # pragma: no cover

# Create an ImagesOptimizationHandler object for handling optimization
all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model,
all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model_for_data_gen,
data_gen_batch_size=data_generation_config.data_gen_batch_size,
init_dataset=init_dataset,
optimizer=data_generation_config.optimizer,
Expand Down

0 comments on commit ad6f9b2

Please sign in to comment.