Skip to content

Commit

Permalink
Merge branch 'main' of github.com:wnma3mz/tLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 9, 2025
2 parents 229243c + 286834e commit ad6b451
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 121 deletions.
83 changes: 45 additions & 38 deletions tllm/generate/llm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,7 @@ async def forward(self, inputs_embeds: MIX_TENSOR, seq_input: SeqInput) -> Forwa
calc_cost_time=sum(calc_cost_time_list),
)

async def generate(self, request_list: List[SequenceRequestData]):
"""
@params:
request_list: List[SequenceRequestData]
Params:
input_ids: List[int]
"""
is_gen_image = any(x.is_gen_image for x in request_list) # In Experiment

async def process_input(self, request_list: List[SequenceRequestData]):
uuid_list, input_ids_list, mm_input_list = [], [], []
for sequence_request in request_list:
uuid_list.append(sequence_request.request_id)
Expand All @@ -52,47 +43,24 @@ async def generate(self, request_list: List[SequenceRequestData]):

if sequence_request.is_gen_image:
compare_input_ids = copy.deepcopy(input_ids_list[-1])
uuid_list.append(sequence_request.request_id + "-bak") # 每个 request 对应两个句子
uuid_list.append(sequence_request.request_id + "-bak") # For Janus-Pro, 每个 request 对应两个句子
compare_input_ids[1:-1] = self.model.pad_token_id
input_ids_list.append(compare_input_ids)
else:
input_ids_list.append(np.array([sequence_request.output_ids[-1]]))
if sequence_request.is_gen_image:
uuid_list.append(sequence_request.request_id + "-bak") # 每个 request 对应两个句子
uuid_list.append(sequence_request.request_id + "-bak")
input_ids_list.append(np.array([sequence_request.output_ids[-1]]))

mm_input = self.merge_mm_input(mm_input_list) if self.merge_mm_input is not None else None

input_ids = np.concatenate(input_ids_list, axis=-1)
# [seq_len1 + seq_len2 + ...] -> [seq_len1 + seq_len2 + ..., hidden_size]

if mm_input is not None:
input_embeds = self.model.get_input_embeddings(input_ids, **mm_input)
else:
if not is_gen_image:
input_embeds = self.model.get_input_embeddings(input_ids)
else:
if request_list[0].is_prefill:
input_embeds = self.model.get_input_embeddings(input_ids)
else:
# 在生成图片时,生成第二个 token 后
input_embeds = self.model.get_gen_img_embeds(input_ids)
input_ids = np.concatenate(input_ids_list, axis=-1)

seq_input = SeqInput(uuid_list=uuid_list, input_ids_list=input_ids_list)
s0 = time.perf_counter()
forward_result = await self.forward(input_embeds, seq_input)
self.logger.debug(f"decoder cost time: {time.perf_counter() - s0:.4f}s")
s1 = time.perf_counter()
if is_gen_image:
seq_logits: List[MIX_TENSOR] = self.model.get_gen_head(forward_result.hidden_states)
else:
seq_logits: List[MIX_TENSOR] = self.model.get_logits(forward_result.hidden_states)

self.logger.debug(f"logits cost time: {time.perf_counter() - s1:.4f}s")
assert seq_logits.shape[0] == len(request_list)

s1 = time.perf_counter()
return seq_input, input_ids, mm_input

async def process_output(self, request_list: List[SequenceRequestData], seq_logits: MIX_TENSOR, s0: float):
# TODO: batch decode by group
# TODO: sequence_request.sampling_params
seq_generate_ids: List[int] = sampling_func(seq_logits)
Expand Down Expand Up @@ -132,6 +100,45 @@ async def generate(self, request_list: List[SequenceRequestData]):
sequence_request.decode_start_ts = time.perf_counter()
sequence_request.is_prefill = False

async def generate(self, request_list: List[SequenceRequestData]):
"""
@params:
request_list: List[SequenceRequestData]
Params:
input_ids: List[int]
"""
is_gen_image = any(x.is_gen_image for x in request_list) # In Experiment
seq_input, input_ids, mm_input = await self.process_input(request_list)

if mm_input is not None:
input_embeds = self.model.get_input_embeddings(input_ids, **mm_input)
else:
if not is_gen_image:
input_embeds = self.model.get_input_embeddings(input_ids)
else:
if request_list[0].is_prefill:
input_embeds = self.model.get_input_embeddings(input_ids)
else:
# For Janus-Pro, 在生成图片时,生成第二个 token 后
input_embeds = self.model.get_gen_img_embeds(input_ids)

s0 = time.perf_counter()
forward_result = await self.forward(input_embeds, seq_input)
self.logger.debug(f"decoder cost time: {time.perf_counter() - s0:.4f}s")
s1 = time.perf_counter()
if is_gen_image:
# For Janus-Pro
seq_logits: List[MIX_TENSOR] = self.model.get_gen_head(forward_result.hidden_states)
else:
seq_logits: List[MIX_TENSOR] = self.model.get_logits(forward_result.hidden_states)

self.logger.debug(f"logits cost time: {time.perf_counter() - s1:.4f}s")
assert seq_logits.shape[0] == len(request_list)

s1 = time.perf_counter()
await self.process_output(request_list, seq_logits, s0)

fraction = forward_result.comm_cost_time / (forward_result.comm_cost_time + forward_result.calc_cost_time)
self.logger.debug(f"de tokenizer cost time: {time.perf_counter() - s1:.4f}s")
self.logger.debug(f"communication cost time: {forward_result.comm_cost_time:.4f}s({fraction*100:.1f}%)")
Expand Down
4 changes: 3 additions & 1 deletion tllm/models/mlx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def build_forward_cache(self, hidden_states: mx.array, seq_input: SeqInput) -> m
if hidden_states.dtype == mx.float16: # float16 is much slower than bfloat16
hidden_states = hidden_states.astype(mx.bfloat16)

attn_mask = build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list)

self.q_len_list = q_len_list
self.hit_cache_len_list = hit_cache_len_list
self.attn_data = AttentionData(
request_cache=self.request_cache,
attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list),
attn_mask=attn_mask if attn_mask is None else attn_mask.astype(hidden_states.dtype),
uuid_list=seq_input.uuid_list,
)

Expand Down
86 changes: 12 additions & 74 deletions tllm/models/mlx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.llama import MLP, Attention, ModelArgs, TransformerBlock, initialize_rope
from mlx_lm.models.llama import ModelArgs, TransformerBlock, initialize_rope

from tllm.commons.cache import AttentionData, RequestsCache, cat_func, zeros_func
from tllm.commons.cache import AttentionData, RequestsCache, cat_func


class BaseParallelLayer(nn.Module):
Expand Down Expand Up @@ -146,9 +146,6 @@ def __init__(self, args, layer_idx: int, offset: int):
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings
)

# self.max_seq_len = 1024
# self._k_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim)
# self._v_cache = zeros_func(self.max_seq_len, self.n_kv_heads, self.head_dim)
self.max_seq_len = -1
self._k_cache, self._v_cache = None, None

Expand All @@ -162,12 +159,7 @@ def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str]
start = end
return cat_func(x_list)

def __call__(
self,
x: mx.array,
mask: mx.array,
cache: AttentionData,
) -> mx.array:
def __call__(self, x: mx.array, cache: AttentionData) -> mx.array:
L, _ = x.shape
queries, keys, values = self.qkv_proj(x)
# Prepare the queries, keys and values for the attention computation
Expand All @@ -186,63 +178,13 @@ def __call__(
keys, values, cache.uuid_list, self.layer_idx - self.offset, self._k_cache, self._v_cache
)

output = sdap(queries, keys, values, scale=self.scale, mask=mask)
output = sdap(queries, keys, values, scale=self.scale, mask=cache.attn_mask)
output = output.reshape(L, -1)

attn_output = self.o_proj(output)
return self.comm.all_reduce(attn_output)


class PlainAttention(Attention):
def __init__(self, args, layer_idx: int, offset: int):
super().__init__(args)
o_proj_bias = getattr(args, "o_proj_bias", False)
self.o_proj = nn.Linear(self.n_heads * self.head_dim, args.hidden_size, bias=o_proj_bias)
self.layer_idx = layer_idx
self.offset = offset

def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str]) -> List[mx.array]:
offset_list = request_cache.get_offset_list(uuid_list, self.layer_idx - self.offset)
x_list = []
start = 0
for uuid, offset in zip(uuid_list, offset_list):
end = start + request_cache.get_q_len(uuid)
x_list.append(self.rope(xs[start:end].transpose(1, 0, 2), offset).transpose(1, 0, 2))
start = end
return cat_func(x_list)

def __call__(
self,
x: mx.array,
mask: mx.array,
cache: AttentionData,
) -> mx.array:
L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(L, self.n_heads, -1)
keys = keys.reshape(L, self.n_kv_heads, -1)
values = values.reshape(L, self.n_kv_heads, -1)

# must has cache, and split by uuid
request_cache: RequestsCache = cache.request_cache
queries = self._rope(queries, request_cache, cache.uuid_list)
keys = self._rope(keys, request_cache, cache.uuid_list)
keys, values = request_cache.update(keys, values, cache.uuid_list, self.layer_idx - self.offset)

output = mx.fast.scaled_dot_product_attention(
mx.expand_dims(queries.transpose(1, 0, 2), axis=0),
mx.expand_dims(keys.transpose(1, 0, 2), axis=0),
mx.expand_dims(values.transpose(1, 0, 2), axis=0),
scale=self.scale,
mask=mask,
)[0]
output = output.transpose(1, 0, 2).reshape(L, -1)

return self.o_proj(output)


class MergedMLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand All @@ -269,19 +211,15 @@ def __init__(self, args: ModelArgs, layer_idx: int, offset: int, is_merge: bool
super(TransformerBlock).__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
if is_merge:
self.self_attn = MergedAttention(args, layer_idx, offset)
self.mlp = MergedMLP(args)
else:
self.self_attn = PlainAttention(args, layer_idx, offset)
self.mlp = MLP(args)
self.self_attn = MergedAttention(args, layer_idx, offset)
self.mlp = MergedMLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
self.layer_idx = layer_idx

def __call__(self, x: mx.array, mask, cache) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
def __call__(self, x: mx.array, cache) -> mx.array:
r = self.self_attn(self.input_layernorm(x), cache)
h = x + r
# no skip some begin token, and skip middle block, https://arxiv.org/abs/2404.03865
# if 24 <= self.layer_idx <= 28 and x.shape[0] == 1:
Expand All @@ -291,7 +229,7 @@ def __call__(self, x: mx.array, mask, cache) -> mx.array:
return out


def empty_func(h, mask, cache):
def empty_func(h, cache):
# TODO
return h

Expand All @@ -306,7 +244,7 @@ def __init__(self, args: ModelArgs, start_layer_idx: int, end_layer_idx: int, is
for layer_idx in range(start_layer_idx, end_layer_idx)
]

def __call__(self, h: mx.array, mask, cache: AttentionData) -> mx.array:
for layer in self.layers:
h = layer(h, mask, cache=cache)
def __call__(self, h: mx.array, cache: AttentionData) -> mx.array:
for i, layer in enumerate(self.layers):
h = layer(h, cache=cache)
return h
5 changes: 1 addition & 4 deletions tllm/models/mlx/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,9 @@ def __init__(self, config: AutoConfig, is_merge: bool = True):

def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input)

# cos, sin = self.rotary_emb(attention_data.position_ids)
# attention_data.cos, attention_data.sin = mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1)
mask = cache_manager.attn_data.attn_mask
mask = mask if mask is None else mask.astype(hidden_states.dtype)
output = self.model(hidden_states, mask=mask, cache=cache_manager.attn_data)
output = self.model(hidden_states, cache=cache_manager.attn_data)

# TODO 异步更新 cache
cache_manager.update_cache(seq_input)
Expand Down
5 changes: 1 addition & 4 deletions tllm/models/mlx/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ def __init__(self, config: AutoConfig, is_merge: bool = True):

def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input)

mask = cache_manager.attn_data.attn_mask
mask = mask if mask is None else mask.astype(hidden_states.dtype)
output = self.model(hidden_states, mask=mask, cache=cache_manager.attn_data)
output = self.model(hidden_states, cache=cache_manager.attn_data)

# TODO 异步更新 cache
cache_manager.update_cache(seq_input)
Expand Down

0 comments on commit ad6b451

Please sign in to comment.