Skip to content

Commit 6e0bb77

Browse files
authored
Merge pull request #55 from mattwarkentin/master
Provide support for minimum and maximum number of epochs
2 parents 2932183 + 25eda79 commit 6e0bb77

File tree

6 files changed

+72
-43
lines changed

6 files changed

+72
-43
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
* Fixed bug in CSV logger callback that was saving the logs as a space delimited file (#52, @mattwarkentin).
44
* Fixed bug in the length of the progress bar for the validation dataset (#52, @mattwarkentin).
55
* `ctx$data` now refers to the current in use `data` instead of always refering to `ctx$train_data`. (#54)
6+
* Allow users to provide the minimum and maximum number of epochs when calling `fit.luz_module_generator()`. Removed `ctx$epochs` from context object and replaced it with `ctx$min_epochs` and `ctx$max_epochs` (#53, @mattwarkentin)
7+
* Early stopping will now only occur if the minimum number of training epochs has been met (#53, @mattwarkentin)
68

79
# luz 0.1.0
810

R/callbacks.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ luz_callback_progress <- luz_callback(
9797
inform(sprintf(
9898
"Epoch %d/%d",
9999
as.integer(ctx$epoch),
100-
as.integer(ctx$epochs)
100+
as.integer(ctx$max_epochs)
101101
))
102102
},
103103
on_train_begin = function() {
@@ -356,7 +356,7 @@ monitor_metrics <- luz_callback(
356356
#'
357357
#' @note
358358
#' This callback adds a `on_early_stopping` callback that can be used to
359-
#' call callbacks after as soon as the model stopped training.
359+
#' call callbacks as soon as the model stops training.
360360
#'
361361
#' @note
362362
#' If `verbose=TRUE` in [fit.luz_module_generator()] a message is printed when
@@ -409,13 +409,16 @@ luz_callback_early_stopping <- luz_callback(
409409
self$patience_counter <- self$patience_counter + 1L
410410
}
411411

412-
if (self$patience_counter >= self$patience) {
412+
if (self$patience_counter >= self$patience &
413+
ctx$epoch >= ctx$min_epochs) {
413414
rlang::signal("Early stopping", class = "early_stopping")
414415
}
415416

416417
},
417418
on_early_stopping = function() {
418-
inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}"))
419+
inform(
420+
glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$max_epochs}")
421+
)
419422
}
420423
)
421424

R/module.R

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,33 @@ get_opt_hparams <- function(module) {
117117
#' @param object An `nn_module` that has been [setup()].
118118
#'
119119
#' @param data (dataloader) A dataloader created with [torch::dataloader()] used
120-
#' for training the model. The dataloader must return a list with at most 2 items.
121-
#' The first item will be used as input for the module and the second will be used
122-
#' as target for the loss function.
120+
#' for training the model. The dataloader must return a list with at most 2
121+
#' items. The first item will be used as input for the module and the second
122+
#' will be used as target for the loss function.
123123
#'
124-
#' @param epochs (int) The number of epochs for training the model.
124+
#' @param epochs (int) The maximum number of epochs for training the model.
125+
#' If a single value is provided, this is taken to be the `max_epochs` and
126+
#' `min_epochs` is set to 0. If a vector of two numbers is provided, the
127+
#' first value is `min_epochs` and the second value is `max_epochs`.
128+
#' The minimum and maximum number of epochs are included in the context
129+
#' object as `ctx$min_epochs` and `ctx$max_epochs`, respectively.
125130
#'
126-
#' @param callbacks (list, optional) A list of callbacks defined with [luz_callback()] that
127-
#' will be called during the training procedure. The callbacks [luz_callback_metrics()],
128-
#' [luz_callback_progress()] and [luz_callback_train_valid()] are always added by default.
131+
#' @param callbacks (list, optional) A list of callbacks defined with
132+
#' [luz_callback()] that will be called during the training procedure. The
133+
#' callbacks [luz_callback_metrics()], [luz_callback_progress()] and
134+
#' [luz_callback_train_valid()] are always added by default.
129135
#'
130-
#' @param valid_data (dataloader, optional) A dataloader created with [torch::dataloader()]
131-
#' that will be used during the validation procedure.
136+
#' @param valid_data (dataloader, optional) A dataloader created with
137+
#' [torch::dataloader()] that will be used during the validation procedure.
132138
#'
133-
#' @param accelerator (accelerator, optional) An optional [accelerator()] object used
134-
#' to configure device placement of the components like [nn_module]s, optimizers
135-
#' and batches of data.
139+
#' @param accelerator (accelerator, optional) An optional [accelerator()] object
140+
#' used to configure device placement of the components like [nn_module]s,
141+
#' optimizers and batches of data.
136142
#'
137-
#' @param verbose (logical, optional) An optional boolean value indicating if the
138-
#' fitting procedure should emmit output to the console during training. By default,
139-
#' it will produce output if [interactive()] is `TRUE`, otherwise it won't print
140-
#' to the console.
143+
#' @param verbose (logical, optional) An optional boolean value indicating if
144+
#' the fitting procedure should emmit output to the console during training.
145+
#' By default, it will produce output if [interactive()] is `TRUE`, otherwise
146+
#' it won't print to the console.
141147
#'
142148
#' @param ... Currently unused,
143149
#'
@@ -147,9 +153,16 @@ get_opt_hparams <- function(module) {
147153
#'
148154
#' @importFrom generics fit
149155
#' @export
150-
fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL,
151-
valid_data = NULL, accelerator = NULL,
152-
verbose = NULL, ...) {
156+
fit.luz_module_generator <- function(
157+
object,
158+
data,
159+
epochs = 10,
160+
callbacks = NULL,
161+
valid_data = NULL,
162+
accelerator = NULL,
163+
verbose = NULL,
164+
...
165+
) {
153166

154167
module <- object
155168
ellipsis::check_dots_empty()
@@ -190,7 +203,10 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
190203
ctx$train_data <- data
191204
ctx$valid_data <- valid_data
192205

193-
ctx$epochs <- epochs
206+
if (length(epochs) == 1) epochs <- c(0, epochs)
207+
ctx$min_epochs <- epochs[[1]]
208+
ctx$max_epochs <- epochs[[2]]
209+
194210
callbacks <- append(default_callbacks(), callbacks)
195211
ctx$callbacks <- initialize_callbacks(callbacks, ctx)
196212

@@ -209,7 +225,7 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
209225
rlang::with_handlers(
210226
!!! ctx$handlers,
211227
.expr = {
212-
for (epoch in seq_len(ctx$epochs)) {
228+
for (epoch in seq_len(ctx$max_epochs)) {
213229
ctx$epoch <- epoch
214230
ctx$iter <- 0L
215231

man/fit.luz_module_generator.Rd

Lines changed: 22 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/luz_callback_early_stopping.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/ctx.Rmd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ The `ctx` object is used in luz to share information between the training loop a
1717
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
1818
| `valid_data` | Dataloader passed to the `valid_data` argument in `fit`. Modified to yield data in the selected device. |
1919
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
20-
| `epochs` | Total number of epochs the model will be trained on. |
20+
| `min_epochs` | Minimum number of epochs the model will be trained for. |
21+
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
22+
| `max_epochs` | Maximum number of epochs the model will be trained for. |
2123
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
2224
| `epoch` | Current training epoch. |
2325
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

0 commit comments

Comments
 (0)