diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py old mode 100755 new mode 100644 index 447787e2f..c03cc9f19 --- a/training/scripts/train_agent.py +++ b/training/scripts/train_agent.py @@ -133,7 +133,7 @@ def main(): verbose=1, ) elif args.agent is not None: - agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent) + agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent, path = args.path) if isinstance(agent, BaseAgent): model = PPO( agent.type.value, diff --git a/training/tools/argsparser.py b/training/tools/argsparser.py index 46645e7e2..cb9110a97 100644 --- a/training/tools/argsparser.py +++ b/training/tools/argsparser.py @@ -23,6 +23,7 @@ def training_args(parser): import rosnav.model.custom_policy import rosnav.model.custom_sb3_policy from rosnav.model.agent_factory import AgentFactory + import rosnav.model.custom_policy_from_json group.add_argument( "--agent", @@ -60,6 +61,11 @@ def training_args(parser): parser.add_argument( "--tb", action="store_true", help="enables tensorboard logging" ) + parser.add_argument( + "--path", + type=str, + help="path to the json file containing" "the neural network", + ) def run_agent_args(parser):