Skip to content

Fix label alignment bug in finetuning#278

Open
matevosashot wants to merge 1 commit intoQwenLM:mainfrom
matevosashot:main
Open

Fix label alignment bug in finetuning#278
matevosashot wants to merge 1 commit intoQwenLM:mainfrom
matevosashot:main

Conversation

@matevosashot
Copy link
Copy Markdown

Fix label alignment bug in finetuning

Fixes label alignment issues in sft_12hz.py and modeling_qwen3_tts.py caused by incorrect interaction with ForCausalLMLoss from transformers.

Context

ForCausalLMLoss has two modes: passing labels automatically shifts them left by one (token n predicts n+1), while passing shift_labels uses them as-is.

sft_12hz.py

The old code manually shifted inputs and labels before passing to the talker (inputs_embeds[:, :-1], labels[:, 1:]). This is unnecessary — ForCausalLMLoss already handles the left-shift internally. The fix passes full unshifted tensors and adjusts hidden state slicing accordingly. Also adds the missing text_projection call on text embeddings.

modeling_qwen3_tts.py

The subtalker outputs 15 codes per position. Passing them via labels causes ForCausalLMLoss to shift and drop one, leaving only 14 — misaligning with the 15 logit outputs. The fix passes subtalker labels via shift_labels instead, bypassing the automatic shift.


ForCausalLMLoss implementation:

def ForCausalLMLoss(
    logits, labels, vocab_size,
    num_items_in_batch=None, ignore_index=-100,
    shift_labels=None, **kwargs,
) -> torch.Tensor:
    logits = logits.float()
    if shift_labels is None:
        # Shift so that tokens < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()
    logits = logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    shift_labels = shift_labels.to(logits.device)
    loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant