Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions roll/distributed/scheduler/generate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def generate(self, data: DataProto, actor_cluster: Union[Any, Cluster], pipeline

def get_available_dp_rank(self):
while True:
# 负载均衡逻辑,期望各dp 正在处理的条数基本接近
# Load balancing logic, expect the number of items being processed by each dp to be roughly similar
sorted_ranks = sorted(
self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank)
)
Expand Down Expand Up @@ -205,26 +205,26 @@ def generate_opt_level_1(self, data: DataProto):
)
self.cluster.start_server(data=DataProto(meta_info=data.meta_info), blocking=True)

# 分发数据至收到target rollout 完成
# 无限循环,把所有的response发送给dp worker
# Distribute data until target rollout completion
# Infinite loop, send all responses to dp workers
send_request_count = 0
request_refs = []
data_index_counter = itertools.count()
last_alive_check = time.time()
while not self.is_completed:

# 探测dp worker是否存活,dp worker的server thread可能由于异常退出,造成hang
# Check if dp worker is alive, dp worker's server thread may exit due to exceptions, causing hang
current_time = time.time()
if current_time - last_alive_check >= self.alive_check_interval:
self.cluster.add_request(command=GenerateRequestType.ALIVE_CHECK, data=DataProto())
last_alive_check = current_time

if send_request_count < data.batch.batch_size[0]:
# 取一个可以发送request的dp worker
# Get a dp worker that can send requests
dp_rank = next(self.get_available_dp_rank())

# 还有数据需要发送, 取需要发送的数据
# request_id 全局递增,否则vllm/sglang scheduler状态不对
# Still have data to send, get the data that needs to be sent
# request_id increments globally, otherwise vllm/sglang scheduler state is incorrect
request_id = next(self.request_counter)
data_index = next(data_index_counter)
request_data = collate_fn([self.data[data_index]])
Expand All @@ -235,7 +235,7 @@ def generate_opt_level_1(self, data: DataProto):
].item()
self.request_id_2_dp_rank[request_data.meta_info["request_id"]] = dp_rank
self.prompt_id_2_request_ids[prompt_id].add(request_data.meta_info["request_id"])
# 需要注意上面的调用顺序, report_response中会更新request_id索引dp_rank,所以这里需要最后add request_id
# Need to pay attention to the calling order above, report_response will update request_id index dp_rank, so need to add request_id last
request_data.meta_info["response_callback_fn"] = self.response_callback_fn
request_data.meta_info["generation_config"] = data.meta_info["generation_config"]
request_refs.append(
Expand Down Expand Up @@ -394,7 +394,7 @@ def set_scheduler(
state: Dict[str, Any] = None,
):
"""
GenerateScheduler可以由多个实例,不再局限于单例
GenerateScheduler can have multiple instances, no longer limited to singleton
"""
self.actor_cluster = actor_cluster
self.reward_clusters = reward_clusters
Expand Down Expand Up @@ -459,9 +459,9 @@ def reset_status(self):

def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
"""
从dataset里,按给定策略sample batch
1. 常规无过滤
2. 动态过滤
Sample batch from dataset using given strategy
1. Regular without filtering
2. Dynamic filtering
"""
self.batch_size = batch_size
self.reset_status()
Expand Down Expand Up @@ -522,7 +522,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
f"used queries: {query_use_count} query_filter_count: {self.query_filter_count} "
f"response_filter_count: {self.response_filter_count}"
)
# TODO: 这里 len(collect_data) > rollout_batch_size, 可以尝试动态扩大batch_size
# TODO: Here len(collect_data) > rollout_batch_size, can try dynamically expanding batch_size
batch = DataProto.concat(collect_data[: self.batch_size * num_return_sequences])
batch.meta_info["metrics"] = {
f"scheduler/query_filter_count": self.query_filter_count,
Expand All @@ -531,7 +531,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
f"scheduler/query_use_count": query_use_count,
}

# 统计全部response metrics
# Count all response metrics
metrics = {}
for domain, response_batches in self.response_cache.items():
response_batch = DataProto.concat(response_batches[:])
Expand All @@ -548,8 +548,8 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
@ray.method(concurrency_group="multi_thread")
def report_response(self, data: DataProto):
"""
这里需要考虑多线程数据访问
data 返回可能有多条的
Need to consider multi-threaded data access here
Data return may have multiple entries
"""
try:
request_id = data.meta_info["request_id"]
Expand All @@ -570,15 +570,15 @@ def report_response(self, data: DataProto):
return

# call reward
# reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意...
# 多域的时候,llm as judge, 需要单独为reward worker分配gpu
# reward worker must support single data calculation, dynamic sampling needs attention for batch reward calculation...
# In multi-domain cases, llm as judge, need to allocate gpu separately for reward worker
rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch))
batch.union(rewards)

response_buffers: List[DataProto] = []
batch_expanded = [batch[[idx]] for idx in range(output_count)]

# response_filter, 不太需要response filter
# response_filter, don't really need response filter
for batch_item in batch_expanded:
if self.response_filter_fn(batch_item, self.pipeline_config):
response_buffers.append(batch_item)
Expand Down Expand Up @@ -706,7 +706,7 @@ def expand_requests(self, data: DataProto):
return target_requests

def check_worker_alive(self, cluster):
# 探测dp worker是否存活,dp worker的server thread可能由于异常退出,造成hang
# Check if dp worker is alive, dp worker's server thread may exit due to exceptions, causing hang
current_time = time.time()
if current_time - self.last_alive_check >= self.alive_check_interval:
cluster.add_request(command=GenerateRequestType.ALIVE_CHECK, data=DataProto())
Expand All @@ -727,7 +727,7 @@ def check_send_new_request(self) -> bool:

def get_available_dp_rank(self):
while True:
# 负载均衡逻辑,期望各dp 正在处理的条数基本接近
# Load balancing logic, expect the number of items being processed by each dp to be roughly similar
sorted_ranks = sorted(
self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank)
)
Expand Down
22 changes: 11 additions & 11 deletions roll/distributed/scheduler/reward_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
@ray.remote
class RewardScheduler:
"""
reward 服务化和generate不同, request接口:
reward scheduler需要解决的是不同域的sample的reward计算问题, 不需要实现request粒度的接口;
并且reward计算和vllm不同,vllm可以continue batch,所以可以动态add request, reward不行,
直接rpc调用reward_cluster.compute_rewards即可(使用rpc方式调用,可以增加reward的数量,增大并发处理能力)
Reward service is different from generation, request interface:
Reward scheduler needs to solve the reward calculation problem for samples from different domains, no need to implement request-level interface;
And reward calculation is different from vllm, vllm can continue batch, so it can dynamically add requests, reward cannot,
directly use rpc to call reward_cluster.compute_rewards (using rpc method, can increase the number of rewards, increase concurrent processing capacity)

reward scheduler需要解决的问题:
按domain路由reward
dp dispatch 均分/不足dp_size 的限制
Problems that reward scheduler needs to solve:
Route rewards by domain
dp dispatch load balancing/insufficient dp_size limitations
"""

def __init__(self):
Expand All @@ -32,13 +32,13 @@ def __init__(self):

def compute_rewards(self, data: DataProto, reward_clusters: Dict[str, Any], pipeline_config) -> DataProto:
"""
保序返回rewards
Return rewards in order
"""
self.pipeline_config = pipeline_config
self.reward_clusters = reward_clusters
data.batch["prompt_id"] = torch.arange(data.batch.batch_size[0], device=data.batch.device)

# 按domain group by data
# Group data by domain
grouped_data: Dict[str, DataProto] = data.group_by("domain")

domain_rewards_refs: Dict[str, List[ray.ObjectRef]] = defaultdict(list)
Expand All @@ -51,8 +51,8 @@ def compute_rewards(self, data: DataProto, reward_clusters: Dict[str, Any], pipe

rewards_list: List[DataProto] = []
for domain, domain_rewards_ref in domain_rewards_refs.items():
# 各reward的输出schema要求一致
# reward worker compute_rewards 接口返回结果保序
# All rewards require consistent output schema
# Reward worker compute_rewards interface returns results in order
if domain not in grouped_data.keys():
continue
domain_rewards: DataProto = DataProto.materialize_concat(data_refs=domain_rewards_ref)
Expand Down
4 changes: 2 additions & 2 deletions roll/models/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def default_reward_model_provider(
is_trainable: Optional[bool] = False,
):
"""
model.forward 遵循TokenClassifierOutput 协议
model.forward follows TokenClassifierOutput protocol
class TokenClassifierOutput(ModelOutput):
logits: torch.FloatTensor # 必须要有
logits: torch.FloatTensor # Required
loss: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
Expand Down
4 changes: 2 additions & 2 deletions roll/pipeline/agentic/agentic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class EnvManagerConfig(WorkerConfig):

def __post_init__(self):
"""
根据es config计算world_size
Calculate world_size based on es config
"""
self.world_size = self.env_groups * self.group_size
self.env_configs: Optional[Dict[int, Dict]] = None
Expand Down Expand Up @@ -266,7 +266,7 @@ def set_max_steps(self, max_steps: int):
self.critic.training_args.per_device_train_batch_size
* self.critic.training_args.gradient_accumulation_steps
)
# 没有除dp_size,需要在分布式环境初始化后再除
# Not divided by dp_size, need to divide after distributed environment initialization
self.actor_train.training_args.max_steps = max_steps * (
self.rollout_batch_size
* self.actor_infer.generating_args.num_return_sequences
Expand Down
12 changes: 6 additions & 6 deletions roll/pipeline/agentic/agentic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, pipeline_config: AgenticConfig):

@torch.no_grad()
def run(self):
# 计算tokens per second 系统吞吐
# Calculate tokens per second system throughput
tps_timer = _Timer(window_size=5)

for global_step in range(self.pipeline_config.max_steps):
Expand Down Expand Up @@ -191,8 +191,8 @@ def run(self):
metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {})))
metrics["time/old_log_probs_values"] = cal_old_logpb_timer.last

# 要按group by处理reward
# 可以tag(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv
# Need to process rewards by group
# Can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to calculate reward/adv
batch.batch["prompt_id"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device)
with Timer(name="adv", logger=None) as timer:
grouping = self.pipeline_config.reward_normalization.grouping
Expand Down Expand Up @@ -228,7 +228,7 @@ def run(self):
batch = DataProto.concat(batch_list)
batch.reorder(indices=torch.argsort(batch.batch["prompt_id"]))
batch.pop("prompt_id")
# advantage是全局batch计算,还是group内计算?
# Is advantage calculated globally across batch or within groups?
batch = compute_advantage(
data=batch,
gamma=self.pipeline_config.gamma,
Expand Down Expand Up @@ -314,8 +314,8 @@ def run(self):


def compute_data_metrics(batch):
# token_level_scores 是reward model给每个token的打分,可能经过了norm/clip
# score 为env的reward,raw value
# token_level_scores are scores given by reward model to each token, possibly after norm/clip
# score is the environment reward, raw value
sequence_score = batch.batch["scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
Expand Down
Loading