From 9633f547db606a2d61825bbaae53c98ace98590b Mon Sep 17 00:00:00 2001 From: cliebig2019 Date: Thu, 21 Jul 2022 12:24:33 +0200 Subject: [PATCH 1/2] custom neural network from json --- training/scripts/train_agent.py | 2 +- training/tools/argsparser.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py index 447787e2f..c03cc9f19 100755 --- a/training/scripts/train_agent.py +++ b/training/scripts/train_agent.py @@ -133,7 +133,7 @@ def main(): verbose=1, ) elif args.agent is not None: - agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent) + agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent, path = args.path) if isinstance(agent, BaseAgent): model = PPO( agent.type.value, diff --git a/training/tools/argsparser.py b/training/tools/argsparser.py index 46645e7e2..cb9110a97 100644 --- a/training/tools/argsparser.py +++ b/training/tools/argsparser.py @@ -23,6 +23,7 @@ def training_args(parser): import rosnav.model.custom_policy import rosnav.model.custom_sb3_policy from rosnav.model.agent_factory import AgentFactory + import rosnav.model.custom_policy_from_json group.add_argument( "--agent", @@ -60,6 +61,11 @@ def training_args(parser): parser.add_argument( "--tb", action="store_true", help="enables tensorboard logging" ) + parser.add_argument( + "--path", + type=str, + help="path to the json file containing" "the neural network", + ) def run_agent_args(parser): From fc4a1af04706c2a73002c27ce3a694fcd6fefe94 Mon Sep 17 00:00:00 2001 From: cliebig2019 <82951475+cliebig2019@users.noreply.github.com> Date: Thu, 21 Jul 2022 13:01:01 +0200 Subject: [PATCH 2/2] train_agent passing path to agentFactory --- training/scripts/train_agent.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 training/scripts/train_agent.py diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py old mode 100755 new mode 100644