Skip to content

Commit 836d801

Browse files
author
Matthew T. Warkentin
committed
Pull in changes from remote repo
2 parents 421890e + 2932183 commit 836d801

7 files changed

Lines changed: 20 additions & 8 deletions

File tree

.github/workflows/test-coverage-pak.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
RSPM: https://packagemanager.rstudio.com/cran/__linux__/bionic/latest
2424
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
2525
TORCH_INSTALL: 1
26+
TORCH_TEST: 1
2627
DEBIAN_FRONTEND: 'noninteractive'
2728

2829
steps:

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# luz (development version)
22

3+
* Fixed bug in CSV logger callback that was saving the logs as a space delimited file (#52, @mattwarkentin).
4+
* Fixed bug in the length of the progress bar for the validation dataset (#52, @mattwarkentin).
5+
* `ctx$data` now refers to the current in use `data` instead of always refering to `ctx$train_data`. (#54)
6+
37
# luz 0.1.0
48

59
* Added a `NEWS.md` file to track changes to the package.

R/callbacks.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ luz_callback_progress <- luz_callback(
176176
show_after <- if (getOption("luz.force_progress_bar", FALSE)) 0 else 0.2
177177

178178
format <- paste0(c(format, abbrevs), collapse = " - ")
179-
total <- if (ctx$training) length(ctx$data) else length(ctx$valid_data)
179+
total <- length(ctx$data) # ctx$data is the current dataset - can be the validation or training.
180180

181181
self$pb <- progress::progress_bar$new(
182182
force = getOption("luz.force_progress_bar", FALSE),

R/module.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ fit.luz_module_generator <- function(
199199
ctx$model$ctx <- ctx
200200

201201
ctx$optimizers <- optimizers
202-
ctx$data <- data
202+
203+
ctx$train_data <- data
203204
ctx$valid_data <- valid_data
204205

205206
if (length(epochs) == 1) epochs <- c(0, epochs)
@@ -227,8 +228,10 @@ fit.luz_module_generator <- function(
227228
for (epoch in seq_len(ctx$max_epochs)) {
228229
ctx$epoch <- epoch
229230
ctx$iter <- 0L
230-
ctx$call_callbacks("on_epoch_begin")
231231

232+
ctx$data <- ctx$train_data
233+
234+
ctx$call_callbacks("on_epoch_begin")
232235
ctx$call_callbacks("on_train_begin")
233236

234237
coro::loop(for (batch in ctx$data) {
@@ -244,11 +247,12 @@ fit.luz_module_generator <- function(
244247

245248
if (!is.null(ctx$valid_data)) {
246249

250+
ctx$data <- ctx$valid_data
247251
ctx$call_callbacks("on_valid_begin")
248252

249253
ctx$iter <- 0L
250254
torch::with_no_grad({
251-
coro::loop(for (batch in ctx$valid_data) {
255+
coro::loop(for (batch in ctx$data) {
252256
bind_batch_to_ctx(ctx, batch)
253257
ctx$iter <- ctx$iter + 1L
254258

man/ctx.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/ctx.Rmd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ The `ctx` object is used in luz to share information between the training loop a
1111
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
1212
| `optimizers` | A named list of optimizers used during training. |
1313
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
14-
| `data` | Dataloader passed to the `data` argument in `fit`. Modified to yield data in the selected device. |
14+
| `data` | Current in use dataloader. When training it's `ctx$train_data`, when doing validation its `ctx$valid_data`. It can also be the prediction dataset when in `predict`. |
15+
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
16+
| `train_data` | Dataloader passed to the `data` argument in `fit`. Modified to yield data in the selected device. |
1517
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
1618
| `valid_data` | Dataloader passed to the `valid_data` argument in `fit`. Modified to yield data in the selected device. |
1719
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

tests/testthat/test-callbacks.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ test_that("csv callback", {
182182
luz_callback_csv_logger(tmp)
183183
))
184184

185-
x <- read.table(tmp, header = TRUE)
185+
x <- read.table(tmp, header = TRUE, sep = ",")
186186
expect_equal(nrow(x), 5)
187187
expect_equal(names(x), c("epoch", "set", "loss"))
188188

@@ -192,7 +192,7 @@ test_that("csv callback", {
192192
luz_callback_csv_logger(tmp)
193193
))
194194

195-
x <- read.table(tmp, header = TRUE)
195+
x <- read.table(tmp, header = TRUE, sep = ",")
196196

197197
expect_equal(nrow(x), 10)
198198
expect_equal(names(x), c("epoch", "set", "loss"))

0 commit comments

Comments
 (0)