-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Thank you for open-sourcing this excellent work!
I'm encountering an issue when applying s2ft to Qwen3:
IndexError: index 12 is out of bounds for dimension 0 with size 8
The error occurs specifically at this code block:
S2FT/experiments/utils/s2_utils.py
Lines 68 to 72 in a58edbb
| q_weight = q_weight.reshape( | |
| model.config.num_key_value_heads, -1, q_weight.shape[-1] | |
| ) | |
| layer.self_attn.q_proj.weight.data = q_weight[order, :, :].reshape( | |
| -1, q_weight.shape[-1] |
Root Cause Analysis
For the Qwen3-4B model:
model.config.num_key_value_heads = 8
model.config.num_attention_heads = 32
This configuration causes the order variable to contain values up to 31. However, the reshaped q_weight tensor only has a dimension size of 8 for num_key_value_heads. Attempting to access index 12 in q_weight[order, :, :] triggers the IndexError since the first dimension is out-of-bounds.
Request
Could you please advise how to modify the code to properly handle Qwen's grouped-query attention architecture?