--- title: "Fitting tabnet with tidymodels" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Fitting tabnet with tidymodels} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = FALSE ) ``` ```{r setup} library(tabnet) library(tidymodels) library(modeldata) ``` In this vignette we show how to create a TabNet model using the tidymodels interface. We are going to use the `lending_club` dataset available in the `modeldata` package. First let's split our dataset into training and testing so we can later access performance of our model: ```{r} set.seed(123) data("lending_club", package = "modeldata") split <- initial_split(lending_club, strata = Class) train <- training(split) test <- testing(split) ``` We now define our pre-processing steps. Note that TabNet handles categorical variables, so we don't need to do any kind of transformation to them. Normalizing the numeric variables is a good idea though. ```{r} rec <- recipe(Class ~ ., train) %>% step_normalize(all_numeric()) ``` Next, we define our model. We are going to train for 50 epochs with a batch size of 128. There are other hyperparameters but, we are going to use the defaults. ```{r} mod <- tabnet(epochs = 50, batch_size = 128) %>% set_engine("torch", verbose = TRUE) %>% set_mode("classification") ``` We also define our `workflow` object: ```{r} wf <- workflow() %>% add_model(mod) %>% add_recipe(rec) ``` We can now define our cross-validation strategy: ```{r} folds <- vfold_cv(train, v = 5) ``` And finally, fit the model: ```{r} fit_rs <- wf %>% fit_resamples(folds) ``` After a few minutes we can get the results: ```{r} collect_metrics(fit_rs) ``` ``` # A tibble: 2 x 5 .metric .estimator mean n std_err 1 accuracy binary 0.946 5 0.000713 2 roc_auc binary 0.732 5 0.00539 ``` And finally, we can verify the results in our test set: ```{r} model <- wf %>% fit(train) test %>% bind_cols( predict(model, test, type = "prob") ) %>% roc_auc(Class, .pred_bad) ``` ``` # A tibble: 1 x 3 .metric .estimator .estimate 1 roc_auc binary 0.710 ```