Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Lantao Yu committed Dec 4, 2019
1 parent 79c4d2d commit 2276744
Show file tree
Hide file tree
Showing 74 changed files with 5,813 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,27 @@
Lantao Yu*, Tianhe Yu*, Chelsea Finn, Stefano Ermon.<br>
The 33rd Conference on Neural Information Processing Systems. (NeurIPS 2019)<br>
[[Paper]](https://arxiv.org/pdf/1909.09314.pdf) [[Website]](https://sites.google.com/view/pemirl)

### Usage
Requirement: The rllab package used in this project is provided [here](https://github.com/ermongroup/MetaIRL/tree/master/rllab).

To get expert trajectories for downstream tasks:
```
python scripts/maze_data_collect.py
```

After getting expert trajectories, run Meta-Inverse RL to learn context dependent reward functions:
```
python scripts/maze_wall_meta_irl.py
```
We provided a pretrained IRL model [here](https://github.com/ermongroup/MetaIRL/tree/master/data_fusion_discrete/maze_wall_meta_irl_imitcoeff-0.01_infocoeff-0.1_mbs-50_bs-16_itr-20_preepoch-1000_entropy-1.0_RandomPol_Rew-2-32/2019_05_14_02_33_17_0), which will be loaded by the following codes by default.

To visualize the context-dependent reward function (Figure 2 in the paper):
```
python scripts/maze_visualize_reward.py
```

To use the context-dependent reward function to train a new policy under new dynamics:
```
python scripts/maze_wall_meta_irl_test.py
```
Binary file not shown.
Empty file added inverse_rl/__init__.py
Empty file.
Binary file added inverse_rl/__pycache__/__init__.cpython-35.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added inverse_rl/algos/__pycache__/irl_trpo.cpython-35.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added inverse_rl/algos/__pycache__/npo.cpython-35.pyc
Binary file not shown.
Binary file not shown.
Binary file added inverse_rl/algos/__pycache__/trpo.cpython-35.pyc
Binary file not shown.
161 changes: 161 additions & 0 deletions inverse_rl/algos/batch_polopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import time
from rllab.algos.base import RLAlgorithm
import rllab.misc.logger as logger
from sandbox.rocky.tf.policies.base import Policy
import tensorflow as tf
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler
from rllab.sampler.utils import rollout

from inverse_rl.utils.hyperparametrized import Hyperparametrized


class BatchPolopt(RLAlgorithm, metaclass=Hyperparametrized):
"""
Base class for batch sampling-based policy optimization methods.
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
"""

def __init__(
self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs
):
"""
:param env: Environment
:param policy: Policy
:type policy: Policy
:param baseline: Baseline
:param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
simultaneously, each using different environments and policies
:param n_itr: Number of iterations.
:param start_itr: Starting iteration.
:param batch_size: Number of samples per iteration.
:param max_path_length: Maximum length of a single rollout.
:param discount: Discount.
:param gae_lambda: Lambda used for generalized advantage estimation.
:param plot: Plot evaluation run after each iteration.
:param pause_for_plot: Whether to pause before contiuing when plotting.
:param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
:param positive_adv: Whether to shift the advantages so that they are always positive. When used in
conjunction with center_adv the advantages will be standardized before shifting.
:param store_paths: Whether to save all paths data to the snapshot.
:return:
"""
self.env = env
self.policy = policy
self.baseline = baseline
self.scope = scope
self.n_itr = n_itr
self.start_itr = start_itr
self.batch_size = batch_size
self.max_path_length = max_path_length
self.discount = discount
self.gae_lambda = gae_lambda
self.plot = plot
self.pause_for_plot = pause_for_plot
self.center_adv = center_adv
self.positive_adv = positive_adv
self.store_paths = store_paths
self.whole_paths = whole_paths
self.fixed_horizon = fixed_horizon
if sampler_cls is None:
if self.policy.vectorized and not force_batch_sampler:
sampler_cls = VectorizedSampler
else:
sampler_cls = BatchSampler
if sampler_args is None:
sampler_args = dict()
self.sampler = sampler_cls(self, **sampler_args)
self.init_opt()

def start_worker(self):
self.sampler.start_worker()

def shutdown_worker(self):
self.sampler.shutdown_worker()

def obtain_samples(self, itr):
return self.sampler.obtain_samples(itr)

def process_samples(self, itr, paths):
return self.sampler.process_samples(itr, paths)

def train(self, sess=None):
created_session = True if (sess is None) else False
if sess is None:
sess = tf.Session()
sess.__enter__()

sess.run(tf.global_variables_initializer())
self.start_worker()
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
if created_session:
sess.close()

def log_diagnostics(self, paths):
self.env.log_diagnostics(paths)
self.policy.log_diagnostics(paths)
self.baseline.log_diagnostics(paths)

def init_opt(self):
"""
Initialize the optimization procedure. If using tensorflow, this may
include declaring all the variables and compiling functions
"""
raise NotImplementedError

def get_itr_snapshot(self, itr, samples_data):
"""
Returns all the data that should be saved in the snapshot for this
iteration.
"""
raise NotImplementedError

def optimize_policy(self, itr, samples_data):
raise NotImplementedError
Loading

0 comments on commit 2276744

Please sign in to comment.