Skip to content

Commit

Permalink
[commands] Fix C++ iterator invalidation bug (#7554)
Browse files Browse the repository at this point in the history
Co-authored-by: Joseph Eng <[email protected]>
Co-authored-by: Ryan Blue <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent f46c81c commit 807dffe
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ class CommandScheduler::Impl {
wpi::SmallVector<InterruptAction, 4> interruptActions;
wpi::SmallVector<Action, 4> finishActions;

// Flag and queues for avoiding concurrent modification if commands are
// scheduled/canceled during run
bool inRunLoop = false;
wpi::SmallVector<Command*, 4> toSchedule;
wpi::SmallVector<Command*, 4> toCancelCommands;
wpi::SmallVector<std::optional<Command*>, 4> toCancelInterruptors;
wpi::SmallSet<Command*, 4> endingCommands;

// Map of Command* -> CommandPtr for CommandPtrs transferred to the scheduler
// via Schedule(CommandPtr&&). These are erased (destroyed) at the very end of
// the loop cycle when the command lifecycle is complete.
Expand Down Expand Up @@ -118,11 +110,6 @@ frc::EventLoop* CommandScheduler::GetDefaultButtonLoop() const {
}

void CommandScheduler::Schedule(Command* command) {
if (m_impl->inRunLoop) {
m_impl->toSchedule.emplace_back(command);
return;
}

RequireUngrouped(command);

if (m_impl->disabled || m_impl->scheduledCommands.contains(command) ||
Expand Down Expand Up @@ -208,10 +195,13 @@ void CommandScheduler::Run() {
loopCache->Poll();
m_watchdog.AddEpoch("buttons.Run()");

m_impl->inRunLoop = true;
bool isDisabled = frc::RobotState::IsDisabled();
// Run scheduled commands, remove finished commands.
for (Command* command : m_impl->scheduledCommands) {
// create a new set to avoid iterator invalidation.
for (Command* command : wpi::SmallSet(m_impl->scheduledCommands)) {
if (!IsScheduled(command)) {
continue; // skip as the normal scheduledCommands was modified
}

if (isDisabled && !command->RunsWhenDisabled()) {
Cancel(command, std::nullopt);
continue;
Expand All @@ -224,14 +214,12 @@ void CommandScheduler::Run() {
m_watchdog.AddEpoch(command->GetName() + ".Execute()");

if (command->IsFinished()) {
m_impl->endingCommands.insert(command);
m_impl->scheduledCommands.erase(command);
command->End(false);
for (auto&& action : m_impl->finishActions) {
action(*command);
}
m_impl->endingCommands.erase(command);

m_impl->scheduledCommands.erase(command);
for (auto&& requirement : command->GetRequirements()) {
m_impl->requirements.erase(requirement);
}
Expand All @@ -241,19 +229,6 @@ void CommandScheduler::Run() {
m_impl->ownedCommands.erase(command);
}
}
m_impl->inRunLoop = false;

for (Command* command : m_impl->toSchedule) {
Schedule(command);
}

for (size_t i = 0; i < m_impl->toCancelCommands.size(); i++) {
Cancel(m_impl->toCancelCommands[i], m_impl->toCancelInterruptors[i]);
}

m_impl->toSchedule.clear();
m_impl->toCancelCommands.clear();
m_impl->toCancelInterruptors.clear();

// Add default commands for un-required registered subsystems.
for (auto&& subsystem : m_impl->subsystems) {
Expand Down Expand Up @@ -346,24 +321,14 @@ void CommandScheduler::Cancel(Command* command,
if (!m_impl) {
return;
}
if (m_impl->endingCommands.contains(command)) {
return;
}
if (m_impl->inRunLoop) {
m_impl->toCancelCommands.emplace_back(command);
m_impl->toCancelInterruptors.emplace_back(interruptor);
return;
}
if (!IsScheduled(command)) {
return;
}
m_impl->endingCommands.insert(command);
m_impl->scheduledCommands.erase(command);
command->End(true);
for (auto&& action : m_impl->interruptActions) {
action(*command, interruptor);
}
m_impl->endingCommands.erase(command);
m_impl->scheduledCommands.erase(command);
for (auto&& requirement : m_impl->requirements) {
if (requirement.second == command) {
m_impl->requirements.erase(requirement.first);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <networktables/NetworkTableInstance.h>

#include "CommandTestBase.h"
#include "frc2/command/FunctionalCommand.h"
#include "frc2/command/InstantCommand.h"
#include "frc2/command/RunCommand.h"

using namespace frc2;
class CommandScheduleTest : public CommandTestBase {};
Expand Down Expand Up @@ -95,6 +98,51 @@ TEST_F(CommandScheduleTest, SchedulerCancel) {
EXPECT_FALSE(scheduler.IsScheduled(&command));
}

TEST_F(CommandScheduleTest, CommandKnowsWhenItEnded) {
CommandScheduler scheduler = GetScheduler();

frc2::FunctionalCommand* commandPtr = nullptr;
auto command = frc2::FunctionalCommand(
[] {}, [] {},
[&](auto isForced) {
EXPECT_FALSE(scheduler.IsScheduled(commandPtr))
<< "Command shouldn't be scheduled when its end is called";
},
[] { return true; });
commandPtr = &command;

scheduler.Schedule(commandPtr);
scheduler.Run();
EXPECT_FALSE(scheduler.IsScheduled(commandPtr))
<< "Command should be removed from scheduler when its isFinished() "
"returns true";
}

TEST_F(CommandScheduleTest, ScheduleCommandInCommand) {
CommandScheduler scheduler = GetScheduler();
int counter = 0;
frc2::InstantCommand commandToGetScheduled{[&counter] { counter++; }};

auto command =
frc2::RunCommand([&counter, &scheduler, &commandToGetScheduled] {
scheduler.Schedule(&commandToGetScheduled);
EXPECT_EQ(counter, 1)
<< "Scheduled command's init was not run immediately "
"after getting scheduled";
});

scheduler.Schedule(&command);
scheduler.Run();
EXPECT_EQ(counter, 1) << "Command 2 was not run when it should have been";
EXPECT_TRUE(scheduler.IsScheduled(&commandToGetScheduled))
<< "Command 2 was not added to scheduler";

scheduler.Run();
EXPECT_EQ(counter, 1) << "Command 2 was run when it shouldn't have been";
EXPECT_FALSE(scheduler.IsScheduled(&commandToGetScheduled))
<< "Command 2 did not end when it should have";
}

TEST_F(CommandScheduleTest, NotScheduledCancel) {
CommandScheduler scheduler = GetScheduler();
MockCommand command;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,45 @@ TEST_F(SchedulingRecursionTest, CancelDefaultCommandFromEnd) {
EXPECT_TRUE(scheduler.IsScheduled(&other));
}

TEST_F(SchedulingRecursionTest, CancelNextCommandFromCommand) {
CommandScheduler scheduler = GetScheduler();

frc2::RunCommand* command1Ptr = nullptr;
frc2::RunCommand* command2Ptr = nullptr;
int counter = 0;

auto command1 = frc2::RunCommand([&counter, &command2Ptr, &scheduler] {
scheduler.Cancel(command2Ptr);
counter++;
});
auto command2 = frc2::RunCommand([&counter, &command1Ptr, &scheduler] {
scheduler.Cancel(command1Ptr);
counter++;
});

command1Ptr = &command1;
command2Ptr = &command2;

scheduler.Schedule(&command1);
scheduler.Schedule(&command2);
scheduler.Run();

EXPECT_EQ(counter, 1) << "Second command was run when it shouldn't have been";

// only one of the commands should be canceled.
EXPECT_FALSE(scheduler.IsScheduled(&command1) &&
scheduler.IsScheduled(&command2))
<< "Both commands are running when only one should be";
// one of the commands shouldn't be canceled because the other one is canceled
// first
EXPECT_TRUE(scheduler.IsScheduled(&command1) ||
scheduler.IsScheduled(&command2))
<< "Both commands are canceled when only one should be";

scheduler.Run();
EXPECT_EQ(counter, 2);
}

INSTANTIATE_TEST_SUITE_P(
SchedulingRecursionTests, SchedulingRecursionTest,
testing::Values(Command::InterruptionBehavior::kCancelSelf,
Expand Down

0 comments on commit 807dffe

Please sign in to comment.