Skip to content

Feature request for case-weights #145

@cgoo4

Description

@cgoo4

Would it be possible to add support for case weights in TabNet?

This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.

(I will probably upsample the minority class in the meantime as an alternative approach.)

This would be the desired workflow:

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

set.seed(1)

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Error in `check_case_weights()`:
#> ! Case weights are not enabled by the underlying model implementation.
#> Backtrace:
#>      ▆
#>   1. ├─generics::fit(tab_wf, train)
#>   2. └─workflows:::fit.workflow(tab_wf, train)
#>   3.   └─workflows::.fit_model(workflow, control)
#>   4.     ├─generics::fit(action_model, workflow = workflow, control = control)
#>   5.     └─workflows:::fit.action_model(...)
#>   6.       └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#>   7.         ├─generics::fit_xy(...)
#>   8.         └─parsnip::fit_xy.model_spec(...)
#>   9.           └─parsnip:::check_case_weights(case_weights, object)
#>  10.             └─rlang::abort("Case weights are not enabled by the underlying model implementation.")

Created on 2024-01-12 with reprex v2.0.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions