-
Notifications
You must be signed in to change notification settings - Fork 14
Closed
Labels
Milestone
Description
Would it be possible to add support for case weights in TabNet?
This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.
(I will probably upsample the minority class in the meantime as an alternative approach.)
This would be the desired workflow:
library(tabnet)
library(tidymodels)
library(modeldata)
set.seed(123)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
set.seed(1)
tab_mod <- tabnet(epochs = 10) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
#> Error in `check_case_weights()`:
#> ! Case weights are not enabled by the underlying model implementation.
#> Backtrace:
#> ▆
#> 1. ├─generics::fit(tab_wf, train)
#> 2. └─workflows:::fit.workflow(tab_wf, train)
#> 3. └─workflows::.fit_model(workflow, control)
#> 4. ├─generics::fit(action_model, workflow = workflow, control = control)
#> 5. └─workflows:::fit.action_model(...)
#> 6. └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#> 7. ├─generics::fit_xy(...)
#> 8. └─parsnip::fit_xy.model_spec(...)
#> 9. └─parsnip:::check_case_weights(case_weights, object)
#> 10. └─rlang::abort("Case weights are not enabled by the underlying model implementation.")
Created on 2024-01-12 with reprex v2.0.2