Skip to content

Resolve each training epoch resulting in progressively faster output#178

Open
dragonfyre13 wants to merge 1 commit intoQwenLM:mainfrom
dragonfyre13:main
Open

Resolve each training epoch resulting in progressively faster output#178
dragonfyre13 wants to merge 1 commit intoQwenLM:mainfrom
dragonfyre13:main

Conversation

@dragonfyre13
Copy link
Copy Markdown

Unsure if this part on input_embeddings was prepared for sub-talker (the non-autoregressive bit) originally, or some other reason why this might be here. Problem is, for the main autoregressive model (model.talker, which determines pacing, duration, prosody, etc. and is used for predicting codec_0_label) having it breaks causality. The AR model should be learning to look at text/codec layer 0, not including looking at layers 1-15, since these don't exist yet at actual inference time. Basically, instead of learning how to time speech based on text, it appears to be learning to time speech based on... the time speech took. End result, it's "rushing" through the generated audio, seemingly faster and faster with each training epoch performed. Remove this bit of code, and all that goes away. It doesn't even appear to have a distinct impact on loss reduction speed based on my testing with a few data sets (elise, jenny, several of my own used for other models).

Unsure if this part on input_embeddings was prepared for sub-talker (the non-autoregressive bit) originally, or some other reason why this might be here. Problem is, for the main autoregressive model (model.talker, which determines pacing, duration, prosody, etc. and is used for predicting codec_0_label) having it breaks causality. The AR model should be learning to look at text/codec layer 0, not including looking at layers 1-15, since these don't exist yet at actual inference time. Basically, instead of learning how to time speech based on text, it appears to be learning to time speech based on... the time speech took. End result, it's "rushing" through the generated audio, seemingly faster and faster with each training epoch performed. Remove this bit of code, and all that goes away. It doesn't even appear to have a distinct impact on loss reduction speed based on my testing with a few data sets (elise, jenny, several of my own used for other models).
@humblenginr
Copy link
Copy Markdown

hey @dragonfyre13 , did you also notice that the script does double-shifting to the labels? The forward function uses (https://github.com/QwenLM/Qwen3-TTS/blob/main/qwen_tts/core/models/modeling_qwen3_tts.py#L1731) uses the HF PreTrainedModel class's default loss_function, which is ForCausalLMLoss (https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L57) which does internal shifting of labels. And the finetuning data preparation script also does label shifting.

rekuenkdr added a commit to rekuenkdr/Qwen3-TTS-streaming that referenced this pull request Feb 6, 2026
Streaming inference (voice clone) - repetition penalty:

Without repetition penalty, the model can fall into a degenerate state
where it keeps sampling the same codec tokens over and over. This
manifests as:
- Looping audio: the same syllable or sound fragment repeats endlessly
- Extremely long generation: instead of reaching EOS in ~200-500 frames,
  it runs for thousands of frames (up to max_frames=10000)
- Apparent "slowness": a response that should take ~1s of audio takes
  10-30s to generate

The fix works by tracking previously generated first-codebook token IDs
and penalizing them before sampling:
- Tokens with positive logits get divided by repetition_penalty (lowering
  their probability)
- Tokens with negative logits get multiplied by it (pushing them further
  down)

This nudges the model away from re-selecting the same tokens, so it
progresses through the text naturally and reaches EOS in a reasonable
number of steps rather than looping. Default is 1.0 (disabled) and is
exposed through the supported_params whitelist in
stream_generate_voice_clone() so it can be set via generate_config or
user kwargs.

Upstream sync (QwenLM/Qwen3-TTS):
- Bump version 0.0.4 -> 0.1.1 to match upstream release.
- finetuning/sft_12hz.py: weight sub-talker loss by 0.3 factor to
  prevent the code predictor gradient from dominating the main talker
  loss during SFT.
- finetuning/sft_12hz.py: remove sub-codebook embedding accumulation
  loop (codec groups 1-15) from input embeddings, unnecessary and
  harmful for finetuning convergence (upstream PR QwenLM#178).
- finetuning/README.md: update recommended hyperparameters to
  batch_size=32, lr=2e-6, num_epochs=10 for more stable training.
rekuenkdr added a commit to rekuenkdr/Qwen3-TTS-streaming that referenced this pull request Feb 6, 2026
…une-fixes

fix: add repetition penalty to streaming and sync upstream finetuning



Streaming inference (voice clone) - repetition penalty:

Without repetition penalty, the model can fall into a degenerate state where it keeps sampling the same codec tokens over and over. This manifests as:

    Looping audio: the same syllable or sound fragment repeats endlessly
    Extremely long generation: instead of reaching EOS in ~200-500 frames, it runs for thousands of frames (up to max_frames=10000)
    Apparent "slowness": a response that should take ~1s of audio takes 10-30s to generate

The fix works by tracking previously generated first-codebook token IDs and penalizing them before sampling:

    Tokens with positive logits get divided by repetition_penalty (lowering their probability)
    Tokens with negative logits get multiplied by it (pushing them further down)

This nudges the model away from re-selecting the same tokens, so it progresses through the text naturally and reaches EOS in a reasonable number of steps rather than looping. Default is 1.0 (disabled) and is exposed through the supported_params whitelist in
stream_generate_voice_clone() so it can be set via generate_config or user kwargs.

Upstream sync (QwenLM/Qwen3-TTS):

    Bump version 0.0.4 -> 0.1.1 to match upstream release.
    finetuning/sft_12hz.py: weight sub-talker loss by 0.3 factor to prevent the code predictor gradient from dominating the main talker loss during SFT.
    finetuning/sft_12hz.py: remove sub-codebook embedding accumulation loop (codec groups 1-15) from input embeddings, unnecessary and harmful for finetuning convergence (upstream PR Resolve each training epoch resulting in progressively faster output QwenLM#178).
    finetuning/README.md: update recommended hyperparameters to batch_size=32, lr=2e-6, num_epochs=10 for more stable training.
LingyeSoul pushed a commit to LingyeSoul/Qwen3-TTS-Streaming that referenced this pull request Feb 8, 2026
…and-finetune-fixes

fix: add repetition penalty to streaming and sync upstream finetuning



Streaming inference (voice clone) - repetition penalty:

Without repetition penalty, the model can fall into a degenerate state where it keeps sampling the same codec tokens over and over. This manifests as:

    Looping audio: the same syllable or sound fragment repeats endlessly
    Extremely long generation: instead of reaching EOS in ~200-500 frames, it runs for thousands of frames (up to max_frames=10000)
    Apparent "slowness": a response that should take ~1s of audio takes 10-30s to generate

The fix works by tracking previously generated first-codebook token IDs and penalizing them before sampling:

    Tokens with positive logits get divided by repetition_penalty (lowering their probability)
    Tokens with negative logits get multiplied by it (pushing them further down)

This nudges the model away from re-selecting the same tokens, so it progresses through the text naturally and reaches EOS in a reasonable number of steps rather than looping. Default is 1.0 (disabled) and is exposed through the supported_params whitelist in
stream_generate_voice_clone() so it can be set via generate_config or user kwargs.

Upstream sync (QwenLM/Qwen3-TTS):

    Bump version 0.0.4 -> 0.1.1 to match upstream release.
    finetuning/sft_12hz.py: weight sub-talker loss by 0.3 factor to prevent the code predictor gradient from dominating the main talker loss during SFT.
    finetuning/sft_12hz.py: remove sub-codebook embedding accumulation loop (codec groups 1-15) from input embeddings, unnecessary and harmful for finetuning convergence (upstream PR Resolve each training epoch resulting in progressively faster output QwenLM#178).
    finetuning/README.md: update recommended hyperparameters to batch_size=32, lr=2e-6, num_epochs=10 for more stable training.
@MrMuhannadObeidat
Copy link
Copy Markdown

I tried the new repo above from rekunenkdr, it is producing a very decent model so far. I only tested with 10K samples and even on the second epoch I am getting good arabic voice and no speed up.
I will run a bigger test with 60K samples next.
One question, the model produced is now labeled as custom_voice instead of base. This means that creating a voice prompt will not be supported. How can I fix that?

@rekuenkdr
Copy link
Copy Markdown

rekuenkdr commented Feb 9, 2026

I tried the new repo above from rekunenkdr, it is producing a very decent model so far. I only tested with 10K samples and even on the second epoch I am getting good arabic voice and no speed up. I will run a bigger test with 60K samples next. One question, the model produced is now labeled as custom_voice instead of base. This means that creating a voice prompt will not be supported. How can I fix that?

Hi, are you using my streaming fork?
Is not intented for finetuning, (although it should work for that too), but for streaming, I applied several patches, on that direction for a realtime AI assistant I'm developing, and fixed the speed issue for voice clone (as my streaming methods didn't have repetition_penalty at the beggining), not finetune.

I merged this PR just to try it whenever I have the time, but is still untested from my end.

You can try this: #179 (comment)
It seems to fix speed issues with finetune models too, we are all investigating this bug and testing.

All finetune models are labeled custom_voice to differenciate from base, but in essence is the base model, you can still create voice_prompt from them.

Load your checkpoint as you would load base and you are good to go.

rekuenkdr added a commit to rekuenkdr/Qwen3-TTS-streaming that referenced this pull request Feb 9, 2026
Reverts a change we applied from open upstream PR QwenLM#178 (commit f83f184),
which removed the sub-codebook embedding loop arguing the talker shouldn't
see levels 1-15 during training. Investigation of the inference code
(modeling_qwen3_tts.py:1894-1920) shows the sub-talker generates levels
1-15 before the talker's next step, and all 16 levels are summed into
the input. Omitting them from training creates a train/inference mismatch.
guybrush1984 added a commit to guybrush1984/Qwen3-TTS that referenced this pull request Mar 10, 2026
- Extract training loop into run_sft() callable from external code
- Add _train_step() and _save_checkpoint() as modular functions
- Add on_epoch_end callback for volume commits etc.
- Support multi-speaker training via speaker_names list
- Remove codec layers 1-15 from AR input embeddings (PR QwenLM#178 fix)
  These layers break causality and cause speech to accelerate over epochs
- Keep train() CLI entry point for backward compatibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
guybrush1984 added a commit to guybrush1984/Qwen3-TTS that referenced this pull request Mar 10, 2026
- Extract training loop into run_sft() callable from external code
- Add _train_step() and _save_checkpoint() as modular functions
- Add on_epoch_end callback for volume commits etc.
- Support multi-speaker training via speaker_names list
- Remove codec layers 1-15 from AR input embeddings (PR QwenLM#178 fix)
  These layers break causality and cause speech to accelerate over epochs
- Keep train() CLI entry point for backward compatibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
guybrush1984 added a commit to guybrush1984/Qwen3-TTS that referenced this pull request Mar 11, 2026
… in training

Removing codec layers 1-15 from input_embeddings caused trained voices
to speak too fast. Restoring the original behavior from v0.1.2.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants