From f429a9c1cc9871eafb1e75ba909b3da8ca36d145 Mon Sep 17 00:00:00 2001 From: ReykCS Date: Mon, 30 Jan 2023 11:42:39 +0100 Subject: [PATCH] updated versions in rosinstall | added seconds to model name | removed wandb because its missing in stable baselines branch --- .rosinstall | 4 ++-- arena_bringup/launch/start_arena.launch | 11 ++++++++--- arena_bringup/launch/start_training.launch | 9 +++++++-- training/configs/training_config.yaml | 3 +-- training/scripts/train_agent.py | 11 ++++++++++- training/tools/general.py | 2 +- training/tools/model_utils.py | 5 ++--- 7 files changed, 31 insertions(+), 14 deletions(-) diff --git a/.rosinstall b/.rosinstall index d969ff16f..86e841d48 100755 --- a/.rosinstall +++ b/.rosinstall @@ -38,7 +38,7 @@ - git: local-name: ../utils/arena-utils uri: https://github.com/Arena-Rosnav/arena-utils.git - version: v2.1.0 + version: v2.2.0 - git: local-name: ../utils/task-generator @@ -55,7 +55,7 @@ - git: local-name: ../planners/rosnav uri: https://github.com/Arena-Rosnav/rosnav.git - version: v1.1.1 + version: v1.1.2 - git: local-name: ../planners/arena-ros diff --git a/arena_bringup/launch/start_arena.launch b/arena_bringup/launch/start_arena.launch index 29f04204f..f83f6424e 100755 --- a/arena_bringup/launch/start_arena.launch +++ b/arena_bringup/launch/start_arena.launch @@ -9,14 +9,19 @@ - + + + + - + + + - + diff --git a/arena_bringup/launch/start_training.launch b/arena_bringup/launch/start_training.launch index e8152c06e..81b69eef2 100644 --- a/arena_bringup/launch/start_training.launch +++ b/arena_bringup/launch/start_training.launch @@ -5,12 +5,17 @@ - + + + + - + + + diff --git a/training/configs/training_config.yaml b/training/configs/training_config.yaml index 499cbaab3..f64942098 100644 --- a/training/configs/training_config.yaml +++ b/training/configs/training_config.yaml @@ -10,7 +10,7 @@ no_gpu: false ### Training Monitoring monitoring: # weights and biases logging - use_wandb: true + use_wandb: false # save evaluation stats during training in log file eval_log: false @@ -71,4 +71,3 @@ rl_agent: m_batch_size: 20 n_epochs: 3 clip_range: 0.22 - diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py index aa6aefbf3..90c50e3ec 100644 --- a/training/scripts/train_agent.py +++ b/training/scripts/train_agent.py @@ -9,9 +9,14 @@ from tools.model_utils import init_callbacks, get_ppo_instance from tools.env_utils import init_envs +def on_shutdown(model): + model.env.close() + sys.exit() + def main(): args, _ = parse_training_args() + config = load_config(args.config) populate_ros_configs(config) @@ -47,7 +52,11 @@ def main(): eval_cb = init_callbacks(config, train_env, eval_env, PATHS) model = get_ppo_instance(config, train_env, PATHS, AgentFactory) - rospy.on_shutdown(model.env.close()) + rospy.on_shutdown(lambda: on_shutdown(model)) + + ## Save model once + if not config["debug_mode"]: + model.save(os.path.join(PATHS["model"], "best_model")) # start training start = time.time() diff --git a/training/tools/general.py b/training/tools/general.py index c4b1ed528..782fdde87 100644 --- a/training/tools/general.py +++ b/training/tools/general.py @@ -220,7 +220,7 @@ def generate_agent_name(config: dict) -> str: :param config (dict): Dict containing the program arguments """ if config["rl_agent"]["resume"] is None: - START_TIME = dt.now().strftime("%Y_%m_%d__%H_%M") + START_TIME = dt.now().strftime("%Y_%m_%d__%H_%M_%S") robot_model = rospy.get_param("robot_model") architecture_name, encoder_name = config["rl_agent"][ "architecture_name" diff --git a/training/tools/model_utils.py b/training/tools/model_utils.py index 466d0a0c8..3e5913c22 100644 --- a/training/tools/model_utils.py +++ b/training/tools/model_utils.py @@ -1,4 +1,5 @@ import os +import sys from typing import Union, Type import wandb @@ -144,9 +145,7 @@ def instantiate_new_model( "n_epochs": ppo_config["n_epochs"], "clip_range": ppo_config["clip_range"], "tensorboard_log": PATHS["tb"], - "use_wandb": False - if config["debug_mode"] - else config["monitoring"]["use_wandb"], + # "use_wandb": False if config["debug_mode"] else config["monitoring"]["use_wandb"], "verbose": 1, }