diff --git a/hw2/train_pg_f18.py b/hw2/train_pg_f18.py index ecca81964..c46440407 100644 --- a/hw2/train_pg_f18.py +++ b/hw2/train_pg_f18.py @@ -10,7 +10,6 @@ import os import time import inspect -from multiprocessing import Process #============================================================================================# # Utilities @@ -81,6 +80,9 @@ def init_tf_sess(self): self.sess.__enter__() # equivalent to `with self.sess:` tf.global_variables_initializer().run() #pylint: disable=E1101 + def close_tf_sess(self): + self.sess.__exit__(None, None, None) + tf.reset_default_graph() #========================================================================================# # ----------PROBLEM 2---------- #========================================================================================# @@ -629,6 +631,7 @@ def train_PG( logz.dump_tabular() logz.pickle_tf_vars() + agent.close_tf_sess() def main(): import argparse @@ -659,41 +662,28 @@ def main(): max_path_length = args.ep_len if args.ep_len > 0 else None - processes = [] - for e in range(args.n_experiments): seed = args.seed + 10*e print('Running experiment with seed %d'%seed) - def train_func(): - train_PG( - exp_name=args.exp_name, - env_name=args.env_name, - n_iter=args.n_iter, - gamma=args.discount, - min_timesteps_per_batch=args.batch_size, - max_path_length=max_path_length, - learning_rate=args.learning_rate, - reward_to_go=args.reward_to_go, - animate=args.render, - logdir=os.path.join(logdir,'%d'%seed), - normalize_advantages=not(args.dont_normalize_advantages), - nn_baseline=args.nn_baseline, - seed=seed, - n_layers=args.n_layers, - size=args.size - ) - # # Awkward hacky process runs, because Tensorflow does not like - # # repeatedly calling train_PG in the same thread. - p = Process(target=train_func, args=tuple()) - p.start() - processes.append(p) - # if you comment in the line below, then the loop will block - # until this process finishes - # p.join() - - for p in processes: - p.join() + + train_PG( + exp_name=args.exp_name, + env_name=args.env_name, + n_iter=args.n_iter, + gamma=args.discount, + min_timesteps_per_batch=args.batch_size, + max_path_length=max_path_length, + learning_rate=args.learning_rate, + reward_to_go=args.reward_to_go, + animate=args.render, + logdir=os.path.join(logdir,'%d'%seed), + normalize_advantages=not(args.dont_normalize_advantages), + nn_baseline=args.nn_baseline, + seed=seed, + n_layers=args.n_layers, + size=args.size + ) if __name__ == "__main__": main()