From ec7b5c4609d4381487ec21748af05e759fed4f63 Mon Sep 17 00:00:00 2001 From: Felipe Carlos Date: Fri, 15 Nov 2024 23:40:35 -0300 Subject: [PATCH] replace probability fractions with NA in classification results --- R/api_classify.R | 7 +-- tests/testthat/test-classification.R | 74 ++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/R/api_classify.R b/R/api_classify.R index e9c238758..74fc88256 100755 --- a/R/api_classify.R +++ b/R/api_classify.R @@ -106,8 +106,6 @@ # Should bbox of resulting tile be updated? update_bbox <- nrow(chunks) != nchunks } - # Compute fractions probability - probs_fractions <- 1 / length(.ml_labels(ml_model)) # Process jobs in parallel block_files <- .jobs_map_parallel_chr(chunks, function(chunk) { # Job block @@ -171,10 +169,9 @@ scale <- .scale(band_conf) if (.has(scale) && scale != 1) { values <- values / scale - probs_fractions <- probs_fractions / scale } - # Mask NA pixels with same probabilities for all classes - values[na_mask, ] <- probs_fractions + # Put NA back in the result + values[na_mask, ] <- NA # Log .debug_log( event = "start_block_data_save", diff --git a/tests/testthat/test-classification.R b/tests/testthat/test-classification.R index 4256d9bbe..3657818f4 100644 --- a/tests/testthat/test-classification.R +++ b/tests/testthat/test-classification.R @@ -56,3 +56,77 @@ test_that("Classify error bands 1", { ) ) }) + +test_that("Classify with NA values", { + # load cube + data_dir <- system.file("extdata/raster/mod13q1", package = "sits") + raster_cube <- sits_cube( + source = "BDC", + collection = "MOD13Q1-6.1", + data_dir = data_dir, + tiles = "012010", + bands = "NDVI", + start_date = "2013-09-14", + end_date = "2014-08-29", + multicores = 2, + progress = FALSE + ) + # preparation - create directory to save NA + data_dir <- paste0(tempdir(), "/na-cube") + dir.create(data_dir, recursive = TRUE, showWarnings = FALSE) + # preparation - insert NA in cube + raster_cube <- sits_apply( + data = raster_cube, + NDVI_NA = ifelse(NDVI > 0.5, NA, NDVI), + output_dir = data_dir + ) + raster_cube <- sits_select(raster_cube, bands = "NDVI_NA") + .fi(raster_cube) <- .fi(raster_cube) |> + dplyr::mutate(band = "NDVI") + # preparation - create a random forest model + rfor_model <- sits_train(samples_modis_ndvi, sits_rfor(num_trees = 40)) + # test classification with NA + class_map <- sits_classify( + data = raster_cube, + ml_model = rfor_model, + output_dir = tempdir(), + progress = FALSE + ) + class_map_rst <- terra::rast(class_map[["file_info"]][[1]][["path"]]) + expect_true(anyNA(class_map_rst[])) +}) + +test_that("Classify with exclusion mask", { + # load cube + data_dir <- system.file("extdata/raster/mod13q1", package = "sits") + raster_cube <- sits_cube( + source = "BDC", + collection = "MOD13Q1-6.1", + data_dir = data_dir, + tiles = "012010", + bands = "NDVI", + start_date = "2013-09-14", + end_date = "2014-08-29", + multicores = 2, + progress = FALSE + ) + # preparation - create a random forest model + rfor_model <- sits_train(samples_modis_ndvi, sits_rfor(num_trees = 40)) + # test classification with NA + class_map <- suppressWarnings( + sits_classify( + data = raster_cube, + ml_model = rfor_model, + output_dir = tempdir(), + exclusion_mask = c( + xmin = -55.63478, + ymin = -11.63328, + xmax = -55.54080, + ymax = -11.56978 + ), + progress = FALSE + ) + ) + class_map_rst <- terra::rast(class_map[["file_info"]][[1]][["path"]]) + expect_true(anyNA(class_map_rst[])) +})