-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Lantao Yu
committed
Dec 4, 2019
1 parent
79c4d2d
commit 2276744
Showing
74 changed files
with
5,813 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+11.8 MB
...16_itr-20_preepoch-1000_entropy-1.0_RandomPol_Rew-2-32/2019_05_14_02_33_17_0/itr_2800.pkl
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import time | ||
from rllab.algos.base import RLAlgorithm | ||
import rllab.misc.logger as logger | ||
from sandbox.rocky.tf.policies.base import Policy | ||
import tensorflow as tf | ||
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler | ||
from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler | ||
from rllab.sampler.utils import rollout | ||
|
||
from inverse_rl.utils.hyperparametrized import Hyperparametrized | ||
|
||
|
||
class BatchPolopt(RLAlgorithm, metaclass=Hyperparametrized): | ||
""" | ||
Base class for batch sampling-based policy optimization methods. | ||
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
env, | ||
policy, | ||
baseline, | ||
scope=None, | ||
n_itr=500, | ||
start_itr=0, | ||
batch_size=5000, | ||
max_path_length=500, | ||
discount=0.99, | ||
gae_lambda=1, | ||
plot=False, | ||
pause_for_plot=False, | ||
center_adv=True, | ||
positive_adv=False, | ||
store_paths=False, | ||
whole_paths=True, | ||
fixed_horizon=False, | ||
sampler_cls=None, | ||
sampler_args=None, | ||
force_batch_sampler=False, | ||
**kwargs | ||
): | ||
""" | ||
:param env: Environment | ||
:param policy: Policy | ||
:type policy: Policy | ||
:param baseline: Baseline | ||
:param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms | ||
simultaneously, each using different environments and policies | ||
:param n_itr: Number of iterations. | ||
:param start_itr: Starting iteration. | ||
:param batch_size: Number of samples per iteration. | ||
:param max_path_length: Maximum length of a single rollout. | ||
:param discount: Discount. | ||
:param gae_lambda: Lambda used for generalized advantage estimation. | ||
:param plot: Plot evaluation run after each iteration. | ||
:param pause_for_plot: Whether to pause before contiuing when plotting. | ||
:param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1. | ||
:param positive_adv: Whether to shift the advantages so that they are always positive. When used in | ||
conjunction with center_adv the advantages will be standardized before shifting. | ||
:param store_paths: Whether to save all paths data to the snapshot. | ||
:return: | ||
""" | ||
self.env = env | ||
self.policy = policy | ||
self.baseline = baseline | ||
self.scope = scope | ||
self.n_itr = n_itr | ||
self.start_itr = start_itr | ||
self.batch_size = batch_size | ||
self.max_path_length = max_path_length | ||
self.discount = discount | ||
self.gae_lambda = gae_lambda | ||
self.plot = plot | ||
self.pause_for_plot = pause_for_plot | ||
self.center_adv = center_adv | ||
self.positive_adv = positive_adv | ||
self.store_paths = store_paths | ||
self.whole_paths = whole_paths | ||
self.fixed_horizon = fixed_horizon | ||
if sampler_cls is None: | ||
if self.policy.vectorized and not force_batch_sampler: | ||
sampler_cls = VectorizedSampler | ||
else: | ||
sampler_cls = BatchSampler | ||
if sampler_args is None: | ||
sampler_args = dict() | ||
self.sampler = sampler_cls(self, **sampler_args) | ||
self.init_opt() | ||
|
||
def start_worker(self): | ||
self.sampler.start_worker() | ||
|
||
def shutdown_worker(self): | ||
self.sampler.shutdown_worker() | ||
|
||
def obtain_samples(self, itr): | ||
return self.sampler.obtain_samples(itr) | ||
|
||
def process_samples(self, itr, paths): | ||
return self.sampler.process_samples(itr, paths) | ||
|
||
def train(self, sess=None): | ||
created_session = True if (sess is None) else False | ||
if sess is None: | ||
sess = tf.Session() | ||
sess.__enter__() | ||
|
||
sess.run(tf.global_variables_initializer()) | ||
self.start_worker() | ||
start_time = time.time() | ||
for itr in range(self.start_itr, self.n_itr): | ||
itr_start_time = time.time() | ||
with logger.prefix('itr #%d | ' % itr): | ||
logger.log("Obtaining samples...") | ||
paths = self.obtain_samples(itr) | ||
logger.log("Processing samples...") | ||
samples_data = self.process_samples(itr, paths) | ||
logger.log("Logging diagnostics...") | ||
self.log_diagnostics(paths) | ||
logger.log("Optimizing policy...") | ||
self.optimize_policy(itr, samples_data) | ||
logger.log("Saving snapshot...") | ||
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs) | ||
if self.store_paths: | ||
params["paths"] = samples_data["paths"] | ||
logger.save_itr_params(itr, params) | ||
logger.log("Saved") | ||
logger.record_tabular('Time', time.time() - start_time) | ||
logger.record_tabular('ItrTime', time.time() - itr_start_time) | ||
logger.dump_tabular(with_prefix=False) | ||
if self.plot: | ||
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length) | ||
if self.pause_for_plot: | ||
input("Plotting evaluation run: Press Enter to " | ||
"continue...") | ||
self.shutdown_worker() | ||
if created_session: | ||
sess.close() | ||
|
||
def log_diagnostics(self, paths): | ||
self.env.log_diagnostics(paths) | ||
self.policy.log_diagnostics(paths) | ||
self.baseline.log_diagnostics(paths) | ||
|
||
def init_opt(self): | ||
""" | ||
Initialize the optimization procedure. If using tensorflow, this may | ||
include declaring all the variables and compiling functions | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_itr_snapshot(self, itr, samples_data): | ||
""" | ||
Returns all the data that should be saved in the snapshot for this | ||
iteration. | ||
""" | ||
raise NotImplementedError | ||
|
||
def optimize_policy(self, itr, samples_data): | ||
raise NotImplementedError |
Oops, something went wrong.