Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class PPOPolicy(Policy):
grad_clip_value=0.5,
# (bool) Whether ignore done (usually for max step termination env).
ignore_done=False,
# (str) The type of KL divergence loss, ['k1', 'k2', 'k3']
kl_type='k1',
# (float) The weight of KL divergence loss.
kl_beta=0.0,
),
# collect_mode config
collect=dict(
Expand Down Expand Up @@ -192,6 +196,8 @@ def _init_learn(self) -> None:
self._clip_ratio = self._cfg.learn.clip_ratio
self._adv_norm = self._cfg.learn.adv_norm
self._value_norm = self._cfg.learn.value_norm
self._kl_type = self._cfg.learn.kl_type
self._kl_beta = self._cfg.learn.kl_beta
if self._value_norm:
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
self._gamma = self._cfg.collect.discount_factor
Expand Down Expand Up @@ -291,27 +297,29 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
batch['return'], batch['weight']
)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
elif self._action_space == 'discrete':
ppo_batch = ppo_data(
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
batch['return'], batch['weight']
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
ppo_discrete_batch = ppo_policy_data(
output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'],
adv, batch['weight']
)
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio)
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(
ppo_discrete_batch, self._clip_ratio, kl_type=self._kl_type
)
# continuous part (continuous policy loss and entropy loss, value loss)
ppo_continuous_batch = ppo_data(
output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'],
output['value'], batch['value'], adv, batch['return'], batch['weight']
)
ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
ppo_continuous_batch, self._clip_ratio
ppo_continuous_batch, self._clip_ratio, kl_type=self._kl_type
)
# sum discrete and continuous loss
ppo_loss = type(ppo_continuous_loss)(
Expand All @@ -320,10 +328,15 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
)
ppo_info = type(ppo_continuous_info)(
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac), ppo_continuous_info.kl_div
)
wv, we = self._value_weight, self._entropy_weight
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
kl_div = ppo_info.kl_div
# 正确的、符合规范的修改
total_loss = (
ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss +
self._kl_beta * kl_div
)

self._optimizer.zero_grad()
total_loss.backward()
Expand All @@ -346,6 +359,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
'value_max': output['value'].max().item(),
'approx_kl': ppo_info.approx_kl,
'clipfrac': ppo_info.clipfrac,
'kl_div': kl_div.item(),
}
if self._action_space == 'continuous':
return_info.update(
Expand Down Expand Up @@ -593,6 +607,7 @@ def _monitor_vars_learn(self) -> List[str]:
'clipfrac',
'value_max',
'value_mean',
'kl_div',
]
if self._action_space == 'continuous':
variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']
Expand Down
64 changes: 53 additions & 11 deletions ding/rl_utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac', 'kl_div'])


def shape_fn_ppo(args, kwargs):
Expand All @@ -46,7 +46,8 @@ def ppo_error(
data: namedtuple,
clip_ratio: float = 0.2,
use_value_clip: bool = True,
dual_clip: Optional[float] = None
dual_clip: Optional[float] = None,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Expand All @@ -57,6 +58,7 @@ def ppo_error(
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:`str`): which kl loss to use, default set to 'approx'
Returns:
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -97,7 +99,7 @@ def ppo_error(
)
logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data
policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight)
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip)
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type)
value_data = ppo_value_data(value_new, value_old, return_, weight)
value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip)

Expand All @@ -108,7 +110,8 @@ def ppo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None,
entropy_bonus: bool = True
entropy_bonus: bool = True,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Expand All @@ -119,6 +122,7 @@ def ppo_policy_error(
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
defaults to 5.0, if you don't want to use it, set this parameter to None
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
Returns:
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -180,7 +184,18 @@ def ppo_policy_error(
approx_kl = (logp_old - logp_new).mean().item()
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
clipfrac = torch.as_tensor(clipped).float().mean().item()
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)

logr = logp_old - logp_new
if kl_type == 'k1':
kl_div = logr.mean()
elif kl_type == 'k2':
kl_div = (logr ** 2 / 2).mean()
elif kl_type == 'k3':
kl_div = (torch.exp(-logr) - 1 + logr).mean()
else:
raise ValueError(f"Unknown kl_type: {kl_type}")

return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div)


def ppo_value_error(
Expand Down Expand Up @@ -232,7 +247,8 @@ def ppo_error_continuous(
data: namedtuple,
clip_ratio: float = 0.2,
use_value_clip: bool = True,
dual_clip: Optional[float] = None
dual_clip: Optional[float] = None,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Expand All @@ -243,6 +259,7 @@ def ppo_error_continuous(
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
Returns:
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -314,12 +331,25 @@ def ppo_error_continuous(
else:
value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()

return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
logr = logp_old - logp_new
if kl_type == 'k1':
kl_div = logr.mean()
elif kl_type == 'k2':
kl_div = (logr ** 2 / 2).mean()
elif kl_type == 'k3':
kl_div = (torch.exp(-logr) - 1 + logr).mean()
else:
raise ValueError(f"Unknown kl_type: {kl_type}")

return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div)


def ppo_policy_error_continuous(data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
def ppo_policy_error_continuous(
data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
Expand All @@ -328,6 +358,7 @@ def ppo_policy_error_continuous(data: namedtuple,
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
Returns:
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -377,4 +408,15 @@ def ppo_policy_error_continuous(data: namedtuple,
approx_kl = (logp_old - logp_new).mean().item()
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
clipfrac = torch.as_tensor(clipped).float().mean().item()
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)

logr = logp_old - logp_new
if kl_type == 'k1':
kl_div = logr.mean()
elif kl_type == 'k2':
kl_div = (logr ** 2 / 2).mean()
elif kl_type == 'k3':
kl_div = (torch.exp(-logr) - 1 + logr).mean()
else:
raise ValueError(f"Unknown kl_type: {kl_type}")

return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div)
2 changes: 2 additions & 0 deletions dizoo/atari/config/serial/pong/pong_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
kl_beta=0.01,
kl_type='k1',
),
collect=dict(
n_sample=3200,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
kl_beta=0.05,
kl_type='k1',
),
collect=dict(
n_sample=1024,
Expand Down
Loading