DICErClust

Deep Significance Clustering for Clinical Risk Stratification

DICErClust provides an R implementation of the DICE (Deep Significance Clustering) algorithm, a self-supervised deep learning framework that identifies clinically meaningful, risk-stratified patient subgroups from electronic health record (EHR) data.

DICE jointly optimises four objectives — autoencoder reconstruction, cluster cohesion, binary outcome prediction, and a likelihood-ratio test (LRT) significance penalty — so that the discovered clusters are both data-driven and statistically validated.

Huang Y, Du C, Zhu F, et al. (2021). Self-supervised deep clustering of patient subgroups for heart failure with preserved ejection fraction. J Am Med Inform Assoc, 28, 2394–2403. doi:10.1093/jamia/ocab203


Installation

DICErClust is distributed as a source tarball. Install it once with install.packages(), then load it like any other R package.

install.packages(
  "/path/to/DICErClust_0.1.1.tar.gz",
  repos = NULL, type = "source"
)

DICErClust depends on the torch package for R. Install torch and its runtime binaries with:

install.packages("torch")
torch::install_torch()   # downloads the LibTorch runtime (~500 MB, once only)

Minimal working example

library(DICErClust)

## ── 1. Prepare data ─────────────────────────────────────────────────────────
## DICEr() reads RDS files, each a length-3 list:
##   [[1]] data_x  — numeric matrix n × p (continuous features, LSTM input)
##   [[2]] data_v  — numeric matrix n × q (binary demographics, outcome head)
##   [[3]] data_y  — integer vector 0/1 outcome

set.seed(1); n <- 150L; p <- 5L; q <- 3L
data_dir <- file.path(tempdir(), "dice_demo")
dir.create(data_dir, showWarnings = FALSE)

saveRDS(list(matrix(runif(n * p), n, p),
             matrix(as.numeric(rbinom(n * q, 1, 0.5)), n, q),
             rbinom(n, 1, 0.3)),
        file.path(data_dir, "train.rds"))
saveRDS(list(matrix(runif(50L * p), 50L, p),
             matrix(as.numeric(rbinom(50L * q, 1, 0.5)), 50L, q),
             rbinom(50L, 1, 0.3)),
        file.path(data_dir, "test.rds"))

## ── 2. Configure and train ──────────────────────────────────────────────────
args <- list(
  seed = 1L, input_path = data_dir,
  filename_train = "train.rds", filename_test = "test.rds",
  n_input_fea = p, n_hidden_fea = 3L,
  lstm_layer = 1L, lstm_dropout = 0.0, K_clusters = 2L,
  n_dummy_demov_fea = q, cuda = FALSE, lr = 1e-4,
  init_AE_epoch = 5L, iter = 20L, epoch_in_iter = 2L,
  lambda_AE = 1.0, lambda_classifier = 1.0,
  lambda_outcome = 1.0, lambda_p_value = 1.0
)

old_wd <- setwd(tempdir())
DICEr(args)          # writes output to hn_3_K_2/part2_AE_nhidden_3/
setwd(old_wd)

## ── 3. Load best checkpoint ─────────────────────────────────────────────────
part2_dir <- file.path(tempdir(), "hn_3_K_2", "part2_AE_nhidden_3")
res_train <- readRDS(file.path(part2_dir, "data_train_iter.rds"))
res_test  <- readRDS(file.path(part2_dir, "data_test_iter.rds"))

## Use res_test$pred_C (not $C) for test-set cluster labels
table(res_test$pred_C)

Vignettes

Vignette Description
Introduction to DICErClust Package overview, data format, quick start, hyperparameter guide
Heart Failure Example Full end-to-end analysis on the UCI Heart Failure dataset (AUC = 0.823, χ² = 32.99, p < 0.001)

To browse vignettes after installation:

vignette("DICEr-introduction",      package = "DICErClust")
vignette("heart-failure-example",   package = "DICErClust")

How it works

DICEr runs three sequential stages:

  1. Autoencoder warm-up (init_AE_epoch epochs): an LSTM encoder–decoder is pre-trained on data_x to learn compact latent representations before any clustering begins.

  2. Joint optimisation (iter iterations × epoch_in_iter epochs): k-means clustering alternates with gradient updates that minimise:

    L = λ_AE · L_recon  +  λ_clf · L_cluster  +  λ_out · L_outcome  +  λ_p · L_LRT

    where L_LRT = 3.841 − G penalises cluster configurations that do not reach χ²(1) significance at α = 0.05.

  3. Model selection: the checkpoint with the lowest test negative log-likelihood that also satisfies LRT p < 0.05 is saved.


Citation

@article{huang2021self,
  title   = {Self-supervised deep clustering of patient subgroups for
             heart failure with preserved ejection fraction},
  author  = {Huang, Yiye and Du, Changchun and Zhu, Fan and others},
  journal = {Journal of the American Medical Informatics Association},
  volume  = {28},
  pages   = {2394--2403},
  year    = {2021},
  doi     = {10.1093/jamia/ocab203}
}

License

MIT © Sarah Ayton and Yiye Zhang, Weill Cornell Medicine