Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ class PPOPolicy(Policy):
grad_clip_value=0.5,
# (bool) Whether ignore done (usually for max step termination env).
ignore_done=False,
# (str) The type of KL divergence loss between current policy and pretrained policy, ['k1', 'k2', 'k3'].
# Reference: http://joschu.net/blog/kl-approx.html
kl_type='k1',
# (float) The weight of KL divergence loss.
kl_beta=0.0,
# (Optional[str]) The path of pretrained model checkpoint.
# If provided, KL regularizer will be calculated between current policy and pretrained policy.
# Default to None, which means KL is not calculated.
pretrained_model_path=None,
),
# collect_mode config
collect=dict(
Expand Down Expand Up @@ -186,12 +195,23 @@ def _init_learn(self) -> None:

self._learn_model = model_wrap(self._model, wrapper_name='base')

# load pretrained model
if self._cfg.learn.pretrained_model_path is not None:
self._pretrained_model = copy.deepcopy(self._model)
state_dict = torch.load(self._cfg.learn.pretrained_model_path, map_location='cpu')
self._pretrained_model.load_state_dict(state_dict)
self._pretrained_model.eval()
else:
self._pretrained_model = None

# Algorithm config
self._value_weight = self._cfg.learn.value_weight
self._entropy_weight = self._cfg.learn.entropy_weight
self._clip_ratio = self._cfg.learn.clip_ratio
self._adv_norm = self._cfg.learn.adv_norm
self._value_norm = self._cfg.learn.value_norm
self._kl_type = self._cfg.learn.kl_type
self._kl_beta = self._cfg.learn.kl_beta
if self._value_norm:
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
self._gamma = self._cfg.collect.discount_factor
Expand Down Expand Up @@ -285,45 +305,57 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Normalize advantage in a train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)

if self._pretrained_model is not None:
with torch.no_grad():
logit_pretrained = self._pretrained_model.forward(batch['obs'], mode='compute_actor')['logit']
else:
logit_pretrained = None

# Calculate ppo error
if self._action_space == 'continuous':
ppo_batch = ppo_data(
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
batch['return'], batch['weight']
batch['return'], batch['weight'], logit_pretrained
)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
elif self._action_space == 'discrete':
ppo_batch = ppo_data(
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
batch['return'], batch['weight']
batch['return'], batch['weight'], logit_pretrained
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
ppo_discrete_batch = ppo_policy_data(
output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'],
adv, batch['weight']
)
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio)
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(
ppo_discrete_batch, self._clip_ratio, kl_type=self._kl_type
)
# continuous part (continuous policy loss and entropy loss, value loss)
ppo_continuous_batch = ppo_data(
output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'],
output['value'], batch['value'], adv, batch['return'], batch['weight']
)
ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
ppo_continuous_batch, self._clip_ratio
ppo_continuous_batch, self._clip_ratio, kl_type=self._kl_type
)
# sum discrete and continuous loss
ppo_loss = type(ppo_continuous_loss)(
ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss,
ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss, ppo_continuous_loss.kl_div
)
ppo_info = type(ppo_continuous_info)(
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
)
wv, we = self._value_weight, self._entropy_weight
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
kl_div = ppo_loss.kl_div
total_loss = (
ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss +
self._kl_beta * kl_div
)

self._optimizer.zero_grad()
total_loss.backward()
Expand All @@ -346,6 +378,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
'value_max': output['value'].max().item(),
'approx_kl': ppo_info.approx_kl,
'clipfrac': ppo_info.clipfrac,
'kl_div': kl_div.item(),
}
if self._action_space == 'continuous':
return_info.update(
Expand Down Expand Up @@ -594,6 +627,8 @@ def _monitor_vars_learn(self) -> List[str]:
'value_max',
'value_mean',
]
if self._pretrained_model is not None:
variables += ['kl_div']
if self._action_space == 'continuous':
variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']
return variables
Expand Down
116 changes: 93 additions & 23 deletions ding/rl_utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,54 @@
from ding.hpc_rl import hpc_wrapper

ppo_data = namedtuple(
'ppo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight']
'ppo_data',
['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained']
)
ppo_data_continuous = namedtuple(
'ppo_data_continuous',
['mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight']
'ppo_data_continuous', [
'mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight',
'logit_pretrained'
]
)
ppo_policy_data = namedtuple(
'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained']
)
ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
ppo_policy_data_continuous = namedtuple(
'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight']
'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight', 'logit_pretrained']
)
ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss', 'kl_div'])
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss', 'kl_div'])
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])


def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor:
"""
Overview:
Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)],
where q is the current policy and p is the pretrained policy.
The implementation is based on John Schulman's blog post "Approximating KL Divergence".
Reference: http://joschu.net/blog/kl-approx.html
Arguments:
- log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be
log(q/p) = logp_new - logp_pretrained.
- kl_type (:obj:`str`): The type of KL divergence estimator to use.
- 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`.
- 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`.
- 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`.
Returns:
- kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate.
"""
if kl_type == 'k1':
return log_ratio.mean()
elif kl_type == 'k2':
return (log_ratio ** 2 / 2).mean()
elif kl_type == 'k3':
return (torch.exp(-log_ratio) - 1 + log_ratio).mean()
else:
raise ValueError(f"Unknown kl_type: {kl_type}")


def shape_fn_ppo(args, kwargs):
r"""
Overview:
Expand All @@ -46,7 +78,8 @@ def ppo_error(
data: namedtuple,
clip_ratio: float = 0.2,
use_value_clip: bool = True,
dual_clip: Optional[float] = None
dual_clip: Optional[float] = None,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Expand All @@ -57,6 +90,7 @@ def ppo_error(
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
Returns:
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -95,20 +129,23 @@ def ppo_error(
assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
dual_clip
)
logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data
policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight)
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip)
logit_new, logit_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data
policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight, logit_pretrained)
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type)
value_data = ppo_value_data(value_new, value_old, return_, weight)
value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip)

return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info
return ppo_loss(
policy_output.policy_loss, value_loss, policy_output.entropy_loss, policy_output.kl_div
), policy_info


def ppo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None,
entropy_bonus: bool = True
entropy_bonus: bool = True,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Expand All @@ -119,6 +156,7 @@ def ppo_policy_error(
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
defaults to 5.0, if you don't want to use it, set this parameter to None
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
Returns:
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -148,7 +186,7 @@ def ppo_policy_error(
.. note::
For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
"""
logit_new, logit_old, action, adv, weight = data
logit_new, logit_old, action, adv, weight, logit_pretrained = data
if weight is None:
weight = torch.ones_like(adv)
dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
Expand Down Expand Up @@ -180,7 +218,16 @@ def ppo_policy_error(
approx_kl = (logp_old - logp_new).mean().item()
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
clipfrac = torch.as_tensor(clipped).float().mean().item()
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)

if logit_pretrained is not None:
dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained)
logp_pretrained = dist_pretrained.log_prob(action)
log_ratio = logp_new - logp_pretrained
kl_div = calculate_kl_div(log_ratio, kl_type)
else:
kl_div = 0

return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)


def ppo_value_error(
Expand Down Expand Up @@ -232,7 +279,8 @@ def ppo_error_continuous(
data: namedtuple,
clip_ratio: float = 0.2,
use_value_clip: bool = True,
dual_clip: Optional[float] = None
dual_clip: Optional[float] = None,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Expand All @@ -243,6 +291,7 @@ def ppo_error_continuous(
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
Returns:
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand Down Expand Up @@ -281,7 +330,7 @@ def ppo_error_continuous(
assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
dual_clip
)
mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight = data
mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data
if weight is None:
weight = torch.ones_like(adv)

Expand Down Expand Up @@ -314,12 +363,23 @@ def ppo_error_continuous(
else:
value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()

return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
if logit_pretrained is not None:
dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1)
logp_pretrained = dist_pretrained.log_prob(action)
log_ratio = logp_new - logp_pretrained
kl_div = calculate_kl_div(log_ratio, kl_type)
else:
kl_div = 0

return ppo_loss(policy_loss, value_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)


def ppo_policy_error_continuous(data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
def ppo_policy_error_continuous(
data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None,
kl_type: str = 'k1'
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
Expand All @@ -328,6 +388,7 @@ def ppo_policy_error_continuous(data: namedtuple,
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
Returns:
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand All @@ -353,7 +414,7 @@ def ppo_policy_error_continuous(data: namedtuple,
assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
dual_clip
)
mu_sigma_new, mu_sigma_old, action, adv, weight = data
mu_sigma_new, mu_sigma_old, action, adv, weight, logit_pretrained = data
if weight is None:
weight = torch.ones_like(adv)

Expand All @@ -377,4 +438,13 @@ def ppo_policy_error_continuous(data: namedtuple,
approx_kl = (logp_old - logp_new).mean().item()
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
clipfrac = torch.as_tensor(clipped).float().mean().item()
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)

if logit_pretrained is not None:
dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1)
logp_pretrained = dist_pretrained.log_prob(action)
log_ratio = logp_new - logp_pretrained
kl_div = calculate_kl_div(log_ratio, kl_type)
else:
kl_div = 0

return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)
7 changes: 7 additions & 0 deletions dizoo/atari/config/serial/pong/pong_ppo_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from easydict import EasyDict

pong_ppo_config = dict(
exp_name='pong_ppo_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
Expand Down Expand Up @@ -39,6 +40,12 @@
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
# KL divergence regularization between current policy and pretrained policy.
# Supported KL divergence estimators: ['k1', 'k2', 'k3'].
# KL divergence loss will be calculated only when pretrained_model_path is provided.
kl_beta=0.01,
kl_type='k1',
pretrained_model_path=None,
),
collect=dict(
n_sample=3200,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
# KL divergence regularization between current policy and pretrained policy.
# Supported KL divergence estimators: ['k1', 'k2', 'k3'].
# KL divergence loss will be calculated only when pretrained_model_path is provided.
kl_beta=0.05,
kl_type='k1',
pretrained_model_path=None,
),
collect=dict(
n_sample=1024,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,3 @@
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0)

Loading
Loading