Skip to content

Commit

Permalink
fix max_seq_len and add torch init layer
Browse files Browse the repository at this point in the history
device
  • Loading branch information
wnma3mz committed Jan 26, 2025
1 parent 9af300e commit c495bb5
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 17 deletions.
4 changes: 2 additions & 2 deletions examples/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def parse_args():


args = parse_args()
os.environ["TLLM_BACKEND"] = args.backend.upper()
os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend.upper()
os.environ["TLLM_BACKEND"] = args.backend
os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend

from tllm.commons.manager import load_client_model, load_master_model
from tllm.commons.tp_communicator import Communicator
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/mlx/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, config: AutoConfig, is_merge: bool = True):
# rope_type="default",
# rope_scaling=1.0,
# )
self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len
self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1)
self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads
self.head_dim = self.model.layers[-1].self_attn.head_dim

Expand Down
2 changes: 1 addition & 1 deletion tllm/models/mlx/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, config: AutoConfig, is_merge: bool = True):
self.model = Decoder(args, config.decoder_start_layer_idx, config.decoder_end_layer_idx, is_merge)
self.num_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx

self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len
self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1)
self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads
self.head_dim = self.model.layers[-1].self_attn.head_dim

Expand Down
11 changes: 6 additions & 5 deletions tllm/models/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
apply_rotary_pos_emb,
)

from tllm import DEVICE
from tllm.commons.attn import ATTN_FUNC
from tllm.commons.cache import AttentionData, RequestsCache, zeros_func

Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(
assert col_size % self.world_size == 0
self.row_size, self.col_size = row_size, col_size
self.dup_layer = dup_layer
self.layer = nn.Linear(row_size, col_size * self.dup_layer // self.world_size, bias=bias)
self.layer = nn.Linear(row_size, col_size * self.dup_layer // self.world_size, bias=bias, device=DEVICE)

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
node_output = self.layer(x)
Expand All @@ -62,7 +63,7 @@ def __init__(self, row_size: int, col_size_list: List[int], world_size: int, ran

self.row_size, self.col_size = row_size, col_size
self.col_size_list = [x // self.world_size for x in col_size_list]
self.layer = nn.Linear(row_size, col_size // self.world_size, bias=bias)
self.layer = nn.Linear(row_size, col_size // self.world_size, bias=bias, device=DEVICE)

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
node_output = self.layer(x)
Expand All @@ -74,7 +75,7 @@ def __init__(self, row_size: int, col_size: int, world_size: int, rank: int, bia
super().__init__(world_size, rank)
assert row_size % self.world_size == 0
self.row_size, self.col_size = row_size, col_size
self.layer = nn.Linear(row_size // self.world_size, col_size, bias=bias)
self.layer = nn.Linear(row_size // self.world_size, col_size, bias=bias, device=DEVICE)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x)
Expand Down Expand Up @@ -152,8 +153,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
)

# self.max_seq_len = 1024
# self._k_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim)
# self._v_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim)
# self._k_cache = zeros_func(self.max_seq_len, self.num_key_value_heads, self.head_dim)
# self._v_cache = zeros_func(self.max_seq_len, self.num_key_value_heads, self.head_dim)
self.max_seq_len = -1
self._k_cache, self._v_cache = None, None

Expand Down
6 changes: 3 additions & 3 deletions tllm/models/torch/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, config, is_merge: bool = True):
self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx
self.rotary_emb = HFLlamaRotaryEmbedding(config=config)

self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len
self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1)
self.num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads
self.head_dim = self.model.layers[-1].self_attn.head_dim

Expand Down Expand Up @@ -126,8 +126,8 @@ class HFLlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, device=DEVICE)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, device=DEVICE)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions tllm/models/torch/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, config, is_merge: bool = True):
self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx
self.rotary_emb = HFQwen2RotaryEmbedding(config=config)

self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len
self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1)
self.num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads
self.head_dim = self.model.layers[-1].self_attn.head_dim

Expand Down Expand Up @@ -132,8 +132,8 @@ class HFQwen2ForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, device=DEVICE)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, device=DEVICE)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tllm/models/torch/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size
self.visual = HFQwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, device=DEVICE)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, device=DEVICE)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

@classmethod
Expand Down

0 comments on commit c495bb5

Please sign in to comment.