Skip to content

Commit

Permalink
making balancer a default
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 9, 2024
1 parent e99605d commit 56e4495
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions egs/libritts/CODEC/encodec/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_parser():
parser.add_argument(
"--use-balancer",
type=str2bool,
default=False,
default=True,
help="Whether to use the balancer for gradient scaling.",
)

Expand Down Expand Up @@ -1114,23 +1114,26 @@ def run(rank, world_size, args):
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

balancer = (
Balancer(
weights={
"gen_stft_adv_loss": 3.0,
"gen_period_adv_loss": 3.0,
"gen_scale_adv_loss": 3.0,
"feature_stft_loss": 3.0,
"feature_period_loss": 3.0,
"feature_scale_loss": 3.0,
"wav_reconstruction_loss": 0.1,
"mel_reconstruction_loss": 1.0,
}
# this setup follows the one described in the Encodec paper
if params.use_balancer:
balancer = (
Balancer(
weights={
"gen_stft_adv_loss": 3.0,
"gen_period_adv_loss": 3.0,
"gen_scale_adv_loss": 3.0,
"feature_stft_loss": 3.0,
"feature_period_loss": 3.0,
"feature_scale_loss": 3.0,
"wav_reconstruction_loss": 0.1,
"mel_reconstruction_loss": 1.0,
}
# this setup follows the one described in the Encodec paper
)
)
if params.use_balancer
else None
)
logging.info(f"Using balancer with weights: {balancer.weights}")
else:
balancer = None


for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
Expand Down

0 comments on commit 56e4495

Please sign in to comment.