diff --git a/examples/train.py b/examples/train.py index f889ac9b6..959596697 100644 --- a/examples/train.py +++ b/examples/train.py @@ -21,6 +21,7 @@ from flow.utils.registry import env_constructor from flow.utils.rllib import FlowParamsEncoder, get_flow_params from flow.utils.registry import make_create_env +from flow.visualize.i210_replay import create_parser, generate_graphs def parse_args(args): @@ -87,6 +88,11 @@ def parse_args(args): parser.add_argument('--multi_node', action='store_true', help='Set to true if this will be run in cluster mode.' 'Relevant for rllib') + parser.add_argument( + '--upload_graphs', type=str, nargs=2, + help='Whether to generate and upload graphs to leaderboard at the end of training.' + 'Arguments are name of the submitter and name of the strategy.' + 'Only relevant for i210 training on rllib') return parser.parse_known_args(args)[0] @@ -376,6 +382,49 @@ def trial_str_creator(trial): exp_dict['upload_dir'] = s3_string tune.run(**exp_dict, queue_trials=False, raise_on_failed_trial=False) + if flags.upload_graphs: + print('Generating experiment graphs and uploading them to leaderboard') + submitter_name, strategy_name = flags.upload_graphs + + # reset ray + ray.shutdown() + if flags.local_mode: + ray.init(local_mode=True) + else: + ray.init() + + # grab checkpoint path + for (dirpath, _, _) in os.walk(os.path.expanduser("~/ray_results")): + if "checkpoint_{}".format(flags.checkpoint_freq) in dirpath \ + and dirpath.split('/')[-3] == flags.exp_title: + checkpoint_path = os.path.dirname(dirpath) + checkpoint_number = -1 + for name in os.listdir(checkpoint_path): + if name.startswith('checkpoint'): + cp = int(name.split('_')[1]) + checkpoint_number = max(checkpoint_number, cp) + + # create dir for graphs output + output_dir = os.path.join(checkpoint_path, 'output_graphs') + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + # run graph generation script + parser = create_parser() + + strategy_name_full = str(strategy_name) + if flags.grid_search: + strategy_name_full += '__' + dirpath.split('/')[-2] + + args = parser.parse_args([ + '-r', checkpoint_path, '-c', str(checkpoint_number), + '--gen_emission', '--use_s3', '--num_cpus', str(flags.num_cpus), + '--output_dir', output_dir, + '--submitter_name', submitter_name, + '--strategy_name', strategy_name_full.replace(',', '_').replace(';', '_') + ]) + generate_graphs(args) + def train_h_baselines(env_name, args, multiagent): """Train policies using SAC and TD3 with h-baselines.""" diff --git a/flow/data_pipeline/data_pipeline.py b/flow/data_pipeline/data_pipeline.py index f0e3637f6..8f73e7e5b 100644 --- a/flow/data_pipeline/data_pipeline.py +++ b/flow/data_pipeline/data_pipeline.py @@ -99,7 +99,7 @@ def get_extra_info(veh_kernel, extra_info, veh_ids, source_id, run_id): extra_info["run_id"].append(run_id) -def get_configuration(): +def get_configuration(submitter_name=None, strategy_name=None): """Get configuration for the metadata table.""" try: config_df = pd.read_csv('./data_pipeline_config') @@ -107,13 +107,19 @@ def get_configuration(): config_df = pd.DataFrame(data={"submitter_name": [""], "strategy": [""]}) if not config_df['submitter_name'][0]: - name = input("Please enter your name:").strip() - while not name: - name = input("Please enter a non-empty name:").strip() + if submitter_name: + name = submitter_name + else: + name = input("Please enter your name:").strip() + while not name: + name = input("Please enter a non-empty name:").strip() config_df['submitter_name'] = [name] - strategy = input( - "Please enter strategy name (current: \"{}\"):".format(config_df["strategy"][0])).strip() + if strategy_name: + strategy = strategy_name + else: + strategy = input( + "Please enter strategy name (current: \"{}\"):".format(config_df["strategy"][0])).strip() if strategy: config_df['strategy'] = [strategy] diff --git a/flow/utils/rllib.py b/flow/utils/rllib.py index fc3229e52..db0e811b8 100644 --- a/flow/utils/rllib.py +++ b/flow/utils/rllib.py @@ -8,6 +8,7 @@ import os import sys +import flow.config import flow.envs from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams, \ SumoParams, InitialConfig, EnvParams, NetParams, InFlows @@ -149,8 +150,7 @@ def get_flow_params(config): net.inflows.__dict__ = flow_params["net"]["inflows"].copy() if net.template is not None and len(net.template) > 0: - dirname = os.getcwd() - filename = os.path.join(dirname, '../../examples') + filename = os.path.join(flow.config.PROJECT_PATH, 'examples') split = net.template.split('examples')[1][1:] path = os.path.abspath(os.path.join(filename, split)) net.template = path diff --git a/flow/visualize/i210_replay.py b/flow/visualize/i210_replay.py index 4c7498413..23ef5fdd4 100644 --- a/flow/visualize/i210_replay.py +++ b/flow/visualize/i210_replay.py @@ -6,6 +6,7 @@ import numpy as np import json import os +import os.path import pytz import subprocess import time @@ -241,8 +242,8 @@ def replay(args, flow_params, output_dir=None, transfer_test=None, rllib_config= metadata['submission_time'].append(cur_time) metadata['network'].append(network_name_translate(env.network.name.split('_20')[0])) metadata['is_baseline'].append(str(args.is_baseline)) - if args.to_aws: - name, strategy = get_configuration() + if args.use_s3: + name, strategy = get_configuration(args.submitter_name, args.strategy_name) metadata['submitter_name'].append(name) metadata['strategy'].append(strategy) @@ -362,8 +363,12 @@ def replay(args, flow_params, output_dir=None, transfer_test=None, rllib_config= '{0}/test_time_rollout/{1}'.format(dir_path, emission_filename) output_path = os.path.join(output_dir, '{}-emission.csv'.format(exp_name)) - # convert the emission file into a csv file - emission_to_csv(emission_path, output_path=output_path) + if os.path.exists(emission_path.replace('emission.xml', '0_emission.csv')): + # csv already exists + os.rename(emission_path.replace('emission.xml', '0_emission.csv'), output_path) + else: + # convert the emission file into a csv file + emission_to_csv(emission_path, output_path=output_path) # generate the trajectory output file trajectory_table_path = os.path.join(dir_path, '{}.csv'.format(source_id)) @@ -384,7 +389,8 @@ def replay(args, flow_params, output_dir=None, transfer_test=None, rllib_config= print("\nGenerated emission file at " + output_path) # delete the .xml version of the emission file - os.remove(emission_path) + if os.path.exists(emission_path): + os.remove(emission_path) all_trip_energies = os.path.join(output_dir, '{}-all_trip_energies.npy'.format(exp_name)) np.save(all_trip_energies, dict(all_trip_energy_distribution)) @@ -500,16 +506,20 @@ def create_parser(): action='store_true', help='specifies whether this is a baseline run' ) + parser.add_argument('--submitter_name', type=str, required=False, default=None, + help='Name of the submitter (replaces the prompt asking for ' + 'the name) (stored locally so only necessary once)') + parser.add_argument('--strategy_name', type=str, required=False, default=None, + help='Name of the training strategy (replaces the prompt ' + 'asking for the strategy)') return parser -if __name__ == '__main__': +def generate_graphs(args): + """Generate the graphs.""" date = datetime.now(tz=pytz.utc) date = date.astimezone(pytz.timezone('US/Pacific')).strftime("%m-%d-%Y") - parser = create_parser() - args = parser.parse_args() - rllib_config = None rllib_result_dir = None if args.rllib_result_dir is not None: @@ -520,12 +530,13 @@ def create_parser(): flow_params = deepcopy(I210_MA_DEFAULT_FLOW_PARAMS) - if args.multi_node: - ray.init(redis_address='localhost:6379') - elif args.local: - ray.init(local_mode=True, object_store_memory=200 * 1024 * 1024) - else: - ray.init(num_cpus=args.num_cpus + 1, object_store_memory=200 * 1024 * 1024) + if not ray.is_initialized(): + if args.multi_node: + ray.init(redis_address='localhost:6379') + elif args.local: + ray.init(local_mode=True, object_store_memory=200 * 1024 * 1024) + else: + ray.init(num_cpus=args.num_cpus + 1, object_store_memory=200 * 1024 * 1024) if args.exp_title: output_dir = os.path.join(args.output_dir, args.exp_title) @@ -573,3 +584,10 @@ def create_parser(): p1.wait(50) except Exception as e: print('This is the error ', e) + + +if __name__ == '__main__': + parser = create_parser() + args = parser.parse_args() + + generate_graphs(args)