diff --git a/examples/run_engine.py b/examples/run_engine.py index c99e6ad..293195c 100644 --- a/examples/run_engine.py +++ b/examples/run_engine.py @@ -23,8 +23,8 @@ def parse_args(): args = parse_args() -os.environ["TLLM_BACKEND"] = args.backend.upper() -os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend.upper() +os.environ["TLLM_BACKEND"] = args.backend +os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend from tllm.commons.manager import load_client_model, load_master_model from tllm.commons.tp_communicator import Communicator diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 174e953..67cc9e5 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -57,7 +57,7 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): # rope_type="default", # rope_scaling=1.0, # ) - self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len + self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads self.head_dim = self.model.layers[-1].self_attn.head_dim diff --git a/tllm/models/mlx/qwen.py b/tllm/models/mlx/qwen.py index 82a8f99..a1fa65a 100644 --- a/tllm/models/mlx/qwen.py +++ b/tllm/models/mlx/qwen.py @@ -35,7 +35,7 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): self.model = Decoder(args, config.decoder_start_layer_idx, config.decoder_end_layer_idx, is_merge) self.num_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx - self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len + self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads self.head_dim = self.model.layers[-1].self_attn.head_dim diff --git a/tllm/models/torch/layers.py b/tllm/models/torch/layers.py index a64847d..3b2b285 100644 --- a/tllm/models/torch/layers.py +++ b/tllm/models/torch/layers.py @@ -12,6 +12,7 @@ apply_rotary_pos_emb, ) +from tllm import DEVICE from tllm.commons.attn import ATTN_FUNC from tllm.commons.cache import AttentionData, RequestsCache, zeros_func @@ -45,7 +46,7 @@ def __init__( assert col_size % self.world_size == 0 self.row_size, self.col_size = row_size, col_size self.dup_layer = dup_layer - self.layer = nn.Linear(row_size, col_size * self.dup_layer // self.world_size, bias=bias) + self.layer = nn.Linear(row_size, col_size * self.dup_layer // self.world_size, bias=bias, device=DEVICE) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: node_output = self.layer(x) @@ -62,7 +63,7 @@ def __init__(self, row_size: int, col_size_list: List[int], world_size: int, ran self.row_size, self.col_size = row_size, col_size self.col_size_list = [x // self.world_size for x in col_size_list] - self.layer = nn.Linear(row_size, col_size // self.world_size, bias=bias) + self.layer = nn.Linear(row_size, col_size // self.world_size, bias=bias, device=DEVICE) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: node_output = self.layer(x) @@ -74,7 +75,7 @@ def __init__(self, row_size: int, col_size: int, world_size: int, rank: int, bia super().__init__(world_size, rank) assert row_size % self.world_size == 0 self.row_size, self.col_size = row_size, col_size - self.layer = nn.Linear(row_size // self.world_size, col_size, bias=bias) + self.layer = nn.Linear(row_size // self.world_size, col_size, bias=bias, device=DEVICE) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layer(x) @@ -152,8 +153,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): ) # self.max_seq_len = 1024 - # self._k_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim) - # self._v_cache = zeros_func(max_seq_len, self.num_key_value_heads, self.head_dim) + # self._k_cache = zeros_func(self.max_seq_len, self.num_key_value_heads, self.head_dim) + # self._v_cache = zeros_func(self.max_seq_len, self.num_key_value_heads, self.head_dim) self.max_seq_len = -1 self._k_cache, self._v_cache = None, None diff --git a/tllm/models/torch/llama.py b/tllm/models/torch/llama.py index 7605faa..afe1e40 100644 --- a/tllm/models/torch/llama.py +++ b/tllm/models/torch/llama.py @@ -49,7 +49,7 @@ def __init__(self, config, is_merge: bool = True): self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx self.rotary_emb = HFLlamaRotaryEmbedding(config=config) - self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len + self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) self.num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads self.head_dim = self.model.layers[-1].self_attn.head_dim @@ -126,8 +126,8 @@ class HFLlamaForCausalLM(nn.Module): def __init__(self, config): super().__init__() self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, device=DEVICE) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, device=DEVICE) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @classmethod diff --git a/tllm/models/torch/qwen.py b/tllm/models/torch/qwen.py index 3fe298b..0e56300 100644 --- a/tllm/models/torch/qwen.py +++ b/tllm/models/torch/qwen.py @@ -54,7 +54,7 @@ def __init__(self, config, is_merge: bool = True): self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx self.rotary_emb = HFQwen2RotaryEmbedding(config=config) - self.max_seq_len = self.model.layers[-1].self_attn.max_seq_len + self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) self.num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads self.head_dim = self.model.layers[-1].self_attn.head_dim @@ -132,8 +132,8 @@ class HFQwen2ForCausalLM(nn.Module): def __init__(self, config): super().__init__() self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, device=DEVICE) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, device=DEVICE) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @classmethod diff --git a/tllm/models/torch/qwen_vl.py b/tllm/models/torch/qwen_vl.py index 45ece35..8c5ccba 100644 --- a/tllm/models/torch/qwen_vl.py +++ b/tllm/models/torch/qwen_vl.py @@ -32,8 +32,8 @@ def __init__(self, config): super().__init__() self.vocab_size = config.vocab_size self.visual = HFQwen2VisionTransformerPretrainedModel._from_config(config.vision_config) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, device=DEVICE) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, device=DEVICE) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @classmethod