--- title: "Training a Tabnet model from missing-values dataset" author: "Christophe Regouby" date: "`r Sys.Date()`" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Training a Tabnet model from missing-values dataset} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = FALSE ) ``` # Motivation Real-life training dataset usually contains missing data. The vast majority of deep-learning networks do not handle missing data and thus either stop or crash when values are missing in the predictors. But Tabnet use a masking mechanism that we can reuse to cover the missing data in the training set. As we enter the world of missing-data, we have to question the type of missing-data we deal with. We could have missing data at random (MAR), like for example some transmission errors on a sensor data dataset, or missing not at random (MNAR) when some interactions exists between the missing data and other predictors values for the same sample. The later is a more complex topic to cover, and we will try to investigate it here through the `ames` dataset. # Missing-data dataset creation ## Ames missings understanding The `ames` dataset from `{modeldata}` contains a lot of _null values_ that the human analysis clearly understand as an implicit _"missing object"_ described by that value. We have for example pool surface of 0 square meters means "no pool", basement surface of 0 square meters means "no basement", ... Many of those variables can be detected visually by inspecting the distribution of the values like, for example, the `Masonry veneer area` predictor : ```{r} library(tidymodels, quietly = TRUE) library(tabnet) data("ames", package = "modeldata") qplot(ames$Mas_Vnr_Area) ``` ![ames variable `Mas_Vnr_Area` histogram showing high occurrence of value zero](ames_mas_vnr_hist.png) We know that it will be extremely difficult for a model to capture an internal representation of such distribution, and thus we want to avoid the null values to penalize the model internal representation. ## While keeping some room for freedom Many of those variables come as a pair in the `ames` dataset, one for the qualitative aspect, the other for the quantitative aspect. We have for example `Pool_QC` for pool condition, that has a "no_pool" level with `Pool_Area=0` in that case. As human, we have the intuition that knowing if a pool is present is important for the modeling task. So we want the model to get an internal representation of the implicit `has_pool=FALSE` without having it explicit in the dataset. In order to do so, we have to let the model some freedom to infer the "no_pool" state and thus we should not mutate both variables in the pair `Pool_Area=NA` and `Pool_QC=NA` at the same time. ## Ames with missing data Let's turn those missing objects data explicitly into `NAs` in an new `ames_missing` dataset : A quick and dirty way to achieve this on numerical predictors is to `na_if()` zeros on any column which name is related to surface and area. Then, according to the keep room for freedom rule, do it carefully on the matching categorical predictors ```{r} col_with_zero_as_na <- ames %>% select(where(is.numeric)) %>% select(matches("_SF|Area|Misc_Val|[Pp]orch$")) %>% summarise_each(min) %>% select_if(~.x==0) %>% names() ames_missing <- ames %>% mutate_at(col_with_zero_as_na, na_if, 0) %>% mutate_at("Alley", na_if, "No_Alley_Access") %>% mutate_at("Fence", na_if, "No_Fence") %>% mutate_at(c("Garage_Cond", "Garage_Finish"), na_if, "No_Garage") %>% mutate_at(c("Bsmt_Exposure", "BsmtFin_Type_1", "BsmtFin_Type_2"), na_if, "No_Basement") visdat::vis_miss(ames_missing) ``` ![ames missing values visualization showing few variables with more than 90% missingness with a global 13% missing](vis_miss_ames.png) We can see here that variable are not missing at random, and thus we can expect the model to capture the missingness relation during the pretraining phase. Note: A better way to achieve proper value mutation to explicit NAs would be to also check if the qualitative column in the pair refers to `none` or to zero occurrence of the equipment. But this is beyond the scope of this vignette. # Model pretraining Let's pretrain one model for each of those dataset, and analyze variable importance that emerge after the unsupervised representation learning step: ## Variable importance with raw `ames` dataset ```{r} ames_rec <- recipe(Sale_Price ~ ., data=ames) %>% step_normalize(all_numeric()) cat_emb_dim <- map_dbl(ames %>% select_if(is.factor), ~log2(nlevels(.x)) %>% round) ames_pretrain <- tabnet_pretrain(ames_rec, data=ames, epoch=50, cat_emb_dim = cat_emb_dim, valid_split = 0.2, verbose=TRUE, batch=2930, early_stopping_patience = 3L, early_stopping_tolerance = 1e-4) autoplot(ames_pretrain) ``` ``` [Epoch 001] Loss: 43.708794 Valid loss: 8066126.500000 [Epoch 002] Loss: 31.463089 Valid loss: 5631984.000000 [Epoch 003] Loss: 23.396217 Valid loss: 3901085.500000 [Epoch 004] Loss: 19.241619 Valid loss: 2947481.750000 [Epoch 005] Loss: 15.032537 Valid loss: 2250338.000000 [Epoch 006] Loss: 12.991020 Valid loss: 1815583.125000 [Epoch 007] Loss: 11.044646 Valid loss: 1533597.875000 [Epoch 008] Loss: 9.114124 Valid loss: 1395840.000000 [Epoch 009] Loss: 8.362211 Valid loss: 1258169.375000 [Epoch 010] Loss: 7.549719 Valid loss: 1064599.500000 [Epoch 011] Loss: 6.808529 Valid loss: 998335.625000 [Epoch 012] Loss: 6.569450 Valid loss: 1047418.500000 [Epoch 013] Loss: 6.606429 Valid loss: 1048583.625000 [Epoch 014] Loss: 6.742617 Valid loss: 993241.312500 [Epoch 015] Loss: 6.806847 Valid loss: 995705.875000 [Epoch 016] Loss: 6.618536 Valid loss: 1026789.625000 [Epoch 017] Loss: 6.593469 Valid loss: 1033726.437500 Early stopping at epoch 017 ``` ![ames_fit model training diagnostic plot](ames_pretrain.png) Now we capture the columns with missings, and create a convenience function to color the `vip::vip()` plot output according to the missingness quality of the column ```{r} col_with_missings <- ames_missing %>% summarise_all(~sum(is.na(.))>0) %>% t %>% enframe(name="Variable") %>% rename(has_missing="value") vip_color <- function(object, col_has_missing) { vip_data <- vip::vip(object)$data %>% arrange(Importance) vis_miss_plus <- left_join(vip_data, col_has_missing , by="Variable") %>% mutate(Variable=factor(Variable, levels = vip_data$Variable)) vis_miss_plus ggplot(vis_miss_plus, aes(x=Variable, y=Importance, fill=has_missing)) + geom_col() + coord_flip() + scale_fill_grey() } vip_color(ames_pretrain, col_with_missings) ``` ![ames_fit model variable importance plot](ames_pretrain_vip.png) We get `BsmtFin_Type_1`, `BsmtFin_SF_1` and `Bsmt_Exposure` variables in the top ten important variables according to this run of pretraining the model. Those variables has been screened as having few missing values. Note that this result varies a lot from run to run. The captured result here depends a lot on your initialization conditions. ## Variable importance with `ames_missing` dataset Let's pretrain a new model with the same hyperparameter, but now using the `ames_missing` dataset. In order to compensate the 13% missingness already present in the `ames_missing` dataset, we adjust the `pretraining_ratio` parameter to `0.5 - 0.13 = 0.37` ```{r} ames_missing_rec <- recipe(Sale_Price ~ ., data=ames_missing) %>% step_normalize(all_numeric()) ames_missing_pretrain <- tabnet_pretrain(ames_missing_rec, data=ames_missing, epoch=50, cat_emb_dim = cat_emb_dim, valid_split = 0.2, verbose=TRUE, batch=2930, pretraining_ratio=0.37, early_stopping_patience = 3L, early_stopping_tolerance = 1e-4) autoplot(ames_missing_pretrain) vip_color(ames_missing_pretrain, col_with_missings) ``` ``` [Epoch 001] Loss: 56.250610 Valid loss: 40321308.000000 [Epoch 002] Loss: 44.254524 Valid loss: 39138240.000000 [Epoch 003] Loss: 33.992207 Valid loss: 38648800.000000 [Epoch 004] Loss: 26.421488 Valid loss: 37445656.000000 [Epoch 005] Loss: 22.290133 Valid loss: 35814052.000000 ... [Epoch 021] Loss: 10.877335 Valid loss: 20903176.000000 [Epoch 022] Loss: 11.023649 Valid loss: 20772972.000000 [Epoch 023] Loss: 10.819239 Valid loss: 20642806.000000 [Epoch 024] Loss: 10.994371 Valid loss: 20575458.000000 [Epoch 025] Loss: 10.700000 Valid loss: 20449918.000000 [Epoch 026] Loss: 10.902529 Valid loss: 20680102.000000 [Epoch 027] Loss: 10.791571 Valid loss: 20849496.000000 [Epoch 028] Loss: 11.102308 Valid loss: 20995910.000000 Early stopping at epoch 028 ``` ![ames_missing_pretrain model training diagnostic plot](ames_missing_pretrain.png) ![ames_missing_pretrain model variable importance plot](ames_missing_pretrain_vip.png) We can see here no variables with high missingness is present in the top 10 important variables. This seems to be a good sign of the model having captured proper interactions between variables. # Model training ## Variable importance with raw `ames` dataset ```{r} ames_fit <- tabnet_pretrain(ames_rec, data=ames, tabnet_model = ames_pretrain, epoch=50, cat_emb_dim = cat_emb_dim, valid_split = 0.2, verbose=TRUE, batch=2930, early_stopping_patience = 5L, early_stopping_tolerance = 1e-4) autoplot(ames_fit) vip_color(ames_fit, col_with_missings) ``` ![ames_fit model training diagnostic plot](ames_fit.png) ![ames_fit model training variable importance plot](ames_fit_vip_.png) Here again, the model uses two predictors `BasmFin_SF_2` and `Garage_Finish` that have respectively 88 % and 5 % missingness. ## Variable importance with `ames_missing` dataset ```{r} ames_missing_fit <- tabnet_pretrain(ames_rec, data=ames_missing, tabnet_model = ames_missing_pretrain, epoch=50, cat_emb_dim = cat_emb_dim, valid_split = 0.2, verbose=TRUE, batch=2930, early_stopping_patience = 5L, early_stopping_tolerance = 1e-4) autoplot(ames_missing_fit) vip_color(ames_missing_fit, col_with_missings) ``` ![ames_fit model training diagnostic plot](ames_missing_fit.png) ![ames_fit model training variable importance plot](ames_missing_fit_vip.png) Here we can see one predictors `Garage_Area` with 5 % missingness in the top 10. # Conclusion Even if the models have a huge variability in the variable importance among different training, we have the intuition that model trained with explicit missing data will provide better result than its counterpart trained with zero-imputed variables. In any case, having the capability to pretrain and fit TabNet models with MAR dataset or MNAR dataset is of high convenience for the real-life use-cases.