-
Notifications
You must be signed in to change notification settings - Fork 182
feature(xjy): Fixed the accumulate_steps, game_segment/weighted_total_loss bugs and refine prompts, compute_llm_prior, and SFT loss, and added cprofile functionality. #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-multitask-balance-clean-rft
Are you sure you want to change the base?
Conversation
…_llm_prior, and SFT loss
| llm_sft_loss = torch.tensor(0.0, device=self._cfg.device) | ||
| if self.llm_policy_cfg.enable_llm and self.llm_policy_cfg.enable_rft: | ||
| with self._profile_block(name="train_llm_rft"): | ||
| llm_rft_loss = self.compute_rft_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个target_value,我理解应该是和 observation 一一对应的,没有错位吧?比如target_value第一个值代表obs中第一个状态下的结果
| sequence_log_probs = token_log_probs.sum(dim=-1) / (mask.sum(dim=-1) + 1e-8) | ||
|
|
||
| if self.llm_policy_cfg.rft_reward=='value': | ||
| rewards_tensor = torch.tensor(batch_values, device=self._cfg.device, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rewards_tensor 重命名为 advantage_tansor 吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
…lect to cprofile.
…ed the REINFORCE-series loss computation.
No description provided.