Skip to content

Commit

Permalink
update zeros_func
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Jan 26, 2025
1 parent 19c506b commit 9af300e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from typing import Dict, List, Optional, Tuple, Union

from tllm import BACKEND, DTYPE, BackendEnum
from tllm import BACKEND, DEVICE, DTYPE, BackendEnum
from tllm.schemas import MIX_TENSOR

if BACKEND == BackendEnum.MLX:
Expand All @@ -16,7 +16,7 @@
import torch

cat_func = lambda tensors: torch.cat(tensors, dim=0)
zeros_func = lambda x0, x1, x2: torch.zeros(size=(x0, x1, x2), dtype=DTYPE)
zeros_func = lambda x0, x1, x2: torch.zeros(size=(x0, x1, x2), dtype=DTYPE, device=DEVICE)
array_func = lambda x: torch.tensor([x], dtype=torch.long)
arange_func = lambda x: torch.arange(0, x, dtype=torch.long)

Expand Down
7 changes: 3 additions & 4 deletions tllm/models/mlx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import mlx.nn as nn
from mlx_lm.models.llama import MLP, Attention, ModelArgs, TransformerBlock, initialize_rope

from tllm import DTYPE
from tllm.commons.cache import AttentionData, RequestsCache, cat_func
from tllm.commons.cache import AttentionData, RequestsCache, cat_func, zeros_func


class BaseParallelLayer(nn.Module):
Expand Down Expand Up @@ -146,8 +145,8 @@ def __init__(self, args, layer_idx: int, offset: int):

self.rope = initialize_rope(args)
# self.max_seq_len = 1024
# self._k_cache = mx.zeros(shape=(self.max_seq_len, self.n_kv_heads, self.head_dim), dtype=DTYPE)
# self._v_cache = mx.zeros(shape=(self.max_seq_len, self.n_kv_heads, self.head_dim), dtype=DTYPE)
# self._k_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim)
# self._v_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim)
self.max_seq_len = -1
self._k_cache, self._v_cache = None, None

Expand Down
7 changes: 3 additions & 4 deletions tllm/models/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
apply_rotary_pos_emb,
)

from tllm import DEVICE, DTYPE
from tllm.commons.attn import ATTN_FUNC
from tllm.commons.cache import AttentionData, RequestsCache
from tllm.commons.cache import AttentionData, RequestsCache, zeros_func

self_attn_func = ATTN_FUNC

Expand Down Expand Up @@ -153,8 +152,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
)

# self.max_seq_len = 1024
# self._k_cache = torch.zeros(size=(max_seq_len, self.num_key_value_heads, self.head_dim), dtype=DTYPE, device=DEVICE)
# self._v_cache = torch.zeros(size=(max_seq_len, self.num_key_value_heads, self.head_dim), dtype=DTYPE, device=DEVICE)
# self._k_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim)
# self._v_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim)
self.max_seq_len = -1
self._k_cache, self._v_cache = None, None

Expand Down

0 comments on commit 9af300e

Please sign in to comment.