diff --git a/lib/RLTrader.py b/lib/RLTrader.py index 4529570..850be21 100644 --- a/lib/RLTrader.py +++ b/lib/RLTrader.py @@ -230,14 +230,16 @@ def test(self, model_epoch: int = 0, should_render: bool = True): del train_provider - test_env = DummyVecEnv([make_env(test_provider, i) for i in range(1)]) + init_envs = DummyVecEnv([make_env(test_provider) for _ in range(self.n_envs)]) model_path = path.join('data', 'agents', f'{self.study_name}__{model_epoch}.pkl') - model = self.Model.load(model_path, env=test_env) + model = self.Model.load(model_path, env=init_envs) + + test_env = DummyVecEnv([make_env(test_provider) for _ in range(1)]) self.logger.info(f'Testing model ({self.study_name}__{model_epoch})') - zero_completed_obs = np.zeros((self.n_envs,) + test_env.observation_space.shape) + zero_completed_obs = np.zeros((self.n_envs,) + init_envs.observation_space.shape) zero_completed_obs[0, :] = test_env.reset() state = None @@ -245,7 +247,7 @@ def test(self, model_epoch: int = 0, should_render: bool = True): for _ in range(len(test_provider.data_frame)): action, state = model.predict(zero_completed_obs, state=state) - obs, reward, _, __ = test_env.step([action]) + obs, reward, _, __ = test_env.step([action[0]]) zero_completed_obs[0, :] = obs