@@ -115,13 +115,14 @@ def update(
115115 q_value_states , KVCacheScaleType .VALUE , layer_idx
116116 )
117117
118- # reshape for per channel scenario
119- # from [batch_size, seq_len - residual_length, num_heads * head_dim]
120- # to [batch_size, num_heads, seq_len - residual_length, head_dim]
121- qdq_key_states = qdq_key_states .view (
122- qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
123- qdq_value_states = qdq_value_states .view (
124- qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
118+ if self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
119+ # reshape for per channel scenario
120+ # from [batch_size, seq_len - residual_length, num_heads * head_dim]
121+ # to [batch_size, num_heads, seq_len - residual_length, head_dim]
122+ qdq_key_states = qdq_key_states .view (
123+ qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
124+ qdq_value_states = qdq_value_states .view (
125+ qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
125126
126127 keys_to_return , values_to_return = qdq_key_states , qdq_value_states
127128
0 commit comments