Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions chatlearn/synchronizer/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,15 @@ def setup_rank_mapping(self):
f"greater or equal to expert parallel world size for inference ({self.num_dst_expert_parallel}) with HEP enabled."
)
if self.dst_model.use_vllm_backend:
if (
self.hep_num_mapping != 1
and get_args().runtime_args.routed_expert_regrouping_comm_type == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL
):
raise NotImplementedError(
"all-to-all routed expert weight is only supported when src TP size * src EP size = dst TP size. "
"Please consider setting `routed_expert_regrouping_comm_type` to allgather or adjusting the model's parallel size."
)

if self.tp_num_mapping == 1:
if self.ep_num_mapping == 1:
self.build_rank_mapping()
Expand Down
2 changes: 1 addition & 1 deletion chatlearn/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class RuntimeConfig(BaseConfig):
#: parameter sync max workers
param_sync_max_workers: int = None
#: communication type to regroup routed experts, allgather/alltoall
routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL
routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER
#: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes
#: if `max_relay_episode` is set to 0, then relay is disabled
max_relay_episode: int = 0
Expand Down
2 changes: 1 addition & 1 deletion chatlearn/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ class PARAM_SYNC_COMM_TYPE(str, Enum):

class ROUTED_EXPERT_REGROUPING_COMM_TYPE(str, Enum):
"""communication type of routed expert regrouping."""
ALLTOALL = "alltoall"
ALLGATHER = "allgather"
ALLTOALL = "alltoall"
18 changes: 9 additions & 9 deletions tests/test_hep_eptp_vllm_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,21 @@ def test_hep_eptp_vllm_tp_dst_ep1_tp2_pp1_src_ep4_tp2_pp1():
assert param_sync_group.ep_num_mapping == tuples[0] / tuples[3]
assert param_sync_group.tp_num_mapping == tuples[1] // tuples[4]

# Judge alltoall actors
alltoall_actors = param_sync_group.send_actors_to_regroup_routed_experts
# Judge allgather actors
allgather_actors = param_sync_group.send_actors_to_regroup_routed_experts
actor2rank = param_sync_group.actor2rank

assert param_sync_group._comm_type_to_regroup_routed_experts == "alltoall"
assert len(alltoall_actors) == 1
assert len(alltoall_actors[0]) == 8 # all src ranks should all-to-all routed experts
assert param_sync_group._comm_type_to_regroup_routed_experts == "allgather"
assert len(allgather_actors) == 1
assert len(allgather_actors[0]) == 8 # all src ranks should all-to-all routed experts
assert len(actor2rank) == 16 # all of the 16 actors should have rank
assert len(set(list(actor2rank.values()))) == len(actor2rank) # all ranks should be unique

alltoall_actor_ranks = []
for actor in alltoall_actors[0]:
alltoall_actor_ranks.append(actor2rank[actor])
allgather_actor_ranks = []
for actor in allgather_actors[0]:
allgather_actor_ranks.append(actor2rank[actor])

assert alltoall_actor_ranks == [0, 1, 2, 3, 4, 5, 6, 7]
assert allgather_actor_ranks == [0, 1, 2, 3, 4, 5, 6, 7]

# Judge src->dst rank mappings
comm_pairs = []
Expand Down
22 changes: 11 additions & 11 deletions tests/test_hep_eptppp_vllm_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,24 @@ def test_hep_eptppp_vllm_tp_dst_ep1_tp2_pp1_src_ep2_tp2_pp2():
assert param_sync_group.ep_num_mapping == tuples[0] / tuples[3]
assert param_sync_group.tp_num_mapping == tuples[1] // tuples[4]

# Judge alltoall actors
alltoall_actors = param_sync_group.send_actors_to_regroup_routed_experts
# Judge allgather actors
allgather_actors = param_sync_group.send_actors_to_regroup_routed_experts
actor2rank = param_sync_group.actor2rank

assert param_sync_group._comm_type_to_regroup_routed_experts == "alltoall"
assert len(alltoall_actors) == 2
assert len(alltoall_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts
assert len(alltoall_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts
assert param_sync_group._comm_type_to_regroup_routed_experts == "allgather"
assert len(allgather_actors) == 2
assert len(allgather_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts
assert len(allgather_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts
assert len(actor2rank) == 16 # all of the 16 actors should have rank
assert len(set(list(actor2rank.values()))) == len(actor2rank) # all ranks should be unique

alltoall_actor_ranks = []
for actor_list in alltoall_actors:
alltoall_actor_ranks.append([])
allgather_actor_ranks = []
for actor_list in allgather_actors:
allgather_actor_ranks.append([])
for actor in actor_list:
alltoall_actor_ranks[-1].append(actor2rank[actor])
allgather_actor_ranks[-1].append(actor2rank[actor])

assert alltoall_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]]
assert allgather_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]]

# Judge src->dst rank mappings
comm_pairs = []
Expand Down
22 changes: 11 additions & 11 deletions tests/test_hep_tp_vllm_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,24 @@ def test_hep_tp_vllm_tp_dst_ep1_tp4_pp1_src_ep1_tp4_pp1():
assert param_sync_group.ep_num_mapping == tuples[0] / tuples[3]
assert param_sync_group.tp_num_mapping == tuples[1] // tuples[4]

# Judge alltoall actors
alltoall_actors = param_sync_group.send_actors_to_regroup_routed_experts
# Judge allgather actors
allgather_actors = param_sync_group.send_actors_to_regroup_routed_experts
actor2rank = param_sync_group.actor2rank

assert param_sync_group._comm_type_to_regroup_routed_experts == "alltoall"
assert len(alltoall_actors) == 2
assert len(alltoall_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts
assert len(alltoall_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts
assert param_sync_group._comm_type_to_regroup_routed_experts == "allgather"
assert len(allgather_actors) == 2
assert len(allgather_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts
assert len(allgather_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts
assert len(actor2rank) == 16 # all of the 16 actors should have rank
assert len(set(list(actor2rank.values()))) == len(actor2rank) # all ranks should be unique

alltoall_actor_ranks = []
for actor_list in alltoall_actors:
alltoall_actor_ranks.append([])
allgather_actor_ranks = []
for actor_list in allgather_actors:
allgather_actor_ranks.append([])
for actor in actor_list:
alltoall_actor_ranks[-1].append(actor2rank[actor])
allgather_actor_ranks[-1].append(actor2rank[actor])

assert alltoall_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]]
assert allgather_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]]

# Judge src->dst rank mappings
comm_pairs = []
Expand Down
Loading