Skip to content

Commit

Permalink
Ability to save and restore the simulation in parts (#1282)
Browse files Browse the repository at this point in the history
* Added StateRecorder::SetIsLastPart which allows you to specify that you are restoring the state from multiple streams, only the last part should have this set to true.
* Added StateRecorderFilter::ShouldRestoreContact function which allows you to skip restoring certain constraints. This can be used when restoring a partial snapshot onto a full snapshot by selectively ignoring contacts from one of the two snapshots.

Fixes #1254
  • Loading branch information
jrouwe authored Oct 10, 2024
1 parent 0a7d250 commit 5a3935e
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 95 deletions.
179 changes: 107 additions & 72 deletions Jolt/Physics/Constraints/ContactConstraintManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ void ContactConstraintManager::ManifoldCache::SaveState(StateRecorder &inStream,
inStream.Write(m_kv->GetKey());
}

bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream)
bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream, const StateRecorderFilter *inFilter)
{
JPH_ASSERT(!mIsFinalized);

Expand Down Expand Up @@ -499,72 +499,95 @@ bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &
body_pair_key = all_bp[i]->GetKey();
inStream.Read(body_pair_key);

// Create new entry for this body pair
uint64 body_pair_hash = body_pair_key.GetHash();
BPKeyValue *bp_kv = Create(contact_allocator, body_pair_key, body_pair_hash);
if (bp_kv == nullptr)
// Check if we want to restore this contact
if (inFilter == nullptr || inFilter->ShouldRestoreContact(body_pair_key.mBodyA, body_pair_key.mBodyB))
{
// Out of cache space
success = false;
break;
}
CachedBodyPair &bp = bp_kv->GetValue();

// Read body pair
if (inStream.IsValidating() && i < all_bp.size())
memcpy(&bp, &all_bp[i]->GetValue(), sizeof(CachedBodyPair));
bp.RestoreState(inStream);

// When validating, get all existing manifolds
Array<const MKeyValue *> all_m;
if (inStream.IsValidating())
inReadCache.GetAllManifoldsSorted(all_bp[i]->GetValue(), all_m);

// Read amount of manifolds
uint32 num_manifolds;
if (inStream.IsValidating())
num_manifolds = uint32(all_m.size());
inStream.Read(num_manifolds);

uint32 handle = ManifoldMap::cInvalidHandle;
for (uint32 j = 0; j < num_manifolds; ++j)
{
// Read key
SubShapeIDPair sub_shape_key;
if (inStream.IsValidating() && j < all_m.size())
sub_shape_key = all_m[j]->GetKey();
inStream.Read(sub_shape_key);
uint64 sub_shape_key_hash = sub_shape_key.GetHash();

// Read amount of contact points
uint16 num_contact_points;
if (inStream.IsValidating() && j < all_m.size())
num_contact_points = all_m[j]->GetValue().mNumContactPoints;
inStream.Read(num_contact_points);

// Read manifold
MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, num_contact_points);
if (m_kv == nullptr)
// Create new entry for this body pair
uint64 body_pair_hash = body_pair_key.GetHash();
BPKeyValue *bp_kv = Create(contact_allocator, body_pair_key, body_pair_hash);
if (bp_kv == nullptr)
{
// Out of cache space
success = false;
break;
}
CachedManifold &cm = m_kv->GetValue();
if (inStream.IsValidating() && j < all_m.size())
CachedBodyPair &bp = bp_kv->GetValue();

// Read body pair
if (inStream.IsValidating() && i < all_bp.size())
memcpy(&bp, &all_bp[i]->GetValue(), sizeof(CachedBodyPair));
bp.RestoreState(inStream);

// When validating, get all existing manifolds
Array<const MKeyValue *> all_m;
if (inStream.IsValidating())
inReadCache.GetAllManifoldsSorted(all_bp[i]->GetValue(), all_m);

// Read amount of manifolds
uint32 num_manifolds = 0;
if (inStream.IsValidating())
num_manifolds = uint32(all_m.size());
inStream.Read(num_manifolds);

uint32 handle = ManifoldMap::cInvalidHandle;
for (uint32 j = 0; j < num_manifolds; ++j)
{
memcpy(&cm, &all_m[j]->GetValue(), CachedManifold::sGetRequiredTotalSize(num_contact_points));
cm.mNumContactPoints = uint16(num_contact_points); // Restore num contact points
}
cm.RestoreState(inStream);
cm.mNextWithSameBodyPair = handle;
handle = ToHandle(m_kv);
// Read key
SubShapeIDPair sub_shape_key;
if (inStream.IsValidating() && j < all_m.size())
sub_shape_key = all_m[j]->GetKey();
inStream.Read(sub_shape_key);
uint64 sub_shape_key_hash = sub_shape_key.GetHash();

// Read amount of contact points
uint16 num_contact_points = 0;
if (inStream.IsValidating() && j < all_m.size())
num_contact_points = all_m[j]->GetValue().mNumContactPoints;
inStream.Read(num_contact_points);

// Read manifold
MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, num_contact_points);
if (m_kv == nullptr)
{
// Out of cache space
success = false;
break;
}
CachedManifold &cm = m_kv->GetValue();
if (inStream.IsValidating() && j < all_m.size())
{
memcpy(&cm, &all_m[j]->GetValue(), CachedManifold::sGetRequiredTotalSize(num_contact_points));
cm.mNumContactPoints = uint16(num_contact_points); // Restore num contact points
}
cm.RestoreState(inStream);
cm.mNextWithSameBodyPair = handle;
handle = ToHandle(m_kv);

// Read contact points
for (uint32 k = 0; k < num_contact_points; ++k)
cm.mContactPoints[k].RestoreState(inStream);
// Read contact points
for (uint32 k = 0; k < num_contact_points; ++k)
cm.mContactPoints[k].RestoreState(inStream);
}
bp.mFirstCachedManifold = handle;
}
else
{
// Skip the contact
CachedBodyPair bp;
bp.RestoreState(inStream);
uint32 num_manifolds = 0;
inStream.Read(num_manifolds);
for (uint32 j = 0; j < num_manifolds; ++j)
{
SubShapeIDPair sub_shape_key;
inStream.Read(sub_shape_key);
uint16 num_contact_points;
inStream.Read(num_contact_points);
CachedManifold cm;
cm.RestoreState(inStream);
for (uint32 k = 0; k < num_contact_points; ++k)
cm.mContactPoints[0].RestoreState(inStream);
}
}
bp.mFirstCachedManifold = handle;
}

// When validating, get all existing CCD manifolds
Expand All @@ -585,22 +608,28 @@ bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &
if (inStream.IsValidating() && j < all_m.size())
sub_shape_key = all_m[j]->GetKey();
inStream.Read(sub_shape_key);
uint64 sub_shape_key_hash = sub_shape_key.GetHash();

// Create CCD manifold
MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, 0);
if (m_kv == nullptr)
// Check if we want to restore this contact
if (inFilter == nullptr || inFilter->ShouldRestoreContact(sub_shape_key.GetBody1ID(), sub_shape_key.GetBody2ID()))
{
// Out of cache space
success = false;
break;
// Create CCD manifold
uint64 sub_shape_key_hash = sub_shape_key.GetHash();
MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, 0);
if (m_kv == nullptr)
{
// Out of cache space
success = false;
break;
}
CachedManifold &cm = m_kv->GetValue();
cm.mFlags |= (uint16)CachedManifold::EFlags::CCDContact;
}
CachedManifold &cm = m_kv->GetValue();
cm.mFlags |= (uint16)CachedManifold::EFlags::CCDContact;
}

#ifdef JPH_ENABLE_ASSERTS
mIsFinalized = true;
// We don't finalize until the last part is restored
if (inStream.IsLastPart())
mIsFinalized = true;
#endif

return success;
Expand Down Expand Up @@ -1707,11 +1736,17 @@ void ContactConstraintManager::SaveState(StateRecorder &inStream, const StateRec
mCache[mCacheWriteIdx ^ 1].SaveState(inStream, inFilter);
}

bool ContactConstraintManager::RestoreState(StateRecorder &inStream)
bool ContactConstraintManager::RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter)
{
bool success = mCache[mCacheWriteIdx].RestoreState(mCache[mCacheWriteIdx ^ 1], inStream);
mCacheWriteIdx ^= 1;
mCache[mCacheWriteIdx].Clear();
bool success = mCache[mCacheWriteIdx].RestoreState(mCache[mCacheWriteIdx ^ 1], inStream, inFilter);

// If this is the last part, the cache is finalized
if (inStream.IsLastPart())
{
mCacheWriteIdx ^= 1;
mCache[mCacheWriteIdx].Clear();
}

return success;
}

Expand Down
4 changes: 2 additions & 2 deletions Jolt/Physics/Constraints/ContactConstraintManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class JPH_EXPORT ContactConstraintManager : public NonCopyable
void SaveState(StateRecorder &inStream, const StateRecorderFilter *inFilter) const;

/// Restoring state for replay. Returns false when failed.
bool RestoreState(StateRecorder &inStream);
bool RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter);

private:
/// Local space contact point, used for caching impulses
Expand Down Expand Up @@ -393,7 +393,7 @@ class JPH_EXPORT ContactConstraintManager : public NonCopyable

/// Saving / restoring state for replay
void SaveState(StateRecorder &inStream, const StateRecorderFilter *inFilter) const;
bool RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream);
bool RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream, const StateRecorderFilter *inFilter);

private:
/// Block size used when allocating new blocks in the contact cache
Expand Down
19 changes: 11 additions & 8 deletions Jolt/Physics/PhysicsSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2657,7 +2657,7 @@ void PhysicsSystem::SaveState(StateRecorder &inStream, EStateRecorderState inSta
mConstraintManager.SaveState(inStream, inFilter);
}

bool PhysicsSystem::RestoreState(StateRecorder &inStream)
bool PhysicsSystem::RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter)
{
JPH_PROFILE_FUNCTION();

Expand All @@ -2676,17 +2676,20 @@ bool PhysicsSystem::RestoreState(StateRecorder &inStream)
return false;

// Update bounding boxes for all bodies in the broadphase
Array<BodyID> bodies;
for (const Body *b : mBodyManager.GetBodies())
if (BodyManager::sIsValidBodyPointer(b) && b->IsInBroadPhase())
bodies.push_back(b->GetID());
if (!bodies.empty())
mBroadPhase->NotifyBodiesAABBChanged(&bodies[0], (int)bodies.size());
if (inStream.IsLastPart())
{
Array<BodyID> bodies;
for (const Body *b : mBodyManager.GetBodies())
if (BodyManager::sIsValidBodyPointer(b) && b->IsInBroadPhase())
bodies.push_back(b->GetID());
if (!bodies.empty())
mBroadPhase->NotifyBodiesAABBChanged(&bodies[0], (int)bodies.size());
}
}

if (uint8(state) & uint8(EStateRecorderState::Contacts))
{
if (!mContactManager.RestoreState(inStream))
if (!mContactManager.RestoreState(inStream, inFilter))
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion Jolt/Physics/PhysicsSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class JPH_EXPORT PhysicsSystem : public NonCopyable
void SaveState(StateRecorder &inStream, EStateRecorderState inState = EStateRecorderState::All, const StateRecorderFilter *inFilter = nullptr) const;

/// Restoring state for replay. Returns false if failed.
bool RestoreState(StateRecorder &inStream);
bool RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter = nullptr);

/// Saving state of a single body.
void SaveBodyState(const Body &inBody, StateRecorder &inStream) const;
Expand Down
65 changes: 65 additions & 0 deletions Jolt/Physics/StateRecorder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,61 @@ enum class EStateRecorderState : uint8
All = Global | Bodies | Contacts | Constraints ///< Save all state
};

/// Bitwise OR operator for EStateRecorderState
constexpr EStateRecorderState operator | (EStateRecorderState inLHS, EStateRecorderState inRHS)
{
return EStateRecorderState(uint8(inLHS) | uint8(inRHS));
}

/// Bitwise AND operator for EStateRecorderState
constexpr EStateRecorderState operator & (EStateRecorderState inLHS, EStateRecorderState inRHS)
{
return EStateRecorderState(uint8(inLHS) & uint8(inRHS));
}

/// Bitwise XOR operator for EStateRecorderState
constexpr EStateRecorderState operator ^ (EStateRecorderState inLHS, EStateRecorderState inRHS)
{
return EStateRecorderState(uint8(inLHS) ^ uint8(inRHS));
}

/// Bitwise NOT operator for EStateRecorderState
constexpr EStateRecorderState operator ~ (EStateRecorderState inAllowedDOFs)
{
return EStateRecorderState(~uint8(inAllowedDOFs));
}

/// Bitwise OR assignment operator for EStateRecorderState
constexpr EStateRecorderState & operator |= (EStateRecorderState &ioLHS, EStateRecorderState inRHS)
{
ioLHS = ioLHS | inRHS;
return ioLHS;
}

/// Bitwise AND assignment operator for EStateRecorderState
constexpr EStateRecorderState & operator &= (EStateRecorderState &ioLHS, EStateRecorderState inRHS)
{
ioLHS = ioLHS & inRHS;
return ioLHS;
}

/// Bitwise XOR assignment operator for EStateRecorderState
constexpr EStateRecorderState & operator ^= (EStateRecorderState &ioLHS, EStateRecorderState inRHS)
{
ioLHS = ioLHS ^ inRHS;
return ioLHS;
}

/// User callbacks that allow determining which parts of the simulation should be saved by a StateRecorder
class JPH_EXPORT StateRecorderFilter
{
public:
/// Destructor
virtual ~StateRecorderFilter() = default;

///@name Functions called during SaveState
///@{

/// If the state of a specific body should be saved
virtual bool ShouldSaveBody([[maybe_unused]] const Body &inBody) const { return true; }

Expand All @@ -39,6 +87,15 @@ class JPH_EXPORT StateRecorderFilter

/// If the state of a specific contact should be saved
virtual bool ShouldSaveContact([[maybe_unused]] const BodyID &inBody1, [[maybe_unused]] const BodyID &inBody2) const { return true; }

///@}
///@name Functions called during RestoreState
///@{

/// If the state of a specific contact should be restored
virtual bool ShouldRestoreContact([[maybe_unused]] const BodyID &inBody1, [[maybe_unused]] const BodyID &inBody2) const { return true; }

///@}
};

/// Class that records the state of a physics system. Can be used to check if the simulation is deterministic by putting the recorder in validation mode.
Expand All @@ -59,8 +116,16 @@ class JPH_EXPORT StateRecorder : public StreamIn, public StreamOut
void SetValidating(bool inValidating) { mIsValidating = inValidating; }
bool IsValidating() const { return mIsValidating; }

/// This allows splitting the state in multiple parts. While restoring, only the last part should have this flag set to true.
/// Note that you should ensure that the different parts contain information for disjoint sets of bodies, constraints and contacts.
/// E.g. if you restore the same contact twice, you get undefined behavior. In order to create disjoint sets you can use the StateRecorderFilter.
/// Note that validation is not compatible with restoring a simulation state in multiple parts.
void SetIsLastPart(bool inIsLastPart) { mIsLastPart = inIsLastPart; }
bool IsLastPart() const { return mIsLastPart; }

private:
bool mIsValidating = false;
bool mIsLastPart = true;
};

JPH_NAMESPACE_END
3 changes: 3 additions & 0 deletions Jolt/Physics/StateRecorderImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class JPH_EXPORT StateRecorderImpl final : public StateRecorder
/// Convert the binary data to a string
string GetData() const { return mStream.str(); }

/// Get size of the binary data in bytes
size_t GetDataSize() { return size_t(mStream.tellp()); }

private:
std::stringstream mStream;
};
Expand Down
Loading

0 comments on commit 5a3935e

Please sign in to comment.