Skip to content

Commit

Permalink
update rllab
Browse files Browse the repository at this point in the history
  • Loading branch information
LantaoYu committed Jun 25, 2020
1 parent 924bd0c commit 455782c
Show file tree
Hide file tree
Showing 64 changed files with 7,570 additions and 1 deletion.
1 change: 0 additions & 1 deletion rllab/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ blackbox.zip
blackbox
rllab/config_personal.py
*.swp
sandbox
Empty file added rllab/sandbox/__init__.py
Empty file.
Empty file added rllab/sandbox/rocky/__init__.py
Empty file.
Empty file.
1 change: 1 addition & 0 deletions rllab/sandbox/rocky/tf/algos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

160 changes: 160 additions & 0 deletions rllab/sandbox/rocky/tf/algos/batch_polopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import time
from rllab.algos.base import RLAlgorithm
import rllab.misc.logger as logger
from sandbox.rocky.tf.policies.base import Policy
import tensorflow as tf
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler
from rllab.sampler.utils import rollout


class BatchPolopt(RLAlgorithm):
"""
Base class for batch sampling-based policy optimization methods.
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
"""

def __init__(
self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs
):
"""
:param env: Environment
:param policy: Policy
:type policy: Policy
:param baseline: Baseline
:param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
simultaneously, each using different environments and policies
:param n_itr: Number of iterations.
:param start_itr: Starting iteration.
:param batch_size: Number of samples per iteration.
:param max_path_length: Maximum length of a single rollout.
:param discount: Discount.
:param gae_lambda: Lambda used for generalized advantage estimation.
:param plot: Plot evaluation run after each iteration.
:param pause_for_plot: Whether to pause before contiuing when plotting.
:param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
:param positive_adv: Whether to shift the advantages so that they are always positive. When used in
conjunction with center_adv the advantages will be standardized before shifting.
:param store_paths: Whether to save all paths data to the snapshot.
:return:
"""
self.env = env
self.policy = policy
self.baseline = baseline
self.scope = scope
self.n_itr = n_itr
self.start_itr = start_itr
self.batch_size = batch_size
self.max_path_length = max_path_length
self.discount = discount
self.gae_lambda = gae_lambda
self.plot = plot
self.pause_for_plot = pause_for_plot
self.center_adv = center_adv
self.positive_adv = positive_adv
self.store_paths = store_paths
self.whole_paths = whole_paths
self.fixed_horizon = fixed_horizon
if sampler_cls is None:
if self.policy.vectorized and not force_batch_sampler:
sampler_cls = VectorizedSampler
else:
sampler_cls = BatchSampler
if sampler_args is None:
sampler_args = dict()
self.sampler = sampler_cls(self, **sampler_args)
self.init_opt()

def start_worker(self):
self.sampler.start_worker()

def shutdown_worker(self):
self.sampler.shutdown_worker()

def obtain_samples(self, itr):
return self.sampler.obtain_samples(itr)

def process_samples(self, itr, paths):
return self.sampler.process_samples(itr, paths)

def train(self, sess=None):
created_session = True if (sess is None) else False
if sess is None:
sess = tf.Session()
sess.__enter__()

sess.run(tf.global_variables_initializer())
self.start_worker()
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
if created_session:
sess.close()

def log_diagnostics(self, paths):
self.env.log_diagnostics(paths)
self.policy.log_diagnostics(paths)
self.baseline.log_diagnostics(paths)

def init_opt(self):
"""
Initialize the optimization procedure. If using tensorflow, this may
include declaring all the variables and compiling functions
"""
raise NotImplementedError

def get_itr_snapshot(self, itr, samples_data):
"""
Returns all the data that should be saved in the snapshot for this
iteration.
"""
raise NotImplementedError

def optimize_policy(self, itr, samples_data):
raise NotImplementedError

1 change: 1 addition & 0 deletions rllab/sandbox/rocky/tf/algos/npg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

130 changes: 130 additions & 0 deletions rllab/sandbox/rocky/tf/algos/npo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@



from rllab.misc import ext
from rllab.misc.overrides import overrides
import rllab.misc.logger as logger
from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer
from sandbox.rocky.tf.algos.batch_polopt import BatchPolopt
from sandbox.rocky.tf.misc import tensor_utils
import tensorflow as tf


class NPO(BatchPolopt):
"""
Natural Policy Optimization.
"""

def __init__(
self,
optimizer=None,
optimizer_args=None,
step_size=0.01,
**kwargs):
if optimizer is None:
if optimizer_args is None:
optimizer_args = dict()
optimizer = PenaltyLbfgsOptimizer(**optimizer_args)
self.optimizer = optimizer
self.step_size = step_size
super(NPO, self).__init__(**kwargs)

@overrides
def init_opt(self):
is_recurrent = int(self.policy.recurrent)
obs_var = self.env.observation_space.new_tensor_variable(
'obs',
extra_dims=1 + is_recurrent,
)
action_var = self.env.action_space.new_tensor_variable(
'action',
extra_dims=1 + is_recurrent,
)
advantage_var = tensor_utils.new_tensor(
'advantage',
ndim=1 + is_recurrent,
dtype=tf.float32,
)
dist = self.policy.distribution

old_dist_info_vars = {
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k)
for k, shape in dist.dist_info_specs
}
old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]

state_info_vars = {
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k)
for k, shape in self.policy.state_info_specs
}
state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]

if is_recurrent:
valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid")
else:
valid_var = None

dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars)
if is_recurrent:
mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
surr_loss = - tf.reduce_sum(lr * advantage_var * valid_var) / tf.reduce_sum(valid_var)
else:
mean_kl = tf.reduce_mean(kl)
surr_loss = - tf.reduce_mean(lr * advantage_var)

input_list = [
obs_var,
action_var,
advantage_var,
] + state_info_vars_list + old_dist_info_vars_list
if is_recurrent:
input_list.append(valid_var)

self.optimizer.update_opt(
loss=surr_loss,
target=self.policy,
leq_constraint=(mean_kl, self.step_size),
inputs=input_list,
constraint_name="mean_kl"
)
return dict()

@overrides
def optimize_policy(self, itr, samples_data):
all_input_values = tuple(ext.extract(
samples_data,
"observations", "actions", "advantages"
))
agent_infos = samples_data["agent_infos"]
state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
all_input_values += tuple(state_info_list) + tuple(dist_info_list)
if self.policy.recurrent:
all_input_values += (samples_data["valids"],)
logger.log("Computing loss before")
loss_before = self.optimizer.loss(all_input_values)
logger.log("Computing KL before")
mean_kl_before = self.optimizer.constraint_val(all_input_values)
logger.log("Optimizing")
self.optimizer.optimize(all_input_values)
logger.log("Computing KL after")
mean_kl = self.optimizer.constraint_val(all_input_values)
logger.log("Computing loss after")
loss_after = self.optimizer.loss(all_input_values)
logger.record_tabular('LossBefore', loss_before)
logger.record_tabular('LossAfter', loss_after)
logger.record_tabular('MeanKLBefore', mean_kl_before)
logger.record_tabular('MeanKL', mean_kl)
logger.record_tabular('dLoss', loss_before - loss_after)
return dict()

@overrides
def get_itr_snapshot(self, itr, samples_data):
return dict(
itr=itr,
policy=self.policy,
baseline=self.baseline,
env=self.env,
)
21 changes: 21 additions & 0 deletions rllab/sandbox/rocky/tf/algos/trpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@


from sandbox.rocky.tf.algos.npo import NPO
from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer


class TRPO(NPO):
"""
Trust Region Policy Optimization
"""

def __init__(
self,
optimizer=None,
optimizer_args=None,
**kwargs):
if optimizer is None:
if optimizer_args is None:
optimizer_args = dict()
optimizer = ConjugateGradientOptimizer(**optimizer_args)
super(TRPO, self).__init__(optimizer=optimizer, **kwargs)
Loading

0 comments on commit 455782c

Please sign in to comment.