Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion tests/framework/callbacks/test_empty_dataloader_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from torch.utils.data import DataLoader, Dataset

from torchtnt.framework._test_utils import Batch, DummyTrainUnit, get_dummy_train_state
from torchtnt.framework._test_utils import (
Batch,
DummyPredictUnit,
DummyTrainUnit,
get_dummy_train_state,
)
from torchtnt.framework.callbacks.empty_dataloader_detector import (
EmptyDataloaderDetectorCallback,
)
Expand Down Expand Up @@ -204,3 +209,19 @@ def __getitem__(self, idx: int) -> Batch:
)

self.assertEqual(callback_with_exception._consecutive_empty_train_epochs, 2)

def test_predict_empty_epoch_detection(self) -> None:
"""Test that empty predict epoch immediately raises an exception."""
callback = EmptyDataloaderDetectorCallback(threshold=2)
state = get_dummy_train_state()
unit = DummyPredictUnit(input_dim=2)

# Set predict progress to 0 steps
unit.predict_progress._num_steps_completed_in_prev_epoch = 0

# Empty predict epoch should immediately raise exception
with self.assertRaisesRegex(
RuntimeError,
"Empty predict epoch detected! Epoch completed 0 steps",
):
callback.on_predict_epoch_end(state, unit)
9 changes: 8 additions & 1 deletion torchtnt/framework/callbacks/empty_dataloader_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TTrainUnit
from torchtnt.framework.unit import TPredictUnit, TTrainUnit

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,3 +54,10 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
raise RuntimeError(error_msg)
else:
self._consecutive_empty_train_epochs = 0

def on_predict_epoch_end(self, state: State, unit: TPredictUnit) -> None:
num_steps_in_predict = unit.predict_progress.num_steps_completed_in_prev_epoch
if num_steps_in_predict == 0:
raise RuntimeError(
"Empty predict epoch detected! Epoch completed 0 steps. This could indicate not enough data in your input."
)
Loading