Skip to content

Commit cf7a643

Browse files
committed
make mla weight contiguous
Signed-off-by: Xinyu Chen <[email protected]>
1 parent e38c8e9 commit cf7a643

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

vllm_gaudi/attention/backends/hpu_attn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,13 @@ def _forward_decode( # type: ignore
352352
result = self._v_up_proj(output)
353353
return result
354354

355+
# NOTE(Xinyu): Make the loaded weight contiguous to avoid the transpose
356+
# during each graph execution
357+
def process_weights_after_loading(self, act_dtype: torch.dtype):
358+
super().process_weights_after_loading(act_dtype)
359+
self.W_UV = self.W_UV.contiguous()
360+
self.W_UK_T = self.W_UK_T.contiguous()
361+
355362
# NOTE(Chendi): PR25184 using output buffer as default, which can't be used in HPU Graph,
356363
# so we override and always return a new tensor
357364
def _v_up_proj(self, x):

0 commit comments

Comments
 (0)