diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 958cba1d83..3d7ee4b29c 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -76,6 +76,15 @@ class PPOPolicy(Policy): grad_clip_value=0.5, # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, + # (str) The type of KL divergence loss between current policy and pretrained policy, ['k1', 'k2', 'k3']. + # Reference: http://joschu.net/blog/kl-approx.html + kl_type='k1', + # (float) The weight of KL divergence loss. + kl_beta=0.0, + # (Optional[str]) The path of pretrained model checkpoint. + # If provided, KL regularizer will be calculated between current policy and pretrained policy. + # Default to None, which means KL is not calculated. + pretrained_model_path=None, ), # collect_mode config collect=dict( @@ -186,12 +195,23 @@ def _init_learn(self) -> None: self._learn_model = model_wrap(self._model, wrapper_name='base') + # load pretrained model + if self._cfg.learn.pretrained_model_path is not None: + self._pretrained_model = copy.deepcopy(self._model) + state_dict = torch.load(self._cfg.learn.pretrained_model_path, map_location='cpu') + self._pretrained_model.load_state_dict(state_dict) + self._pretrained_model.eval() + else: + self._pretrained_model = None + # Algorithm config self._value_weight = self._cfg.learn.value_weight self._entropy_weight = self._cfg.learn.entropy_weight self._clip_ratio = self._cfg.learn.clip_ratio self._adv_norm = self._cfg.learn.adv_norm self._value_norm = self._cfg.learn.value_norm + self._kl_type = self._cfg.learn.kl_type + self._kl_beta = self._cfg.learn.kl_beta if self._value_norm: self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) self._gamma = self._cfg.collect.discount_factor @@ -285,45 +305,57 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Normalize advantage in a train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) + if self._pretrained_model is not None: + with torch.no_grad(): + logit_pretrained = self._pretrained_model.forward(batch['obs'], mode='compute_actor')['logit'] + else: + logit_pretrained = None + # Calculate ppo error if self._action_space == 'continuous': ppo_batch = ppo_data( output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, - batch['return'], batch['weight'] + batch['return'], batch['weight'], logit_pretrained ) - ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio) + ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type) elif self._action_space == 'discrete': ppo_batch = ppo_data( output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, - batch['return'], batch['weight'] + batch['return'], batch['weight'], logit_pretrained ) - ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type) elif self._action_space == 'hybrid': # discrete part (discrete policy loss and entropy loss) ppo_discrete_batch = ppo_policy_data( output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'], adv, batch['weight'] ) - ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio) + ppo_discrete_loss, ppo_discrete_info = ppo_policy_error( + ppo_discrete_batch, self._clip_ratio, kl_type=self._kl_type + ) # continuous part (continuous policy loss and entropy loss, value loss) ppo_continuous_batch = ppo_data( output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'], output['value'], batch['value'], adv, batch['return'], batch['weight'] ) ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous( - ppo_continuous_batch, self._clip_ratio + ppo_continuous_batch, self._clip_ratio, kl_type=self._kl_type ) # sum discrete and continuous loss ppo_loss = type(ppo_continuous_loss)( ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss, - ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss + ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss, ppo_continuous_loss.kl_div ) ppo_info = type(ppo_continuous_info)( max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) ) wv, we = self._value_weight, self._entropy_weight - total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + kl_div = ppo_loss.kl_div + total_loss = ( + ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + + self._kl_beta * kl_div + ) self._optimizer.zero_grad() total_loss.backward() @@ -346,6 +378,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 'value_max': output['value'].max().item(), 'approx_kl': ppo_info.approx_kl, 'clipfrac': ppo_info.clipfrac, + 'kl_div': kl_div.item(), } if self._action_space == 'continuous': return_info.update( @@ -594,6 +627,8 @@ def _monitor_vars_learn(self) -> List[str]: 'value_max', 'value_mean', ] + if self._pretrained_model is not None: + variables += ['kl_div'] if self._action_space == 'continuous': variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act'] return variables diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index c88c647e7c..bb83d3b9bc 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -6,22 +6,54 @@ from ding.hpc_rl import hpc_wrapper ppo_data = namedtuple( - 'ppo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight'] + 'ppo_data', + ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained'] ) ppo_data_continuous = namedtuple( - 'ppo_data_continuous', - ['mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight'] + 'ppo_data_continuous', [ + 'mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', + 'logit_pretrained' + ] +) +ppo_policy_data = namedtuple( + 'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained'] ) -ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight']) ppo_policy_data_continuous = namedtuple( - 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight'] + 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight', 'logit_pretrained'] ) ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight']) -ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) -ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss']) +ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss', 'kl_div']) +ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss', 'kl_div']) ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) +def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor: + """ + Overview: + Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)], + where q is the current policy and p is the pretrained policy. + The implementation is based on John Schulman's blog post "Approximating KL Divergence". + Reference: http://joschu.net/blog/kl-approx.html + Arguments: + - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be + log(q/p) = logp_new - logp_pretrained. + - kl_type (:obj:`str`): The type of KL divergence estimator to use. + - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`. + - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`. + - 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`. + Returns: + - kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate. + """ + if kl_type == 'k1': + return log_ratio.mean() + elif kl_type == 'k2': + return (log_ratio ** 2 / 2).mean() + elif kl_type == 'k3': + return (torch.exp(-log_ratio) - 1 + log_ratio).mean() + else: + raise ValueError(f"Unknown kl_type: {kl_type}") + + def shape_fn_ppo(args, kwargs): r""" Overview: @@ -46,7 +78,8 @@ def ppo_error( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, - dual_clip: Optional[float] = None + dual_clip: Optional[float] = None, + kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: @@ -57,6 +90,7 @@ def ppo_error( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -95,20 +129,23 @@ def ppo_error( assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) - logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data - policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight) - policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip) + logit_new, logit_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data + policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight, logit_pretrained) + policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type) value_data = ppo_value_data(value_new, value_old, return_, weight) value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip) - return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info + return ppo_loss( + policy_output.policy_loss, value_loss, policy_output.entropy_loss, policy_output.kl_div + ), policy_info def ppo_policy_error( data: namedtuple, clip_ratio: float = 0.2, dual_clip: Optional[float] = None, - entropy_bonus: bool = True + entropy_bonus: bool = True, + kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: @@ -119,6 +156,7 @@ def ppo_policy_error( - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \ defaults to 5.0, if you don't want to use it, set this parameter to None - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it. + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -148,7 +186,7 @@ def ppo_policy_error( .. note:: For the action mask often used in LLM/VLM, users can set the `weight` to the action mask. """ - logit_new, logit_old, action, adv, weight = data + logit_new, logit_old, action, adv, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) dist_new = torch.distributions.categorical.Categorical(logits=logit_new) @@ -180,7 +218,16 @@ def ppo_policy_error( approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() - return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac) + + if logit_pretrained is not None: + dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained) + logp_pretrained = dist_pretrained.log_prob(action) + log_ratio = logp_new - logp_pretrained + kl_div = calculate_kl_div(log_ratio, kl_type) + else: + kl_div = 0 + + return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) def ppo_value_error( @@ -232,7 +279,8 @@ def ppo_error_continuous( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, - dual_clip: Optional[float] = None + dual_clip: Optional[float] = None, + kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: @@ -243,6 +291,7 @@ def ppo_error_continuous( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -281,7 +330,7 @@ def ppo_error_continuous( assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) - mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight = data + mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) @@ -314,12 +363,23 @@ def ppo_error_continuous( else: value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() - return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac) + if logit_pretrained is not None: + dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) + logp_pretrained = dist_pretrained.log_prob(action) + log_ratio = logp_new - logp_pretrained + kl_div = calculate_kl_div(log_ratio, kl_type) + else: + kl_div = 0 + + return ppo_loss(policy_loss, value_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) -def ppo_policy_error_continuous(data: namedtuple, - clip_ratio: float = 0.2, - dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]: +def ppo_policy_error_continuous( + data: namedtuple, + clip_ratio: float = 0.2, + dual_clip: Optional[float] = None, + kl_type: str = 'k1' +) -> Tuple[namedtuple, namedtuple]: """ Overview: Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip @@ -328,6 +388,7 @@ def ppo_policy_error_continuous(data: namedtuple, - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -353,7 +414,7 @@ def ppo_policy_error_continuous(data: namedtuple, assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) - mu_sigma_new, mu_sigma_old, action, adv, weight = data + mu_sigma_new, mu_sigma_old, action, adv, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) @@ -377,4 +438,13 @@ def ppo_policy_error_continuous(data: namedtuple, approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() - return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac) + + if logit_pretrained is not None: + dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) + logp_pretrained = dist_pretrained.log_prob(action) + log_ratio = logp_new - logp_pretrained + kl_div = calculate_kl_div(log_ratio, kl_type) + else: + kl_div = 0 + + return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) diff --git a/dizoo/atari/config/serial/pong/pong_ppo_config.py b/dizoo/atari/config/serial/pong/pong_ppo_config.py index df74a30a55..0d7ae3ed78 100644 --- a/dizoo/atari/config/serial/pong/pong_ppo_config.py +++ b/dizoo/atari/config/serial/pong/pong_ppo_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict pong_ppo_config = dict( + exp_name='pong_ppo_seed0', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -39,6 +40,12 @@ ignore_done=False, grad_clip_type='clip_norm', grad_clip_value=0.5, + # KL divergence regularization between current policy and pretrained policy. + # Supported KL divergence estimators: ['k1', 'k2', 'k3']. + # KL divergence loss will be calculated only when pretrained_model_path is provided. + kl_beta=0.01, + kl_type='k1', + pretrained_model_path=None, ), collect=dict( n_sample=3200, diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py index cb94b49e3b..fb5969282c 100644 --- a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py @@ -44,6 +44,12 @@ ignore_done=False, grad_clip_type='clip_norm', grad_clip_value=0.5, + # KL divergence regularization between current policy and pretrained policy. + # Supported KL divergence estimators: ['k1', 'k2', 'k3']. + # KL divergence loss will be calculated only when pretrained_model_path is provided. + kl_beta=0.05, + kl_type='k1', + pretrained_model_path=None, ), collect=dict( n_sample=1024, diff --git a/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py b/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py index 82d6c673ec..a80662941a 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py @@ -63,4 +63,3 @@ from ding.entry import serial_pipeline with DDPContext(): serial_pipeline((main_config, create_config), seed=0) - diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py index 144feac1dd..e3aa855afe 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py index 545ecf970b..440525a320 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py index d48a1fb472..0974735b72 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_expert_iql_config.py b/dizoo/d4rl/config/hopper_medium_expert_iql_config.py index 6aef029c5e..2eebce2771 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_iql_config.py b/dizoo/d4rl/config/hopper_medium_iql_config.py index 8f429be268..61dbb5fac3 100644 --- a/dizoo/d4rl/config/hopper_medium_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_replay_iql_config.py b/dizoo/d4rl/config/hopper_medium_replay_iql_config.py index ad1b222843..df96a84aea 100644 --- a/dizoo/d4rl/config/hopper_medium_replay_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_replay_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None,