@@ -94,6 +94,14 @@ def update(
9494 _pad_and_append_at_idx_ (self .k_observers , layer_idx , k_observer )
9595 _pad_and_append_at_idx_ (self .v_observers , layer_idx , v_observer )
9696
97+ # reshape for per channel scenario
98+ num_heads = key_states .shape [1 ]
99+ head_dim = key_states .shape [- 1 ]
100+ # from [batch_size, num_heads, seq_len - residual_length, head_dim]
101+ # to [batch_size, seq_len - residual_length, num_heads * head_dim]
102+ key_states = key_states .transpose (1 , 2 ).flatten (2 )
103+ value_states = value_states .transpose (1 , 2 ).flatten (2 )
104+
97105 q_key_states = self ._quantize (
98106 key_states .contiguous (), KVCacheScaleType .KEY , layer_idx
99107 )
@@ -106,6 +114,14 @@ def update(
106114 q_value_states , KVCacheScaleType .VALUE , layer_idx
107115 )
108116
117+ # reshape for per channel scenario
118+ # from [batch_size, seq_len - residual_length, num_heads * head_dim]
119+ # to [batch_size, num_heads, seq_len - residual_length, head_dim]
120+ qdq_key_states = qdq_key_states .view (
121+ qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
122+ qdq_value_states = qdq_value_states .view (
123+ qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
124+
109125 keys_to_return , values_to_return = qdq_key_states , qdq_value_states
110126
111127 return keys_to_return , values_to_return
@@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
155171 zps = self .v_zps
156172
157173 scale , zp = observer (tensor )
158- _pad_and_append_at_idx_ (scales , layer_idx , scale )
159- _pad_and_append_at_idx_ (zps , layer_idx , zp )
174+ _pad_and_append_at_idx_ (scales , layer_idx , scale . squeeze () )
175+ _pad_and_append_at_idx_ (zps , layer_idx , zp . squeeze () )
160176
161177 q_tensor = quantize (
162178 x = tensor ,
0 commit comments