The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

Fitting tabnet with tidymodels

library(tabnet)
library(tidymodels)
library(modeldata)

In this vignette we show how to create a TabNet model using the tidymodels interface.

We are going to use the lending_club dataset available in the modeldata package.

First let’s split our dataset into training and testing so we can later access performance of our model:

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

We now define our pre-processing steps. Note that TabNet handles categorical variables, so we don’t need to do any kind of transformation to them. Normalizing the numeric variables is a good idea though.

rec <- recipe(Class ~ ., train) %>%
  step_normalize(all_numeric())

Next, we define our model. We are going to train for 50 epochs with a batch size of 128. There are other hyperparameters but, we are going to use the defaults.

mod <- tabnet(epochs = 50, batch_size = 128) %>%
  set_engine("torch", verbose = TRUE) %>%
  set_mode("classification")

We also define our workflow object:

wf <- workflow() %>%
  add_model(mod) %>%
  add_recipe(rec)

We can now define our cross-validation strategy:

folds <- vfold_cv(train, v = 5)

And finally, fit the model:

fit_rs <- wf %>%
  fit_resamples(folds)

After a few minutes we can get the results:

collect_metrics(fit_rs)
# A tibble: 2 x 5
  .metric  .estimator  mean     n  std_err
  <chr>    <chr>      <dbl> <int>    <dbl>
1 accuracy binary     0.946     5 0.000713
2 roc_auc  binary     0.732     5 0.00539 

And finally, we can verify the results in our test set:

model <- wf %>% fit(train)
test %>% 
  bind_cols(
    predict(model, test, type = "prob")
  ) %>% 
  roc_auc(Class, .pred_bad)
# A tibble: 1 x 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.710

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.