Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for simulation reset via a publicly callable API #2648

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/gz/sim/Server.hh
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ namespace gz
/// \brief Stop the server. This will stop all running simulations.
public: void Stop();

/// \brief Reset all runners in this simulation
public: void ResetAll();
arjo129 marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Private data
private: std::unique_ptr<ServerPrivate> dataPtr;
};
Expand Down
6 changes: 6 additions & 0 deletions include/gz/sim/TestFixture.hh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ class GZ_SIM_VISIBLE TestFixture
public: TestFixture &OnPostUpdate(std::function<void(
const UpdateInfo &, const EntityComponentManager &)> _cb);

/// \brief Wrapper around a system's update callback
/// \param[in] _cb Function to be called every update
/// \return Reference to self.
public: TestFixture &OnReset(std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb);

/// \brief Finalize all the functions and add fixture to server.
/// Finalize must be called before running the server, otherwise none of the
/// `On*` functions will be called.
Expand Down
4 changes: 3 additions & 1 deletion python/src/gz/sim/Server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ void defineSimServer(pybind11::object module)
.def(
"is_running",
pybind11::overload_cast<>(&gz::sim::Server::Running, pybind11::const_),
"Get whether the server is running.");
"Get whether the server is running.")
.def("reset_all", &gz::sim::Server::ResetAll,
"Resets all simulation runners under this server.");
}
} // namespace python
} // namespace sim
Expand Down
11 changes: 11 additions & 0 deletions python/src/gz/sim/TestFixture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ defineSimTestFixture(pybind11::object module)
),
pybind11::return_value_policy::reference,
"Wrapper around a system's post-update callback"
)
.def(
"on_reset", WrapCallbacks(
[](TestFixture* self, std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb)
{
self->OnReset(_cb);
}
),
pybind11::return_value_policy::reference,
"Wrapper around a system's post-update callback"
);
// TODO(ahcorde): This method is not compiling for the following reason:
// The EventManager class has an unordered_map which holds a unique_ptr
Expand Down
11 changes: 11 additions & 0 deletions src/Server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,17 @@ bool Server::RequestRemoveEntity(const Entity _entity,
return false;
}

//////////////////////////////////////////////////
void Server::ResetAll()
{
for (auto worldId = 0u;
worldId < this->dataPtr->simRunners.size();
worldId++)
{
this->dataPtr->simRunners[worldId]->Reset(true, false, false);
}
}

//////////////////////////////////////////////////
void Server::Stop()
{
Expand Down
14 changes: 14 additions & 0 deletions src/SimulationRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1661,3 +1661,17 @@ void SimulationRunner::CreateEntities(const sdf::World &_world)
// Store the initial state of the ECM;
this->initialEntityCompMgr.CopyFrom(this->entityCompMgr);
}

/////////////////////////////////////////////////
void SimulationRunner::Reset(const bool _all,
const bool _time, const bool _model)
{
WorldControl control;
std::lock_guard<std::mutex> lock(this->msgBufferMutex);
control.rewind = _all || _time;
if (_model)
{
gzwarn << "Model reset not supported" <<std::endl;
}
this->worldControls.push_back(control);
}
3 changes: 3 additions & 0 deletions src/SimulationRunner.hh
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ namespace gz
/// \brief Set the next step to be blocking and paused.
public: void SetNextStepAsBlockingPaused(const bool value);

/// \brief Reset the current simulation runner
public: void Reset(const bool all, const bool time, const bool model);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document parameters.

Copy link
Contributor Author

@arjo129 arjo129 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left model reset in here because we have some scaffolding for it but, I'm not sure if I should leave it in.


/// \brief Updates the physics parameters of the simulation based on the
/// Physics component of the world, if any.
public: void UpdatePhysicsParams();
Expand Down
29 changes: 28 additions & 1 deletion src/TestFixture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class HelperSystem :
public ISystemConfigure,
public ISystemPreUpdate,
public ISystemUpdate,
public ISystemPostUpdate
public ISystemPostUpdate,
public ISystemReset
{
// Documentation inherited
public: void Configure(
Expand All @@ -50,6 +51,10 @@ class HelperSystem :
public: void PostUpdate(const UpdateInfo &_info,
const EntityComponentManager &_ecm) override;

// Documentation inherited
public: void Reset(const UpdateInfo &_info,
EntityComponentManager &_ecm) override;

/// \brief Function to call every time we configure a world
public: std::function<void(const Entity &_entity,
const std::shared_ptr<const sdf::Element> &_sdf,
Expand All @@ -68,6 +73,10 @@ class HelperSystem :
/// \brief Function to call every post-update
public: std::function<void(const UpdateInfo &,
const EntityComponentManager &)> postUpdateCallback;

/// \brief Reset callback
public: std::function<void(const UpdateInfo &,
EntityComponentManager &)> resetCallback;
};

/////////////////////////////////////////////////
Expand Down Expand Up @@ -105,6 +114,14 @@ void HelperSystem::PostUpdate(const UpdateInfo &_info,
this->postUpdateCallback(_info, _ecm);
}

/////////////////////////////////////////////////
void HelperSystem::Reset(const UpdateInfo &_info,
EntityComponentManager &_ecm)
{
if (this->resetCallback)
this->resetCallback(_info, _ecm);
}

//////////////////////////////////////////////////
class gz::sim::TestFixture::Implementation
{
Expand Down Expand Up @@ -200,6 +217,16 @@ TestFixture &TestFixture::OnPostUpdate(std::function<void(
return *this;
}

//////////////////////////////////////////////////
TestFixture &TestFixture::OnReset(std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb)
{
if (nullptr != this->dataPtr->helperSystem)
this->dataPtr->helperSystem->resetCallback = std::move(_cb);
return *this;
}


//////////////////////////////////////////////////
std::shared_ptr<Server> TestFixture::Server() const
{
Expand Down
32 changes: 29 additions & 3 deletions src/TestFixture_TEST.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ TEST_F(TestFixtureTest, Callbacks)
unsigned int preUpdates{0u};
unsigned int updates{0u};
unsigned int postUpdates{0u};
unsigned int resets{0u};

testFixture.
OnConfigure([&](const Entity &_entity,
const std::shared_ptr<const sdf::Element> &_sdf,
Expand All @@ -85,20 +87,33 @@ TEST_F(TestFixtureTest, Callbacks)
{
this->Updates(_info, _ecm);
preUpdates++;
EXPECT_EQ(preUpdates, _info.iterations);
if (resets == 0)
{
EXPECT_EQ(preUpdates, _info.iterations);
}
}).
OnUpdate([&](const UpdateInfo &_info, EntityComponentManager &_ecm)
{
this->Updates(_info, _ecm);
updates++;
EXPECT_EQ(updates, _info.iterations);
if (resets == 0)
{
EXPECT_EQ(updates, _info.iterations);
}
}).
OnPostUpdate([&](const UpdateInfo &_info,
const EntityComponentManager &_ecm)
{
this->Updates(_info, _ecm);
postUpdates++;
EXPECT_EQ(postUpdates, _info.iterations);
if (resets == 0)
{
EXPECT_EQ(postUpdates, _info.iterations);
}
}).
OnReset([&](const UpdateInfo &, EntityComponentManager &)
{
resets++;
}).
Finalize();

Expand All @@ -109,8 +124,19 @@ TEST_F(TestFixtureTest, Callbacks)
EXPECT_EQ(expectedIterations, preUpdates);
EXPECT_EQ(expectedIterations, updates);
EXPECT_EQ(expectedIterations, postUpdates);
EXPECT_EQ(0u, resets);

testFixture.Server()->ResetAll();

testFixture.Server()->Run(true, expectedIterations, false);
EXPECT_EQ(1u, configures);
EXPECT_EQ(expectedIterations * 2 - 1, preUpdates);
EXPECT_EQ(expectedIterations * 2 - 1, updates);
EXPECT_EQ(expectedIterations * 2 - 1, postUpdates);
EXPECT_EQ(1u, resets);
}


/////////////////////////////////////////////////
TEST_F(TestFixtureTest, LoadConfig)
{
Expand Down
Loading