The CISS-VAE model can handle binary and categorical variables, but categorical variables must first be converted into binary dummy variables.
The Palmer
Penguins dataset has both continuous (bill_length, bill_depth,
flipper_length, body_mass) and categorical (species, island, sex) values
so it makes a good example for this. We can use the
dummy_cols() function from the fastDummies
package to create dummy variables for our categories. Set
ignore_na = TRUE and
remove_selected_columns = TRUE to avoid creating a new
column for NA values and to remove the original categoricals once the
dummies are created.
library(tidyverse)
library(kableExtra)
library(reticulate)
library(rCISSVAE)
library(fastDummies)
library(palmerpenguins)##
## Attaching package: 'palmerpenguins'
## The following objects are masked from 'package:datasets':
##
## penguins, penguins_raw
data(package = 'palmerpenguins')
penguins_clean = na.omit(penguins)%>%
select(year, everything()) ## removing existing incomplete rows for illustration purposes
glue::glue("Dimensions: {paste0(dim(penguins), collapse = ',')}")## Dimensions: 344,8
| year | species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex |
|---|---|---|---|---|---|---|---|
| 2007 | Adelie | Torgersen | 39.1 | 18.7 | 181 | 3750 | male |
| 2007 | Adelie | Torgersen | 39.5 | 17.4 | 186 | 3800 | female |
| 2007 | Adelie | Torgersen | 40.3 | 18.0 | 195 | 3250 | female |
| 2007 | Adelie | Torgersen | 36.7 | 19.3 | 193 | 3450 | female |
| 2007 | Adelie | Torgersen | 39.3 | 20.6 | 190 | 3650 | male |
| 2007 | Adelie | Torgersen | 38.9 | 17.8 | 181 | 3625 | female |
## create penguins_missing
n <- nrow(penguins_clean)
p <- ncol(penguins_clean)
m <- floor(0.20 * n * p) # number of cells to mask
idx <- sample.int(n * p, m) # positions in a logical matrix
mask <- matrix(FALSE, nrow = n, ncol = p)
mask[idx] <- TRUE
penguins_missing <- penguins_clean
## anything can be missing except the year
for (j in seq(2, p, 1)) {
penguins_missing[[j]][mask[, j]] <- NaN
}
# quick check of missingness rate
glue::glue("\nMissingness proportion of penguins_missing: {round(mean(is.na(as.matrix(penguins_missing))), 2)}") ## Missingness proportion of penguins_missing: 0.17
## create dummy vars
penguin_dummies_complete = penguins_clean %>%
dummy_cols(select_columns = c("species", "island", "sex"),
ignore_na = TRUE,
remove_first_dummy = TRUE,
remove_selected_columns = TRUE)
penguin_dummies = penguins_missing %>%
dummy_cols(select_columns = c("species", "island", "sex"),
ignore_na = TRUE,
remove_first_dummy = TRUE,
remove_selected_columns = TRUE)
head(penguin_dummies) %>% kable()| year | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | species_Chinstrap | species_Gentoo | island_Dream | island_Torgersen | sex_male |
|---|---|---|---|---|---|---|---|---|---|
| 2007 | 39.1 | NaN | 181 | 3750 | 0 | 0 | 0 | 1 | 1 |
| 2007 | 39.5 | 17.4 | 186 | 3800 | NA | NA | NA | NA | 0 |
| 2007 | 40.3 | 18.0 | 195 | 3250 | 0 | 0 | 0 | 1 | 0 |
| 2007 | NaN | NaN | 193 | 3450 | NA | NA | NA | NA | 0 |
| 2007 | 39.3 | NaN | 190 | 3650 | 0 | 0 | NA | NA | NA |
| 2007 | 38.9 | NaN | NaN | NaN | 0 | 0 | 0 | 1 | 0 |
Now that the dummy vars are created and there is missingness, we can
create a binary_feature_mask and impute with
run_cissvae().
binary_feature_mask = c(rep(FALSE, 5), rep(TRUE, 5))
glue::glue("Binary Feature Mask: {paste0(binary_feature_mask, collapse = ', ')}")
results = run_cissvae(
data = penguin_dummies,
val_proportion = 0.20, ## small dataset so using higher val proportion
columns_ignore = "year",
binary_feature_mask = binary_feature_mask,
clusters = NULL,
n_clusters = 1,
scale_features = TRUE,
epochs = 500,
debug = FALSE
)
head(results$imputed_dataset)
head(penguin_dummies)## year bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
## 0 2007 39.10000 20.1713 181 3750
## 1 2007 39.50000 17.4000 186 3800
## 2 2007 40.30000 18.0000 195 3250
## 3 2007 36.70000 19.3000 193 3450
## 4 2007 40.38935 20.6000 190 3650
## 5 2007 38.90000 17.8000 181 3625
## species_Chinstrap species_Gentoo island_Dream island_Torgersen sex_male
## 0 2.210167e-15 1.156993e-15 1.339994e-03 6.544447e-07 1.0000000000
## 1 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.0000000000
## 2 0.000000e+00 0.000000e+00 4.124794e-09 4.522174e-01 0.0000000000
## 3 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.0002707173
## 4 3.508456e-16 7.418369e-16 8.792417e-04 1.695923e-06 0.9999980927
## 5 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.0000000000
## # A tibble: 6 × 10
## year bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
## <int> <dbl> <dbl> <dbl> <dbl>
## 1 2007 39.1 NaN 181 3750
## 2 2007 39.5 17.4 186 3800
## 3 2007 40.3 18 195 3250
## 4 2007 NaN NaN 193 3450
## 5 2007 39.3 NaN 190 3650
## 6 2007 38.9 NaN NaN NaN
## # ℹ 5 more variables: species_Chinstrap <int>, species_Gentoo <int>,
## # island_Dream <int>, island_Torgersen <int>, sex_male <int>
As we can see above, the imputed values for the binary variables are in terms of probability, not a flat 0,1 so we have to convert those values to binary. The ‘imputed_dataset’ is returned as a data.frame, so we can use tidyverse mutate to convert the binary variables.
results$imputed_dataset <- results$imputed_dataset %>%
mutate(across(
.cols = matches("species|island|sex"),
.fns = ~ case_when(
.x > 0.5 ~ 1,
.x <= 0.5 ~ 0,
TRUE ~ .x
)
))
head(results$imputed_dataset)
head(penguin_dummies)
head(penguin_dummies_complete)## year bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
## 0 2007 39.10000 20.1713 181 3750
## 1 2007 39.50000 17.4000 186 3800
## 2 2007 40.30000 18.0000 195 3250
## 3 2007 36.70000 19.3000 193 3450
## 4 2007 40.38935 20.6000 190 3650
## 5 2007 38.90000 17.8000 181 3625
## species_Chinstrap species_Gentoo island_Dream island_Torgersen sex_male
## 0 0 0 0 0 1
## 1 0 0 0 1 0
## 2 0 0 0 0 0
## 3 0 0 0 1 0
## 4 0 0 0 0 1
## 5 0 0 0 1 0
## # A tibble: 6 × 10
## year bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
## <int> <dbl> <dbl> <dbl> <dbl>
## 1 2007 39.1 NaN 181 3750
## 2 2007 39.5 17.4 186 3800
## 3 2007 40.3 18 195 3250
## 4 2007 NaN NaN 193 3450
## 5 2007 39.3 NaN 190 3650
## 6 2007 38.9 NaN NaN NaN
## # ℹ 5 more variables: species_Chinstrap <int>, species_Gentoo <int>,
## # island_Dream <int>, island_Torgersen <int>, sex_male <int>
## # A tibble: 6 × 10
## year bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
## <int> <dbl> <dbl> <int> <int>
## 1 2007 39.1 18.7 181 3750
## 2 2007 39.5 17.4 186 3800
## 3 2007 40.3 18 195 3250
## 4 2007 36.7 19.3 193 3450
## 5 2007 39.3 20.6 190 3650
## 6 2007 38.9 17.8 181 3625
## # ℹ 5 more variables: species_Chinstrap <int>, species_Gentoo <int>,
## # island_Dream <int>, island_Torgersen <int>, sex_male <int>