From 286834e3a0bf5bd2f8d67f32f80c968b09a6aaeb Mon Sep 17 00:00:00 2001 From: lujianghu Date: Thu, 6 Feb 2025 17:25:08 +0800 Subject: [PATCH] split generate function and use attn_data.mask --- tllm/generate/llm_generator.py | 83 +++++++++++++++++--------------- tllm/models/mlx/helper.py | 4 +- tllm/models/mlx/layers.py | 86 +++++----------------------------- tllm/models/mlx/llama.py | 5 +- tllm/models/mlx/qwen2.py | 5 +- 5 files changed, 62 insertions(+), 121 deletions(-) diff --git a/tllm/generate/llm_generator.py b/tllm/generate/llm_generator.py index e6c1d95..f0dfa86 100644 --- a/tllm/generate/llm_generator.py +++ b/tllm/generate/llm_generator.py @@ -30,16 +30,7 @@ async def forward(self, inputs_embeds: MIX_TENSOR, seq_input: SeqInput) -> Forwa calc_cost_time=sum(calc_cost_time_list), ) - async def generate(self, request_list: List[SequenceRequestData]): - """ - @params: - request_list: List[SequenceRequestData] - Params: - input_ids: List[int] - - """ - is_gen_image = any(x.is_gen_image for x in request_list) # In Experiment - + async def process_input(self, request_list: List[SequenceRequestData]): uuid_list, input_ids_list, mm_input_list = [], [], [] for sequence_request in request_list: uuid_list.append(sequence_request.request_id) @@ -52,47 +43,24 @@ async def generate(self, request_list: List[SequenceRequestData]): if sequence_request.is_gen_image: compare_input_ids = copy.deepcopy(input_ids_list[-1]) - uuid_list.append(sequence_request.request_id + "-bak") # 每个 request 对应两个句子 + uuid_list.append(sequence_request.request_id + "-bak") # For Janus-Pro, 每个 request 对应两个句子 compare_input_ids[1:-1] = self.model.pad_token_id input_ids_list.append(compare_input_ids) else: input_ids_list.append(np.array([sequence_request.output_ids[-1]])) if sequence_request.is_gen_image: - uuid_list.append(sequence_request.request_id + "-bak") # 每个 request 对应两个句子 + uuid_list.append(sequence_request.request_id + "-bak") input_ids_list.append(np.array([sequence_request.output_ids[-1]])) mm_input = self.merge_mm_input(mm_input_list) if self.merge_mm_input is not None else None - input_ids = np.concatenate(input_ids_list, axis=-1) # [seq_len1 + seq_len2 + ...] -> [seq_len1 + seq_len2 + ..., hidden_size] - - if mm_input is not None: - input_embeds = self.model.get_input_embeddings(input_ids, **mm_input) - else: - if not is_gen_image: - input_embeds = self.model.get_input_embeddings(input_ids) - else: - if request_list[0].is_prefill: - input_embeds = self.model.get_input_embeddings(input_ids) - else: - # 在生成图片时,生成第二个 token 后 - input_embeds = self.model.get_gen_img_embeds(input_ids) + input_ids = np.concatenate(input_ids_list, axis=-1) seq_input = SeqInput(uuid_list=uuid_list, input_ids_list=input_ids_list) - s0 = time.perf_counter() - forward_result = await self.forward(input_embeds, seq_input) - self.logger.debug(f"decoder cost time: {time.perf_counter() - s0:.4f}s") - s1 = time.perf_counter() - if is_gen_image: - seq_logits: List[MIX_TENSOR] = self.model.get_gen_head(forward_result.hidden_states) - else: - seq_logits: List[MIX_TENSOR] = self.model.get_logits(forward_result.hidden_states) - - self.logger.debug(f"logits cost time: {time.perf_counter() - s1:.4f}s") - assert seq_logits.shape[0] == len(request_list) - - s1 = time.perf_counter() + return seq_input, input_ids, mm_input + async def process_output(self, request_list: List[SequenceRequestData], seq_logits: MIX_TENSOR, s0: float): # TODO: batch decode by group # TODO: sequence_request.sampling_params seq_generate_ids: List[int] = sampling_func(seq_logits) @@ -132,6 +100,45 @@ async def generate(self, request_list: List[SequenceRequestData]): sequence_request.decode_start_ts = time.perf_counter() sequence_request.is_prefill = False + async def generate(self, request_list: List[SequenceRequestData]): + """ + @params: + request_list: List[SequenceRequestData] + Params: + input_ids: List[int] + + """ + is_gen_image = any(x.is_gen_image for x in request_list) # In Experiment + seq_input, input_ids, mm_input = await self.process_input(request_list) + + if mm_input is not None: + input_embeds = self.model.get_input_embeddings(input_ids, **mm_input) + else: + if not is_gen_image: + input_embeds = self.model.get_input_embeddings(input_ids) + else: + if request_list[0].is_prefill: + input_embeds = self.model.get_input_embeddings(input_ids) + else: + # For Janus-Pro, 在生成图片时,生成第二个 token 后 + input_embeds = self.model.get_gen_img_embeds(input_ids) + + s0 = time.perf_counter() + forward_result = await self.forward(input_embeds, seq_input) + self.logger.debug(f"decoder cost time: {time.perf_counter() - s0:.4f}s") + s1 = time.perf_counter() + if is_gen_image: + # For Janus-Pro + seq_logits: List[MIX_TENSOR] = self.model.get_gen_head(forward_result.hidden_states) + else: + seq_logits: List[MIX_TENSOR] = self.model.get_logits(forward_result.hidden_states) + + self.logger.debug(f"logits cost time: {time.perf_counter() - s1:.4f}s") + assert seq_logits.shape[0] == len(request_list) + + s1 = time.perf_counter() + await self.process_output(request_list, seq_logits, s0) + fraction = forward_result.comm_cost_time / (forward_result.comm_cost_time + forward_result.calc_cost_time) self.logger.debug(f"de tokenizer cost time: {time.perf_counter() - s1:.4f}s") self.logger.debug(f"communication cost time: {forward_result.comm_cost_time:.4f}s({fraction*100:.1f}%)") diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index c1dd357..fbe76c3 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -66,11 +66,13 @@ def build_forward_cache(self, hidden_states: mx.array, seq_input: SeqInput) -> m if hidden_states.dtype == mx.float16: # float16 is much slower than bfloat16 hidden_states = hidden_states.astype(mx.bfloat16) + attn_mask = build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list) + self.q_len_list = q_len_list self.hit_cache_len_list = hit_cache_len_list self.attn_data = AttentionData( request_cache=self.request_cache, - attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list), + attn_mask=attn_mask if attn_mask is None else attn_mask.astype(hidden_states.dtype), uuid_list=seq_input.uuid_list, ) diff --git a/tllm/models/mlx/layers.py b/tllm/models/mlx/layers.py index 7bde2b8..30b666f 100644 --- a/tllm/models/mlx/layers.py +++ b/tllm/models/mlx/layers.py @@ -3,9 +3,9 @@ import mlx.core as mx import mlx.nn as nn -from mlx_lm.models.llama import MLP, Attention, ModelArgs, TransformerBlock, initialize_rope +from mlx_lm.models.llama import ModelArgs, TransformerBlock, initialize_rope -from tllm.commons.cache import AttentionData, RequestsCache, cat_func, zeros_func +from tllm.commons.cache import AttentionData, RequestsCache, cat_func class BaseParallelLayer(nn.Module): @@ -146,9 +146,6 @@ def __init__(self, args, layer_idx: int, offset: int): self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings ) - # self.max_seq_len = 1024 - # self._k_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim) - # self._v_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim) self.max_seq_len = -1 self._k_cache, self._v_cache = None, None @@ -162,12 +159,7 @@ def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str] start = end return cat_func(x_list) - def __call__( - self, - x: mx.array, - mask: mx.array, - cache: AttentionData, - ) -> mx.array: + def __call__(self, x: mx.array, cache: AttentionData) -> mx.array: L, _ = x.shape queries, keys, values = self.qkv_proj(x) # Prepare the queries, keys and values for the attention computation @@ -186,63 +178,13 @@ def __call__( keys, values, cache.uuid_list, self.layer_idx - self.offset, self._k_cache, self._v_cache ) - output = sdap(queries, keys, values, scale=self.scale, mask=mask) + output = sdap(queries, keys, values, scale=self.scale, mask=cache.attn_mask) output = output.reshape(L, -1) attn_output = self.o_proj(output) return self.comm.all_reduce(attn_output) -class PlainAttention(Attention): - def __init__(self, args, layer_idx: int, offset: int): - super().__init__(args) - o_proj_bias = getattr(args, "o_proj_bias", False) - self.o_proj = nn.Linear(self.n_heads * self.head_dim, args.hidden_size, bias=o_proj_bias) - self.layer_idx = layer_idx - self.offset = offset - - def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str]) -> List[mx.array]: - offset_list = request_cache.get_offset_list(uuid_list, self.layer_idx - self.offset) - x_list = [] - start = 0 - for uuid, offset in zip(uuid_list, offset_list): - end = start + request_cache.get_q_len(uuid) - x_list.append(self.rope(xs[start:end].transpose(1, 0, 2), offset).transpose(1, 0, 2)) - start = end - return cat_func(x_list) - - def __call__( - self, - x: mx.array, - mask: mx.array, - cache: AttentionData, - ) -> mx.array: - L, D = x.shape - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(L, self.n_heads, -1) - keys = keys.reshape(L, self.n_kv_heads, -1) - values = values.reshape(L, self.n_kv_heads, -1) - - # must has cache, and split by uuid - request_cache: RequestsCache = cache.request_cache - queries = self._rope(queries, request_cache, cache.uuid_list) - keys = self._rope(keys, request_cache, cache.uuid_list) - keys, values = request_cache.update(keys, values, cache.uuid_list, self.layer_idx - self.offset) - - output = mx.fast.scaled_dot_product_attention( - mx.expand_dims(queries.transpose(1, 0, 2), axis=0), - mx.expand_dims(keys.transpose(1, 0, 2), axis=0), - mx.expand_dims(values.transpose(1, 0, 2), axis=0), - scale=self.scale, - mask=mask, - )[0] - output = output.transpose(1, 0, 2).reshape(L, -1) - - return self.o_proj(output) - - class MergedMLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -269,19 +211,15 @@ def __init__(self, args: ModelArgs, layer_idx: int, offset: int, is_merge: bool super(TransformerBlock).__init__() self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size - if is_merge: - self.self_attn = MergedAttention(args, layer_idx, offset) - self.mlp = MergedMLP(args) - else: - self.self_attn = PlainAttention(args, layer_idx, offset) - self.mlp = MLP(args) + self.self_attn = MergedAttention(args, layer_idx, offset) + self.mlp = MergedMLP(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.args = args self.layer_idx = layer_idx - def __call__(self, x: mx.array, mask, cache) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) + def __call__(self, x: mx.array, cache) -> mx.array: + r = self.self_attn(self.input_layernorm(x), cache) h = x + r # no skip some begin token, and skip middle block, https://arxiv.org/abs/2404.03865 # if 20 <= self.layer_idx <= 24 and x.shape[0] == 1: @@ -291,7 +229,7 @@ def __call__(self, x: mx.array, mask, cache) -> mx.array: return out -def empty_func(h, mask, cache): +def empty_func(h, cache): # TODO return h @@ -306,7 +244,7 @@ def __init__(self, args: ModelArgs, start_layer_idx: int, end_layer_idx: int, is for layer_idx in range(start_layer_idx, end_layer_idx) ] - def __call__(self, h: mx.array, mask, cache: AttentionData) -> mx.array: - for layer in self.layers: - h = layer(h, mask, cache=cache) + def __call__(self, h: mx.array, cache: AttentionData) -> mx.array: + for i, layer in enumerate(self.layers): + h = layer(h, cache=cache) return h diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index f38cfdc..617c13b 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -79,12 +79,9 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) - # cos, sin = self.rotary_emb(attention_data.position_ids) # attention_data.cos, attention_data.sin = mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) - mask = cache_manager.attn_data.attn_mask - mask = mask if mask is None else mask.astype(hidden_states.dtype) - output = self.model(hidden_states, mask=mask, cache=cache_manager.attn_data) + output = self.model(hidden_states, cache=cache_manager.attn_data) # TODO 异步更新 cache cache_manager.update_cache(seq_input) diff --git a/tllm/models/mlx/qwen2.py b/tllm/models/mlx/qwen2.py index 5b928b5..3aa5294 100644 --- a/tllm/models/mlx/qwen2.py +++ b/tllm/models/mlx/qwen2.py @@ -46,10 +46,7 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) - - mask = cache_manager.attn_data.attn_mask - mask = mask if mask is None else mask.astype(hidden_states.dtype) - output = self.model(hidden_states, mask=mask, cache=cache_manager.attn_data) + output = self.model(hidden_states, cache=cache_manager.attn_data) # TODO 异步更新 cache cache_manager.update_cache(seq_input)