Skip to content

add_twomicrobatch #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: deepep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List
from typing import List, Any

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -361,6 +361,15 @@ def __init__(self,
hidden_size=hidden_dim,
params_dtype=out_dtype,
)
self.token_dispatcher_for2mb = DeepEPDispatcher(
group=ep_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_experts // ep_size,
hidden_size=hidden_dim,
params_dtype=out_dtype,
)
self.experts = DeepEPMoE(num_experts, ep_size, [block_size,block_size])

def forward(self,
Expand All @@ -377,7 +386,7 @@ def forward(self,
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = (
self.token_dispatcher.dispatch(
hidden_states,
topk_ids.to(torch.int32),
topk_ids.to(torch.int64),
topk_weights.to(torch.float32),
self.num_experts,
)
Expand All @@ -387,6 +396,61 @@ def forward(self,
out_states = self.token_dispatcher.combine(out_states)
return out_states

def forward_yield(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None,
tag: Any = None,
shared_experts: Any = None):
"""forward_yield."""
topk_weights = _renormalize(topk_weights, self.renormalize)

if shared_experts is not None:
if self.token_dispatcher.get_shared_experts() is None:
self.token_dispatcher.set_shared_experts(shared_experts)
if self.token_dispatcher.get_shared_experts() is None:
self.token_dispatcher_for2mb.set_shared_experts(shared_experts)

assert tag is not None and len(tag) >= 1
_token_dispatcher = self.token_dispatcher
if tag is not None and tag[0] == "0":
_token_dispatcher = self.token_dispatcher
if tag is not None and tag[0] == "1":
_token_dispatcher = self.token_dispatcher_for2mb
is_decoding = False
is_prefill = False
if tag is not None and len(tag) > 1 and tag[1].upper() == "P":
is_prefill = True
if tag is not None and len(tag) > 1 and tag[1].upper() == "D":
is_decoding = True

_token_dispatcher.set_shared_experts(shared_experts)
# yield for attn1, dis (+share), dis_wait, moe
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert, shared_states_indispatch = (
yield from _token_dispatcher.dispatch_yield(
hidden_states,
topk_ids.to(torch.int64),
topk_weights.to(torch.float32),
self.num_experts,
is_prefill,
is_decoding
)
)
out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, gate_up_weights, gate_up_scale,
down_weights, down_scale)
# yield for moe, comb, (+share) comb_wait, (+share) attn0
out_states, shared_states_incomb = yield from _token_dispatcher.combine_yield(out_states,
hidden_states,
is_prefill,
is_decoding)
shared_states = shared_states_indispatch if shared_states_indispatch is not None else shared_states_incomb
return out_states, shared_states

class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
"""triton fused moe blocked f8 builder."""

Expand Down
151 changes: 149 additions & 2 deletions lmdeploy/pytorch/backends/cuda/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use_deepep = False

import os
from typing import Optional, Tuple
from typing import Optional, Tuple, Any

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -156,6 +156,8 @@ def __init__(
self.token_probs = None
# Handle used for combine operation
self.handle = None
# shared experts
self.shared_experts = None

# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
Expand All @@ -181,7 +183,6 @@ def dispatch(
num_max_dispatch_tokens_per_rank: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.hidden_shape = hidden_states.shape
topk_idx = topk_idx.to(torch.int64)
(
hidden_states,
topk_idx,
Expand All @@ -205,6 +206,53 @@ def dispatch(
hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states)
return hidden_states, topk_idx, topk_weights, tokens_per_expert

def dispatch_yield(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
previous_event=None,
num_max_dispatch_tokens_per_rank: int = 128,
is_prefill: bool = False,
is_decoding: bool = False
):
self.hidden_shape = hidden_states.shape
# yield for attn1, dis (+share)
yield
previous_event = self.buffer_normal.capture()
(
recv_hidden_states,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
handle,
event,
) = self.dispatch_normal_async(
hidden_states, topk_idx, topk_weights, num_experts, previous_event, True
)
if is_decoding and self.shared_experts is not None:
shared_states = self.shared_experts(hidden_states)
else:
shared_states = None
# yield for dis (+share), dis_wait
yield
event.current_stream_wait()
# yield for dis_wait, moe
yield
self.tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list,
device=hidden_states.device,
dtype=torch.int64,
)
tokens_per_expert = self.get_number_of_tokens_per_expert()
self.handle = handle
self.topk_idx = recv_topk_idx
self.topk_weights = recv_topk_weights
if recv_hidden_states.shape[0] > 0:
recv_hidden_states = self.get_permuted_hidden_states_by_experts(recv_hidden_states)
return recv_hidden_states, recv_topk_idx, recv_topk_weights, tokens_per_expert, shared_states

def dispatch_normal(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -256,6 +304,57 @@ def dispatch_normal(
event,
)

def dispatch_normal_async(
self,
x: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
previous_event=None,
async_finish=True
):
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
previous_event,
) = self.buffer_normal.get_dispatch_layout(
topk_idx,
num_experts,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=previous_event is not None and async_finish,
)

(
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
handle,
event,
) = self.buffer_normal.dispatch(
x,
topk_idx=topk_idx,
topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=previous_event is not None and async_finish,
)

return (
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
handle,
event,
)

def combine(
self, hidden_states: torch.Tensor
Expand All @@ -268,6 +367,36 @@ def combine(
self.handle = None
return hidden_states.view(self.hidden_shape)

def combine_yield(
self,
out_states: torch.Tensor,
hidden_states: torch.Tensor,
is_prefill: bool = False,
is_decoding: bool = False
):
if out_states.shape[0] > 0:
out_states = self.get_restored_hidden_states_by_experts(
out_states
)
# yield for moe, comb
yield
previous_event = self.buffer_normal.capture()
out_states, event = self.combine_normal_async(out_states,
self.handle,
previous_event=previous_event,
async_finish=True)
# yield for comb, (+share) comb_wait,
yield
if is_prefill and self.shared_experts is not None:
shared_states = self.shared_experts(hidden_states)
else:
shared_states = None
event.current_stream_wait()
# yield for (+share) comb_wait, (+share) attn0
yield
self.handle = None
return out_states.view(self.hidden_shape), shared_states

def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
combined_x, _, event = self.buffer_normal.combine(
x,
Expand All @@ -278,6 +407,16 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
)
return combined_x, event

def combine_normal_async(self, x: torch.Tensor, handle: Tuple, previous_event=None, async_finish=True):
combined_x, _, event = self.buffer_normal.combine(
x,
handle,
async_finish=async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None and async_finish,
)
return combined_x, event

def _indices_to_multihot(self, indices, probs):
batch_size = indices.shape[0]
multihot_routing_map = torch.zeros(
Expand Down Expand Up @@ -341,3 +480,11 @@ def get_restored_hidden_states_by_experts(
fused=self.permute_fusion,
)
return hidden_states.to(input_dtype)

def set_shared_experts(self, shared_experts: Any = None):
if self.shared_experts is not None:
self.shared_experts = shared_experts
return self.shared_experts

def get_shared_experts(self):
return self.shared_experts
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,5 @@ def __skip_nextn(name, nextn_keys):
else:
param = params_dict[name]
load_weight(param, loaded_weight)

import lmdeploy.pytorch.models.utils.microbatch
Loading