diff --git a/lightllm/models/cohere/__init__.py b/lightllm/models/cohere/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/cohere/layer_infer/__init__.py b/lightllm/models/cohere/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..27b8688fc --- /dev/null +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -0,0 +1,139 @@ +import torch +import torch.distributed as dist +import numpy as np + +from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight +from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward +from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight +from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo + +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from einops import rearrange +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.basemodel import PostLayerInferTpl + + +class CoherePostLayerInfer(PostLayerInferTpl): + def __init__(self, tp_rank, world_size, network_config, mode): + super().__init__(tp_rank, world_size, network_config, mode) + self.eps_ = network_config["layer_norm_eps"] + self.vocab_size_ = network_config["vocab_size"] + self.embed_dim_ = network_config["n_embed"] + self.logits_scale = network_config["logit_scale"] + return + + def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor: + return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) + + def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo): + if infer_state.is_splitfuse: + # for SplitFuse + batch_size = infer_state.batch_size + last_input = torch.empty( + (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + ) + tmp_ = torch.cat( + [ + torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"), + infer_state.prefill_b_seq_len - infer_state.prefill_b_split_ready_cache_len, + ], + dim=0, + ) + last_index = torch.cumsum(tmp_, dim=0, dtype=torch.long) - 1 + last_input[:, :] = input_embdings[last_index, :] + return last_input, batch_size + + if infer_state.is_prefill and infer_state.is_token_healing: + batch_size = infer_state.batch_size + b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() + select_index = [] + start_index = 0 + select_token_num = 0 + for cur_len in b_seq_len_numpy: + if cur_len == 1: + select_index.append(start_index + cur_len - 1) + start_index += cur_len + select_token_num += 1 + else: + select_index.append(start_index + cur_len - 2) + select_index.append(start_index + cur_len - 1) + start_index += cur_len + select_token_num += 2 + + last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) + last_input = torch.empty( + (select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + ) + + last_input[:, :] = input_embdings[last_index, :] + return last_input, select_token_num + + if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics: + batch_size = infer_state.batch_size + last_input = torch.empty( + (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + ) + last_index = ( + torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 + ) + last_input[:, :] = input_embdings[last_index, :] + return last_input, batch_size + + if not infer_state.is_splitfuse and infer_state.is_prefill and infer_state.return_all_prompt_logics: + total_tokens = infer_state.total_token_num + return input_embdings, total_tokens + + if not infer_state.is_splitfuse and not infer_state.is_prefill: + batch_size = infer_state.batch_size + return input_embdings[-batch_size:, :], batch_size + + assert False, "Error State" + + def soft_max(self, data): + return torch.softmax(data.permute(1, 0).float(), dim=-1) + + def token_forward( + self, + input_embdings, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaPreAndPostLayerWeight, + return_logics=False, + ): + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) + input_embdings_dtype = input_embdings.dtype + input_embdings = None + last_input = self._norm(last_input, infer_state, layer_weight) + last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num) + logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) + + last_input = None + if self.world_size_ == 1: + gather_data = logic_batch + else: + gather_data = torch.empty( + (self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype + ) + split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64) + dist.all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)], + logic_batch, + group=None, + async_op=False, + ) + gather_data = gather_data * self.logits_scale + logic_batch = None + + if not return_logics: + prob_out = self.soft_max(gather_data) + gather_data = None + return prob_out + else: + ans_logics = gather_data.permute(1, 0).float() + gather_data = None + return ans_logics + + # @mark_cost_time("splitfuse post forward") + def splitfuse_forward( + self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight: BaseLayerWeight, return_logics=False + ): + return self.token_forward(input_embdings, infer_state, layer_weight, return_logics=return_logics) diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py new file mode 100755 index 000000000..ed386f0c0 --- /dev/null +++ b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py @@ -0,0 +1,174 @@ +import torch +from functools import partial + +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl +from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo +from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight +from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, mh_layernorm_forward +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer + +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +import torch.distributed as dist + +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.utils.infer_utils import mark_cost_time + + + +class CohereTransformerLayerInfer(LlamaTransformerLayerInfer): + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): + network_config["rms_norm_eps"] = network_config["layer_norm_eps"] # cohere uses layer_norm_eps + self.use_qk_norm = network_config.get("use_qk_norm", False) + super().__init__(layer_num, tp_rank, world_size, network_config, mode) + self.eps_ = network_config["layer_norm_eps"] # overwrite eps + self._bind_func() + return + + def _att_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: CohereTransformerLayerWeight): + return layernorm_forward(input, weight=layer_weight.att_norm_weight_, eps=self.eps_) + + def _q_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: CohereTransformerLayerWeight): + return mh_layernorm_forward(input, weight=layer_weight.q_norm_weight_, eps=self.eps_) + + def _k_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: CohereTransformerLayerWeight): + return mh_layernorm_forward(input, weight=layer_weight.k_norm_weight_, eps=self.eps_) + + def _bind_norm(self): + self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) + self._ffn_norm = None # no ffn norm in cohere models + self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) if self.use_qk_norm else None + self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) if self.use_qk_norm else None + + def _get_qkv( + self, input, cache_kv, infer_state, layer_weight + ) -> torch.Tensor: + q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) + torch.mm( + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + if self.use_qk_norm: + q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + k = cache_kv[:, 0 : self.tp_k_head_num_, :] + q = self._q_norm(q, infer_state, layer_weight) + cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + @mark_cost_time("trans context ffn forward time cost") # dont to remove this, will make performence down, did not know why + def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = input_embdings + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + infer_state._ffn_out = ffn_out + return + + def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = input_embdings + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + infer_state._ffn_out = ffn_out + return + + # @mark_cost_time("trans context ffn forward time cost") # dont to remove this, will make performence down, did not know why + def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight): + input1 = input_embdings + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + infer_state._ffn_out = ffn_out + return + + @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why + def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): + input1 = input_embding + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) + input1 = None + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + infer_state._attn_out = o + return + + # this impl dont to use @mark_cost_time + def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): + input1 = input_embding + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + infer_state._attn_out = o + return + + # @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why + def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight): + input1 = input_embding + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + infer_state._attn_out = o + return + + def _cohere_residual(self, input_embdings, infer_state: InferStateInfo): + emb_addr = input_embdings.data_ptr() + attn_out_addr = infer_state._attn_out.data_ptr() + ffn_addr = infer_state._ffn_out.data_ptr() + assert emb_addr != attn_out_addr + assert emb_addr != ffn_addr + assert attn_out_addr != ffn_addr + input_embdings.add_(infer_state._attn_out.view(-1, self.embed_dim_) + infer_state._ffn_out.view(-1, self.embed_dim_)) + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + self._context_attention(input1, + infer_state, + layer_weight=layer_weight) + self._context_ffn(input1, infer_state, layer_weight) + self._cohere_residual(input_embdings, infer_state) + return input_embdings + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + self._token_attention(input1, + infer_state, + layer_weight=layer_weight) + self._token_ffn(input1, infer_state, layer_weight) + self._cohere_residual(input_embdings, infer_state) + return input_embdings + + def splitfuse_forward(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + self._splitfuse_attention(input1, + infer_state, + layer_weight=layer_weight) + self._splitfuse_ffn(input1, infer_state, layer_weight) + self._cohere_residual(input_embdings, infer_state) + return input_embdings diff --git a/lightllm/models/cohere/layer_weights/__init__.py b/lightllm/models/cohere/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..fb32d2c9e --- /dev/null +++ b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,36 @@ +import torch +import numpy as np + +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + +class CoherePreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + tie_weight = self.network_config_.get("tie_word_embeddings", True) + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "model.embed_tokens.weight" in weights: + # print(weights['model.embed_tokens.weight'].shape) + self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + if tie_weight: + self.lm_head_weight_ = self.wte_weight_ + if "model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + if "model.lm_head.weight" in weights: + self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"]) + return + + def verify_load(self): + super().verify_load() + + errors = "tie weights load not ok" + tie_weight = self.network_config_.get("tie_word_embeddings", True) + if tie_weight: + assert self.lm_head_weight_ is not None, errors + assert self.wte_weight_ is self.lm_head_weight_, errors + else: + assert self.lm_head_weight_ is not None, errors + assert self.wte_weight_ is not None, errors + assert self.wte_weight_ is not self.lm_head_weight_, errors + diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..7920de6c0 --- /dev/null +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -0,0 +1,107 @@ +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + + +class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) + self.use_qk_norm = network_config.get("use_qk_norm", False) + return + + def load_hf_weights(self, weights): + self._load_qkvo_weights(weights) + self._load_ffn_weights(weights) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.gate_up_proj, + self.down_proj, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + if self.use_qk_norm: + qk_weights = [self.q_norm_weight_, self.k_norm_weight_] + for i in range(len(qk_weights)): + assert qk_weights[i] is not None, "index:" + str(i + len(weights)) + " " + errors + return + + def _load_qkvo_weights(self, weights): + # input layernorm params + if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: + self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) + + n_embed = self.network_config_["hidden_size"] + q_split_n_embed = n_embed // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) + q_split_head = self.network_config_["num_attention_heads"] // self.world_size_ + kv_split_head = self.network_config_["num_key_value_heads"] // self.world_size_ + + # q k v weights for llama + if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: + self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + + if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.self_attn.q_norm.weight" in weights: + q_norm_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_norm.weight"] + q_norm_weight_ = q_norm_weight_[q_split_head * self.tp_rank_ : q_split_head * (self.tp_rank_ + 1)] + self.q_norm_weight_ = self._cuda(q_norm_weight_) + if f"model.layers.{self.layer_num_}.self_attn.k_norm.weight" in weights: + k_norm_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_norm.weight"] + k_norm_weight_ = k_norm_weight_[kv_split_head * self.tp_rank_ : kv_split_head * (self.tp_rank_ + 1)] + self.k_norm_weight_ = self._cuda(k_norm_weight_) + + # attention output dense params + if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + + return + + def _load_ffn_weights(self, weights): + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + + if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + + if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + return diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py new file mode 100644 index 000000000..9d4e57b58 --- /dev/null +++ b/lightllm/models/cohere/model.py @@ -0,0 +1,43 @@ +import os +import json +import torch +from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer +from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer +from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight +from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.models.llama.layer_weights.ds_load_utils import load_ds_weights +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +class CohereTpPartModel(LlamaTpPartModel): + r''' + The cohere model is modified from the llama model. + 1. emb is the same + 2. layer_norm is a normal layer_norm instead of the llama rms_layer_norm; only input norm, no output norm + 3. mlp is not configuable, which is bias-free. + 4. rotary_emb is the same + 5. finial lm_head is tied to the emb + 6. res = emb + attn(iln_emb) + mlp(iln_emb) + ''' + pre_and_post_weight_class = CoherePreAndPostLayerWeight + transformer_weight_class = CohereTransformerLayerWeight + + pre_layer_infer_class = LlamaPreLayerInfer + post_layer_infer_class = CoherePostLayerInfer + transformer_layer_infer_class = CohereTransformerLayerInfer + + infer_state_class = LlamaInferStateInfo + splitfuse_infer_state_class = LlamaSplitFuseInferStateInfo diff --git a/lightllm/models/cohere/triton_kernels/__init__.py b/lightllm/models/cohere/triton_kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/cohere/triton_kernels/layernorm.py b/lightllm/models/cohere/triton_kernels/layernorm.py new file mode 100644 index 000000000..4ebf8c2c3 --- /dev/null +++ b/lightllm/models/cohere/triton_kernels/layernorm.py @@ -0,0 +1,66 @@ +import torch + +import triton +import triton.language as tl + +@torch.no_grad() +def layernorm_forward(x, weight, eps): + return torch.layer_norm(x, (x.shape[-1],), weight, bias=None, eps=eps) + +def mh_layernorm_forward(x, weight, eps): + # x shape : (bs, seqlen, head, head_dim) + inp_dtype = x.dtype + x = x.to(torch.float32) + mean = x.mean(-1, keepdim=True) + variance = (x - mean).pow(2).mean(-1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + eps) + x = weight.to(torch.float32) * x + return x.to(inp_dtype) + + +class CohereLayerNorm(torch.nn.Module): + def __init__(self, hidden_size=None, eps=1e-5, bias=False): + """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim""" + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + +def test(): + hidden_size = 768 + bs = 10 + seqlen = 128 + m = CohereLayerNorm(hidden_size).to(torch.float32) + for i in range(10): + x = torch.randn(bs, seqlen, hidden_size, dtype=torch.float32) + output_1 = m(x) + output_2 = layernorm_forward(x, m.weight, m.variance_epsilon) + print(torch.allclose(output_1, output_2, atol=1e-4)) + max_err = torch.max(torch.abs(output_1 - output_2)) + print("max error:", max_err) + + head = 8 + head_dim = 64 + m = CohereLayerNorm((head, head_dim)).to(torch.float32) + print(m.weight.shape) + for i in range(10): + x = torch.randn(bs * seqlen, head, head_dim, dtype=torch.float32) + output_1 = m(x) + output_2 = mh_layernorm_forward( + x.view(bs * seqlen, head, head_dim), m.weight, m.variance_epsilon + ) + print(torch.allclose(output_1, output_2, atol=1e-4)) + max_err = torch.max(torch.abs(output_1 - output_2)) + print("max error:", max_err) + +if __name__ == "__main__": + test() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fcacb8b8b..aff4ff37d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -6,6 +6,7 @@ from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig +from lightllm.models.cohere.model import CohereTpPartModel from lightllm.models.mixtral.model import MixtralTpPartModel from lightllm.models.qwen2.model import Qwen2TpPartModel from rpyc.utils.classic import obtain @@ -171,6 +172,8 @@ def init_model(self, kvargs): self.model = Qwen2TpPartModel(model_kvargs) elif self.model_type == "gemma": self.model = Gemma_2bTpPartModel(model_kvargs) + elif self.model_type == "cohere": + self.model = CohereTpPartModel(model_kvargs) else: raise Exception(f"can not support {self.model_type} now") except Exception as e: