Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a PretrainedCREPEEmbeddingLoss to training #229

Open
JCBrouwer opened this issue Oct 25, 2020 · 3 comments
Open

Adding a PretrainedCREPEEmbeddingLoss to training #229

JCBrouwer opened this issue Oct 25, 2020 · 3 comments

Comments

@JCBrouwer
Copy link

Hello, I've trained a model for a while using the solo_instrument config at 48 kHz, but the audio is still fairly noisy even after 117k steps (spectral loss is ~9 on average).

I'd like to continue training with the PretrainedCREPEEmbeddingLoss() enabled as well to encourage more natural / perceptually realistic synthesis.

I've tried just adding the loss into the ae.gin file, but get the following error which I don't really understand:

Traceback (most recent call last):
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 231, in <module>
    console_entry_point()
  File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 227, in console_entry_point
    app.run(main)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 205, in main
    report_loss_to_hypertune=FLAGS.hypertune,
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/config.py", line 1078, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/utils.py", line 49, in augment_exception_message_and_reraise
    six.raise_from(proxy.with_traceback(exception.__traceback__), None)
  File "<string>", line 3, in raise_from
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/config.py", line 1055, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/hans/code/maua-ddsp/ddsp/training/train_util.py", line 185, in train
    trainer.build(next(dataset_iter))
  File "/home/hans/code/maua-ddsp/ddsp/training/trainers.py", line 134, in build
    _ = self.run(tf.function(self.model.__call__), batch)
  File "/home/hans/code/maua-ddsp/ddsp/training/trainers.py", line 129, in run
    return self.strategy.run(fn, args=args, kwargs=kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1211, in run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2585, in call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 585, in _call_for_each_replica
    self._container_strategy(), fn, args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_run.py", line 78, in call_for_each_replica
    return wrapped(args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 904, in _call
    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2828, in __call__
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:

    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:896 fn_with_cond  *
        functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py:1180 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/ops/cond_v2.py:92 cond_v2
        op_return_value=pred)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1848 _filtered_call
        cancellation_manager=cancellation_manager)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1877 _call_flat
        for v in self._func_graph.variables:
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:489 variables
        return tuple(deref(v) for v in self._weak_variables)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:489 <genexpr>
        return tuple(deref(v) for v in self._weak_variables)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:482 deref
        "Called a function referencing variables which have been deleted. "

    AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.

  In call to configurable 'train' (<function train at 0x7f995ef00268>)

How can I train with this loss enabled?

@jesseengel
Copy link
Contributor

Can you give more details on your initial config/run command and the one used for restarting the job? Are you warmstarting from the pretrained checkpoint but adding a new loss?

@JCBrouwer
Copy link
Author

Yes I want to warmstart with the pretrained checkpoint. Although I get the same error when training from scratch with the crepe embedding loss added in ae.gin.

My original training command:

python -m ddsp.training.ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz/" \                                                             
  --gin_file=models/solo_instrument.gin \               
  --gin_file=datasets/tfrecord.gin \                      
  --gin_param="TFRecordProvider.file_pattern='/home/hans/datasets/neuro-bass-ddsp/48kHz/train.tfrecord*'" \
  --gin_param="batch_size=16" \
  --gin_param="train_util.train.num_steps=300000" \
  --gin_param="train_util.train.steps_per_save=3000" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=10" \
  --gin_param="TFRecordProvider.example_secs=4" \
  --gin_param="TFRecordProvider.sample_rate=48000" \
  --gin_param="TFRecordProvider.frame_rate=250" \
  --gin_param="Additive.n_samples=192000" \
  --gin_param="Additive.sample_rate=48000" \
  --gin_param="FilteredNoise.n_samples=192000"

Then after having trained overnight, I've added PretrainedCREPEEmbeddingLoss() in ae.gin (which solo_instrument.gin inherits from):

Autoencoder.losses = [
    @losses.SpectralLoss(),
    @losses.PretrainedCREPEEmbeddingLoss(),
]

Then I'm running and getting the error (the error is the same with or without --restore_dir):

python -m ddsp.training.ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz-crepe/" \
  --restore_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz/" \
  --gin_file=models/solo_instrument.gin \        
  --gin_file=datasets/tfrecord.gin \
  --gin_param="TFRecordProvider.file_pattern='/home/hans/datasets/neuro-bass-ddsp/48kHz/train.tfrecord*'" \
  --gin_param="batch_size=16" \
  --gin_param="train_util.train.num_steps=300000" \
  --gin_param="train_util.train.steps_per_save=3000" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=10" \
  --gin_param="TFRecordProvider.example_secs=4" \
  --gin_param="TFRecordProvider.sample_rate=48000" \
  --gin_param="TFRecordProvider.frame_rate=250" \
  --gin_param="Additive.n_samples=192000" \
  --gin_param="Additive.sample_rate=48000" \
  --gin_param="FilteredNoise.n_samples=192000"

@JCBrouwer
Copy link
Author

Update: I've found that running with only a single GPU (via CUDA_VISIBLE_DEVICES=0) does work to train with the PretrainedCREPEEmbeddingLoss.

Is there a way to allow the PretrainedCREPEEmbeddingLoss to work with multi-gpu training?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants