-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
64 changed files
with
7,570 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,4 +36,3 @@ blackbox.zip | |
blackbox | ||
rllab/config_personal.py | ||
*.swp | ||
sandbox |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import time | ||
from rllab.algos.base import RLAlgorithm | ||
import rllab.misc.logger as logger | ||
from sandbox.rocky.tf.policies.base import Policy | ||
import tensorflow as tf | ||
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler | ||
from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler | ||
from rllab.sampler.utils import rollout | ||
|
||
|
||
class BatchPolopt(RLAlgorithm): | ||
""" | ||
Base class for batch sampling-based policy optimization methods. | ||
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
env, | ||
policy, | ||
baseline, | ||
scope=None, | ||
n_itr=500, | ||
start_itr=0, | ||
batch_size=5000, | ||
max_path_length=500, | ||
discount=0.99, | ||
gae_lambda=1, | ||
plot=False, | ||
pause_for_plot=False, | ||
center_adv=True, | ||
positive_adv=False, | ||
store_paths=False, | ||
whole_paths=True, | ||
fixed_horizon=False, | ||
sampler_cls=None, | ||
sampler_args=None, | ||
force_batch_sampler=False, | ||
**kwargs | ||
): | ||
""" | ||
:param env: Environment | ||
:param policy: Policy | ||
:type policy: Policy | ||
:param baseline: Baseline | ||
:param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms | ||
simultaneously, each using different environments and policies | ||
:param n_itr: Number of iterations. | ||
:param start_itr: Starting iteration. | ||
:param batch_size: Number of samples per iteration. | ||
:param max_path_length: Maximum length of a single rollout. | ||
:param discount: Discount. | ||
:param gae_lambda: Lambda used for generalized advantage estimation. | ||
:param plot: Plot evaluation run after each iteration. | ||
:param pause_for_plot: Whether to pause before contiuing when plotting. | ||
:param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1. | ||
:param positive_adv: Whether to shift the advantages so that they are always positive. When used in | ||
conjunction with center_adv the advantages will be standardized before shifting. | ||
:param store_paths: Whether to save all paths data to the snapshot. | ||
:return: | ||
""" | ||
self.env = env | ||
self.policy = policy | ||
self.baseline = baseline | ||
self.scope = scope | ||
self.n_itr = n_itr | ||
self.start_itr = start_itr | ||
self.batch_size = batch_size | ||
self.max_path_length = max_path_length | ||
self.discount = discount | ||
self.gae_lambda = gae_lambda | ||
self.plot = plot | ||
self.pause_for_plot = pause_for_plot | ||
self.center_adv = center_adv | ||
self.positive_adv = positive_adv | ||
self.store_paths = store_paths | ||
self.whole_paths = whole_paths | ||
self.fixed_horizon = fixed_horizon | ||
if sampler_cls is None: | ||
if self.policy.vectorized and not force_batch_sampler: | ||
sampler_cls = VectorizedSampler | ||
else: | ||
sampler_cls = BatchSampler | ||
if sampler_args is None: | ||
sampler_args = dict() | ||
self.sampler = sampler_cls(self, **sampler_args) | ||
self.init_opt() | ||
|
||
def start_worker(self): | ||
self.sampler.start_worker() | ||
|
||
def shutdown_worker(self): | ||
self.sampler.shutdown_worker() | ||
|
||
def obtain_samples(self, itr): | ||
return self.sampler.obtain_samples(itr) | ||
|
||
def process_samples(self, itr, paths): | ||
return self.sampler.process_samples(itr, paths) | ||
|
||
def train(self, sess=None): | ||
created_session = True if (sess is None) else False | ||
if sess is None: | ||
sess = tf.Session() | ||
sess.__enter__() | ||
|
||
sess.run(tf.global_variables_initializer()) | ||
self.start_worker() | ||
start_time = time.time() | ||
for itr in range(self.start_itr, self.n_itr): | ||
itr_start_time = time.time() | ||
with logger.prefix('itr #%d | ' % itr): | ||
logger.log("Obtaining samples...") | ||
paths = self.obtain_samples(itr) | ||
logger.log("Processing samples...") | ||
samples_data = self.process_samples(itr, paths) | ||
logger.log("Logging diagnostics...") | ||
self.log_diagnostics(paths) | ||
logger.log("Optimizing policy...") | ||
self.optimize_policy(itr, samples_data) | ||
logger.log("Saving snapshot...") | ||
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs) | ||
if self.store_paths: | ||
params["paths"] = samples_data["paths"] | ||
logger.save_itr_params(itr, params) | ||
logger.log("Saved") | ||
logger.record_tabular('Time', time.time() - start_time) | ||
logger.record_tabular('ItrTime', time.time() - itr_start_time) | ||
logger.dump_tabular(with_prefix=False) | ||
if self.plot: | ||
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length) | ||
if self.pause_for_plot: | ||
input("Plotting evaluation run: Press Enter to " | ||
"continue...") | ||
self.shutdown_worker() | ||
if created_session: | ||
sess.close() | ||
|
||
def log_diagnostics(self, paths): | ||
self.env.log_diagnostics(paths) | ||
self.policy.log_diagnostics(paths) | ||
self.baseline.log_diagnostics(paths) | ||
|
||
def init_opt(self): | ||
""" | ||
Initialize the optimization procedure. If using tensorflow, this may | ||
include declaring all the variables and compiling functions | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_itr_snapshot(self, itr, samples_data): | ||
""" | ||
Returns all the data that should be saved in the snapshot for this | ||
iteration. | ||
""" | ||
raise NotImplementedError | ||
|
||
def optimize_policy(self, itr, samples_data): | ||
raise NotImplementedError | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
|
||
|
||
|
||
from rllab.misc import ext | ||
from rllab.misc.overrides import overrides | ||
import rllab.misc.logger as logger | ||
from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer | ||
from sandbox.rocky.tf.algos.batch_polopt import BatchPolopt | ||
from sandbox.rocky.tf.misc import tensor_utils | ||
import tensorflow as tf | ||
|
||
|
||
class NPO(BatchPolopt): | ||
""" | ||
Natural Policy Optimization. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer=None, | ||
optimizer_args=None, | ||
step_size=0.01, | ||
**kwargs): | ||
if optimizer is None: | ||
if optimizer_args is None: | ||
optimizer_args = dict() | ||
optimizer = PenaltyLbfgsOptimizer(**optimizer_args) | ||
self.optimizer = optimizer | ||
self.step_size = step_size | ||
super(NPO, self).__init__(**kwargs) | ||
|
||
@overrides | ||
def init_opt(self): | ||
is_recurrent = int(self.policy.recurrent) | ||
obs_var = self.env.observation_space.new_tensor_variable( | ||
'obs', | ||
extra_dims=1 + is_recurrent, | ||
) | ||
action_var = self.env.action_space.new_tensor_variable( | ||
'action', | ||
extra_dims=1 + is_recurrent, | ||
) | ||
advantage_var = tensor_utils.new_tensor( | ||
'advantage', | ||
ndim=1 + is_recurrent, | ||
dtype=tf.float32, | ||
) | ||
dist = self.policy.distribution | ||
|
||
old_dist_info_vars = { | ||
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) | ||
for k, shape in dist.dist_info_specs | ||
} | ||
old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys] | ||
|
||
state_info_vars = { | ||
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) | ||
for k, shape in self.policy.state_info_specs | ||
} | ||
state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys] | ||
|
||
if is_recurrent: | ||
valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") | ||
else: | ||
valid_var = None | ||
|
||
dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) | ||
kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) | ||
lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) | ||
if is_recurrent: | ||
mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var) | ||
surr_loss = - tf.reduce_sum(lr * advantage_var * valid_var) / tf.reduce_sum(valid_var) | ||
else: | ||
mean_kl = tf.reduce_mean(kl) | ||
surr_loss = - tf.reduce_mean(lr * advantage_var) | ||
|
||
input_list = [ | ||
obs_var, | ||
action_var, | ||
advantage_var, | ||
] + state_info_vars_list + old_dist_info_vars_list | ||
if is_recurrent: | ||
input_list.append(valid_var) | ||
|
||
self.optimizer.update_opt( | ||
loss=surr_loss, | ||
target=self.policy, | ||
leq_constraint=(mean_kl, self.step_size), | ||
inputs=input_list, | ||
constraint_name="mean_kl" | ||
) | ||
return dict() | ||
|
||
@overrides | ||
def optimize_policy(self, itr, samples_data): | ||
all_input_values = tuple(ext.extract( | ||
samples_data, | ||
"observations", "actions", "advantages" | ||
)) | ||
agent_infos = samples_data["agent_infos"] | ||
state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] | ||
dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys] | ||
all_input_values += tuple(state_info_list) + tuple(dist_info_list) | ||
if self.policy.recurrent: | ||
all_input_values += (samples_data["valids"],) | ||
logger.log("Computing loss before") | ||
loss_before = self.optimizer.loss(all_input_values) | ||
logger.log("Computing KL before") | ||
mean_kl_before = self.optimizer.constraint_val(all_input_values) | ||
logger.log("Optimizing") | ||
self.optimizer.optimize(all_input_values) | ||
logger.log("Computing KL after") | ||
mean_kl = self.optimizer.constraint_val(all_input_values) | ||
logger.log("Computing loss after") | ||
loss_after = self.optimizer.loss(all_input_values) | ||
logger.record_tabular('LossBefore', loss_before) | ||
logger.record_tabular('LossAfter', loss_after) | ||
logger.record_tabular('MeanKLBefore', mean_kl_before) | ||
logger.record_tabular('MeanKL', mean_kl) | ||
logger.record_tabular('dLoss', loss_before - loss_after) | ||
return dict() | ||
|
||
@overrides | ||
def get_itr_snapshot(self, itr, samples_data): | ||
return dict( | ||
itr=itr, | ||
policy=self.policy, | ||
baseline=self.baseline, | ||
env=self.env, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
|
||
from sandbox.rocky.tf.algos.npo import NPO | ||
from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer | ||
|
||
|
||
class TRPO(NPO): | ||
""" | ||
Trust Region Policy Optimization | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer=None, | ||
optimizer_args=None, | ||
**kwargs): | ||
if optimizer is None: | ||
if optimizer_args is None: | ||
optimizer_args = dict() | ||
optimizer = ConjugateGradientOptimizer(**optimizer_args) | ||
super(TRPO, self).__init__(optimizer=optimizer, **kwargs) |
Oops, something went wrong.