From 2a658f9aee9e3713e31abb11767a9e2fb64b2c87 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Fri, 11 Oct 2024 13:20:52 +0800 Subject: [PATCH 1/8] Adds support for Reset in test fixture This PR adds support for the Reset API to the test fixture. As `TestFixture` is one of the main ways one can get access to the ECM in python when trying to write some scripts for Deep Reinforcement Learning I realized that without `Reset` supported in the `TestFixture` API, end users would have a very hard time using our python APIs (which are actually quite nice). For reference I'm hacking a demo template here: https://github.com/arjo129/gz_deep_rl_experiments/tree/ionic Signed-off-by: Arjo Chakravarty --- include/gz/sim/TestFixture.hh | 6 ++++++ python/src/gz/sim/TestFixture.cc | 11 +++++++++++ src/TestFixture.cc | 29 ++++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/include/gz/sim/TestFixture.hh b/include/gz/sim/TestFixture.hh index 24fb2298a4..6041062cad 100644 --- a/include/gz/sim/TestFixture.hh +++ b/include/gz/sim/TestFixture.hh @@ -96,6 +96,12 @@ class GZ_SIM_VISIBLE TestFixture public: TestFixture &OnPostUpdate(std::function _cb); + /// \brief Wrapper around a system's update callback + /// \param[in] _cb Function to be called every update + /// \return Reference to self. + public: TestFixture &OnReset(std::function _cb); + /// \brief Finalize all the functions and add fixture to server. /// Finalize must be called before running the server, otherwise none of the /// `On*` functions will be called. diff --git a/python/src/gz/sim/TestFixture.cc b/python/src/gz/sim/TestFixture.cc index 826558fbf4..1fa05d23c3 100644 --- a/python/src/gz/sim/TestFixture.cc +++ b/python/src/gz/sim/TestFixture.cc @@ -83,6 +83,17 @@ defineSimTestFixture(pybind11::object module) ), pybind11::return_value_policy::reference, "Wrapper around a system's post-update callback" + ) + .def( + "on_reset", WrapCallbacks( + [](TestFixture* self, std::function _cb) + { + self->OnReset(_cb); + } + ), + pybind11::return_value_policy::reference, + "Wrapper around a system's post-update callback" ); // TODO(ahcorde): This method is not compiling for the following reason: // The EventManager class has an unordered_map which holds a unique_ptr diff --git a/src/TestFixture.cc b/src/TestFixture.cc index 1d02a900ff..6c984b7996 100644 --- a/src/TestFixture.cc +++ b/src/TestFixture.cc @@ -29,7 +29,8 @@ class HelperSystem : public ISystemConfigure, public ISystemPreUpdate, public ISystemUpdate, - public ISystemPostUpdate + public ISystemPostUpdate, + public ISystemReset { // Documentation inherited public: void Configure( @@ -50,6 +51,10 @@ class HelperSystem : public: void PostUpdate(const UpdateInfo &_info, const EntityComponentManager &_ecm) override; + // Documentation inherited + public: void Reset(const UpdateInfo &_info, + EntityComponentManager &_ecm) override; + /// \brief Function to call every time we configure a world public: std::function &_sdf, @@ -68,6 +73,10 @@ class HelperSystem : /// \brief Function to call every post-update public: std::function postUpdateCallback; + + /// \brief Reset callback + public: std::function resetCallback; }; ///////////////////////////////////////////////// @@ -105,6 +114,14 @@ void HelperSystem::PostUpdate(const UpdateInfo &_info, this->postUpdateCallback(_info, _ecm); } +///////////////////////////////////////////////// +void HelperSystem::Reset(const UpdateInfo &_info, + EntityComponentManager &_ecm) +{ + if (this->resetCallback) + this->resetCallback(_info, _ecm); +} + ////////////////////////////////////////////////// class gz::sim::TestFixture::Implementation { @@ -200,6 +217,16 @@ TestFixture &TestFixture::OnPostUpdate(std::function _cb) +{ + if (nullptr != this->dataPtr->helperSystem) + this->dataPtr->helperSystem->resetCallback = std::move(_cb); + return *this; +} + + ////////////////////////////////////////////////// std::shared_ptr TestFixture::Server() const { From 053152f66161de69fa8d9cb243a4020992b3ebec Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Fri, 11 Oct 2024 14:38:18 +0800 Subject: [PATCH 2/8] Add support for simulation reset via a publicly callable API This allows us to reset simulations without having to call into gz-transport making the code more readable from an external API. Depends on #2647 Signed-off-by: Arjo Chakravarty --- include/gz/sim/Server.hh | 3 +++ python/src/gz/sim/Server.cc | 4 +++- src/Server.cc | 11 +++++++++++ src/SimulationRunner.cc | 15 ++++++++++++++- src/SimulationRunner.hh | 3 +++ src/TestFixture_TEST.cc | 32 +++++++++++++++++++++++++++++--- 6 files changed, 63 insertions(+), 5 deletions(-) diff --git a/include/gz/sim/Server.hh b/include/gz/sim/Server.hh index 28b5b3feb1..7bc006c8a7 100644 --- a/include/gz/sim/Server.hh +++ b/include/gz/sim/Server.hh @@ -338,6 +338,9 @@ namespace gz /// \brief Stop the server. This will stop all running simulations. public: void Stop(); + /// \brief Reset All runners in this simulation + public: void ResetAll(); + /// \brief Private data private: std::unique_ptr dataPtr; }; diff --git a/python/src/gz/sim/Server.cc b/python/src/gz/sim/Server.cc index c148a08ff2..7ece7ec090 100644 --- a/python/src/gz/sim/Server.cc +++ b/python/src/gz/sim/Server.cc @@ -46,7 +46,9 @@ void defineSimServer(pybind11::object module) .def( "is_running", pybind11::overload_cast<>(&gz::sim::Server::Running, pybind11::const_), - "Get whether the server is running."); + "Get whether the server is running.") + .def("reset_all", &gz::sim::Server::ResetAll, + "Resets all simulation runners under this server."); } } // namespace python } // namespace sim diff --git a/src/Server.cc b/src/Server.cc index 712f273467..602eb1e569 100644 --- a/src/Server.cc +++ b/src/Server.cc @@ -485,6 +485,17 @@ bool Server::RequestRemoveEntity(const Entity _entity, return false; } +////////////////////////////////////////////////// +void Server::ResetAll() +{ + for (auto worldId = 0u; + worldId < this->dataPtr->simRunners.size(); + worldId++) + { + this->dataPtr->simRunners[worldId]->Reset(true, false, false); + } +} + ////////////////////////////////////////////////// void Server::Stop() { diff --git a/src/SimulationRunner.cc b/src/SimulationRunner.cc index baa2712adc..d98f7c1d02 100644 --- a/src/SimulationRunner.cc +++ b/src/SimulationRunner.cc @@ -95,7 +95,6 @@ struct MaybeGilScopedRelease #endif } - ////////////////////////////////////////////////// SimulationRunner::SimulationRunner(const sdf::World &_world, const SystemLoaderPtr &_systemLoader, @@ -1661,3 +1660,17 @@ void SimulationRunner::CreateEntities(const sdf::World &_world) // Store the initial state of the ECM; this->initialEntityCompMgr.CopyFrom(this->entityCompMgr); } + +///////////////////////////////////////////////// +void SimulationRunner::Reset(const bool all, + const bool time, const bool model) +{ + WorldControl control; + std::lock_guard lock(this->msgBufferMutex); + control.rewind = all || time; + if (model) + { + gzwarn << "Model reset not supported" <worldControls.push_back(control); +} \ No newline at end of file diff --git a/src/SimulationRunner.hh b/src/SimulationRunner.hh index 8fe03511e7..8b5ac8bf36 100644 --- a/src/SimulationRunner.hh +++ b/src/SimulationRunner.hh @@ -369,6 +369,9 @@ namespace gz /// \brief Set the next step to be blocking and paused. public: void SetNextStepAsBlockingPaused(const bool value); + /// \brief Reset the current simulation runner + public: void Reset(const bool all, const bool time, const bool model); + /// \brief Updates the physics parameters of the simulation based on the /// Physics component of the world, if any. public: void UpdatePhysicsParams(); diff --git a/src/TestFixture_TEST.cc b/src/TestFixture_TEST.cc index f810ae64c1..21a9b7a9bd 100644 --- a/src/TestFixture_TEST.cc +++ b/src/TestFixture_TEST.cc @@ -72,6 +72,8 @@ TEST_F(TestFixtureTest, Callbacks) unsigned int preUpdates{0u}; unsigned int updates{0u}; unsigned int postUpdates{0u}; + unsigned int resets{0u}; + testFixture. OnConfigure([&](const Entity &_entity, const std::shared_ptr &_sdf, @@ -85,20 +87,33 @@ TEST_F(TestFixtureTest, Callbacks) { this->Updates(_info, _ecm); preUpdates++; - EXPECT_EQ(preUpdates, _info.iterations); + if (resets == 0) + { + EXPECT_EQ(preUpdates, _info.iterations); + } }). OnUpdate([&](const UpdateInfo &_info, EntityComponentManager &_ecm) { this->Updates(_info, _ecm); updates++; - EXPECT_EQ(updates, _info.iterations); + if (resets == 0) + { + EXPECT_EQ(updates, _info.iterations); + } }). OnPostUpdate([&](const UpdateInfo &_info, const EntityComponentManager &_ecm) { this->Updates(_info, _ecm); postUpdates++; - EXPECT_EQ(postUpdates, _info.iterations); + if (resets == 0) + { + EXPECT_EQ(postUpdates, _info.iterations); + } + }). + OnReset([&](const UpdateInfo &, EntityComponentManager &) + { + resets++; }). Finalize(); @@ -109,8 +124,19 @@ TEST_F(TestFixtureTest, Callbacks) EXPECT_EQ(expectedIterations, preUpdates); EXPECT_EQ(expectedIterations, updates); EXPECT_EQ(expectedIterations, postUpdates); + EXPECT_EQ(0u, resets); + + testFixture.Server()->ResetAll(); + + testFixture.Server()->Run(true, expectedIterations, false); + EXPECT_EQ(1u, configures); + EXPECT_EQ(expectedIterations * 2 - 1, preUpdates); + EXPECT_EQ(expectedIterations * 2 - 1, updates); + EXPECT_EQ(expectedIterations * 2 - 1, postUpdates); + EXPECT_EQ(1u, resets); } + ///////////////////////////////////////////////// TEST_F(TestFixtureTest, LoadConfig) { From ff468b79c63314ff5e0e1245c8b42957fbb6230a Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Fri, 11 Oct 2024 14:46:24 +0800 Subject: [PATCH 3/8] Style Signed-off-by: Arjo Chakravarty --- src/SimulationRunner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SimulationRunner.cc b/src/SimulationRunner.cc index d98f7c1d02..9fbf41bac3 100644 --- a/src/SimulationRunner.cc +++ b/src/SimulationRunner.cc @@ -1673,4 +1673,4 @@ void SimulationRunner::Reset(const bool all, gzwarn << "Model reset not supported" <worldControls.push_back(control); -} \ No newline at end of file +} From e0ee0dd51e284a45a7638738c23f526245bc69da Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Fri, 11 Oct 2024 15:00:36 +0800 Subject: [PATCH 4/8] Style Signed-off-by: Arjo Chakravarty --- src/SimulationRunner.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/SimulationRunner.cc b/src/SimulationRunner.cc index 9fbf41bac3..5656ee6786 100644 --- a/src/SimulationRunner.cc +++ b/src/SimulationRunner.cc @@ -1662,13 +1662,13 @@ void SimulationRunner::CreateEntities(const sdf::World &_world) } ///////////////////////////////////////////////// -void SimulationRunner::Reset(const bool all, - const bool time, const bool model) +void SimulationRunner::Reset(const bool _all, + const bool _time, const bool _model) { WorldControl control; std::lock_guard lock(this->msgBufferMutex); - control.rewind = all || time; - if (model) + control.rewind = _all || _time; + if (_model) { gzwarn << "Model reset not supported" < Date: Fri, 11 Oct 2024 15:01:26 +0800 Subject: [PATCH 5/8] Style Signed-off-by: Arjo Chakravarty --- src/SimulationRunner.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/SimulationRunner.cc b/src/SimulationRunner.cc index 5656ee6786..96ac61d7d4 100644 --- a/src/SimulationRunner.cc +++ b/src/SimulationRunner.cc @@ -95,6 +95,7 @@ struct MaybeGilScopedRelease #endif } + ////////////////////////////////////////////////// SimulationRunner::SimulationRunner(const sdf::World &_world, const SystemLoaderPtr &_systemLoader, From 76adc263390d6ef49158cd82ab1a1424b061a809 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Fri, 11 Oct 2024 17:11:39 +0800 Subject: [PATCH 6/8] Typo Signed-off-by: Arjo Chakravarty --- include/gz/sim/Server.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/gz/sim/Server.hh b/include/gz/sim/Server.hh index 7bc006c8a7..ae49bd59f6 100644 --- a/include/gz/sim/Server.hh +++ b/include/gz/sim/Server.hh @@ -338,7 +338,7 @@ namespace gz /// \brief Stop the server. This will stop all running simulations. public: void Stop(); - /// \brief Reset All runners in this simulation + /// \brief Reset all runners in this simulation public: void ResetAll(); /// \brief Private data From 3e83828a5b2d87576acef4337fd2c6f06ddcd481 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Wed, 8 Jan 2025 19:14:34 +0800 Subject: [PATCH 7/8] Address feedback Signed-off-by: Arjo Chakravarty --- python/src/gz/sim/TestFixture.cc | 2 +- src/SimulationRunner.hh | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/src/gz/sim/TestFixture.cc b/python/src/gz/sim/TestFixture.cc index 1fa05d23c3..dc2589cd51 100644 --- a/python/src/gz/sim/TestFixture.cc +++ b/python/src/gz/sim/TestFixture.cc @@ -93,7 +93,7 @@ defineSimTestFixture(pybind11::object module) } ), pybind11::return_value_policy::reference, - "Wrapper around a system's post-update callback" + "Wrapper around a system's reset callback" ); // TODO(ahcorde): This method is not compiling for the following reason: // The EventManager class has an unordered_map which holds a unique_ptr diff --git a/src/SimulationRunner.hh b/src/SimulationRunner.hh index 8b5ac8bf36..03ac9e47f8 100644 --- a/src/SimulationRunner.hh +++ b/src/SimulationRunner.hh @@ -370,6 +370,9 @@ namespace gz public: void SetNextStepAsBlockingPaused(const bool value); /// \brief Reset the current simulation runner + /// \param[in] all - Reset all parameters + /// \param[in] time - Reset the time + /// \param[in] model - Reset the model only [currently unsupported] public: void Reset(const bool all, const bool time, const bool model); /// \brief Updates the physics parameters of the simulation based on the From 15fa86727f197bd3591524453a19b01a47732b72 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Wed, 8 Jan 2025 20:01:28 +0800 Subject: [PATCH 8/8] Add support for individual resets Signed-off-by: Arjo Chakravarty --- include/gz/sim/Server.hh | 4 ++++ python/src/gz/sim/Server.cc | 4 +++- src/Server.cc | 11 +++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/include/gz/sim/Server.hh b/include/gz/sim/Server.hh index ae49bd59f6..9867ca68d7 100644 --- a/include/gz/sim/Server.hh +++ b/include/gz/sim/Server.hh @@ -340,6 +340,10 @@ namespace gz /// \brief Reset all runners in this simulation public: void ResetAll(); + /// \brief Reset a specific runner in this server + /// \param[in] runnerId - The runner which you want to reset + /// \ return False if the runner does not exist, true otherwise. + public: bool Reset(const std::size_t _runnerId); /// \brief Private data private: std::unique_ptr dataPtr; diff --git a/python/src/gz/sim/Server.cc b/python/src/gz/sim/Server.cc index 7ece7ec090..b5688ece9b 100644 --- a/python/src/gz/sim/Server.cc +++ b/python/src/gz/sim/Server.cc @@ -48,7 +48,9 @@ void defineSimServer(pybind11::object module) pybind11::overload_cast<>(&gz::sim::Server::Running, pybind11::const_), "Get whether the server is running.") .def("reset_all", &gz::sim::Server::ResetAll, - "Resets all simulation runners under this server."); + "Resets all simulation runners under this server.") + .def("reset", &gz::sim::Server::Reset, + "Resets a specific simulation runner under this server."); } } // namespace python } // namespace sim diff --git a/src/Server.cc b/src/Server.cc index 602eb1e569..25497fa960 100644 --- a/src/Server.cc +++ b/src/Server.cc @@ -496,6 +496,17 @@ void Server::ResetAll() } } +////////////////////////////////////////////////// +bool Server::Reset(const std::size_t _runnerId) +{ + if (_runnerId >= this->dataPtr->simRunners.size()) + { + return false; + } + this->dataPtr->simRunners[_runnerId]->Reset(true, false, false); + return true; +} + ////////////////////////////////////////////////// void Server::Stop() {