diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index 600f51d..a9563cb 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -2,7 +2,7 @@ import time from typing import Dict, List, Optional, Tuple, Union -from tllm import BACKEND, DTYPE, BackendEnum +from tllm import BACKEND, DEVICE, DTYPE, BackendEnum from tllm.schemas import MIX_TENSOR if BACKEND == BackendEnum.MLX: @@ -16,7 +16,7 @@ import torch cat_func = lambda tensors: torch.cat(tensors, dim=0) - zeros_func = lambda x0, x1, x2: torch.zeros(size=(x0, x1, x2), dtype=DTYPE) + zeros_func = lambda x0, x1, x2: torch.zeros(size=(x0, x1, x2), dtype=DTYPE, device=DEVICE) array_func = lambda x: torch.tensor([x], dtype=torch.long) arange_func = lambda x: torch.arange(0, x, dtype=torch.long) diff --git a/tllm/models/mlx/layers.py b/tllm/models/mlx/layers.py index 90a27d4..0f296dc 100644 --- a/tllm/models/mlx/layers.py +++ b/tllm/models/mlx/layers.py @@ -5,8 +5,7 @@ import mlx.nn as nn from mlx_lm.models.llama import MLP, Attention, ModelArgs, TransformerBlock, initialize_rope -from tllm import DTYPE -from tllm.commons.cache import AttentionData, RequestsCache, cat_func +from tllm.commons.cache import AttentionData, RequestsCache, cat_func, zeros_func class BaseParallelLayer(nn.Module): @@ -146,8 +145,8 @@ def __init__(self, args, layer_idx: int, offset: int): self.rope = initialize_rope(args) # self.max_seq_len = 1024 - # self._k_cache = mx.zeros(shape=(self.max_seq_len, self.n_kv_heads, self.head_dim), dtype=DTYPE) - # self._v_cache = mx.zeros(shape=(self.max_seq_len, self.n_kv_heads, self.head_dim), dtype=DTYPE) + # self._k_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim) + # self._v_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim) self.max_seq_len = -1 self._k_cache, self._v_cache = None, None diff --git a/tllm/models/torch/layers.py b/tllm/models/torch/layers.py index c8eacc8..a64847d 100644 --- a/tllm/models/torch/layers.py +++ b/tllm/models/torch/layers.py @@ -12,9 +12,8 @@ apply_rotary_pos_emb, ) -from tllm import DEVICE, DTYPE from tllm.commons.attn import ATTN_FUNC -from tllm.commons.cache import AttentionData, RequestsCache +from tllm.commons.cache import AttentionData, RequestsCache, zeros_func self_attn_func = ATTN_FUNC @@ -153,8 +152,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): ) # self.max_seq_len = 1024 - # self._k_cache = torch.zeros(size=(max_seq_len, self.num_key_value_heads, self.head_dim), dtype=DTYPE, device=DEVICE) - # self._v_cache = torch.zeros(size=(max_seq_len, self.num_key_value_heads, self.head_dim), dtype=DTYPE, device=DEVICE) + # self._k_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim) + # self._v_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim) self.max_seq_len = -1 self._k_cache, self._v_cache = None, None