Custom train loop to perform partial batch updates #20319
Unanswered
leonardcaquot94
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there,
I'm working on a recursive data processing task in PyTorch Lightning, where the batch size decreases with each iteration as predictions are made until no more are needed. I want to address two key issues:
I think I need to change how training_step and validation_step work. Instead of iterating over a single batch until all recursive predictions are made, these methods should perform just one step at a time. After each step, completed predictions should be removed from the batch, while the remaining elements stay for further processing. This way, I can keep the batch size more consistent and improve efficiency.
Any idea that can help ?
Beta Was this translation helpful? Give feedback.
All reactions