From 56e449583c352cdea040ed79599aa80b58dfd6aa Mon Sep 17 00:00:00 2001 From: JinZr Date: Wed, 9 Oct 2024 12:26:20 +0800 Subject: [PATCH] making balancer a default --- egs/libritts/CODEC/encodec/train.py | 37 ++++++++++++++++------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 459fd72c3d..36842d73ce 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -142,7 +142,7 @@ def get_parser(): parser.add_argument( "--use-balancer", type=str2bool, - default=False, + default=True, help="Whether to use the balancer for gradient scaling.", ) @@ -1114,23 +1114,26 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - balancer = ( - Balancer( - weights={ - "gen_stft_adv_loss": 3.0, - "gen_period_adv_loss": 3.0, - "gen_scale_adv_loss": 3.0, - "feature_stft_loss": 3.0, - "feature_period_loss": 3.0, - "feature_scale_loss": 3.0, - "wav_reconstruction_loss": 0.1, - "mel_reconstruction_loss": 1.0, - } - # this setup follows the one described in the Encodec paper + if params.use_balancer: + balancer = ( + Balancer( + weights={ + "gen_stft_adv_loss": 3.0, + "gen_period_adv_loss": 3.0, + "gen_scale_adv_loss": 3.0, + "feature_stft_loss": 3.0, + "feature_period_loss": 3.0, + "feature_scale_loss": 3.0, + "wav_reconstruction_loss": 0.1, + "mel_reconstruction_loss": 1.0, + } + # this setup follows the one described in the Encodec paper + ) ) - if params.use_balancer - else None - ) + logging.info(f"Using balancer with weights: {balancer.weights}") + else: + balancer = None + for epoch in range(params.start_epoch, params.num_epochs + 1): logging.info(f"Start epoch {epoch}")