---
title: "DICErClust Illustrated Example: Heart Failure Risk Stratification"
author: "Sarah Ayton and Yiye Zhang"
date: "`r Sys.Date()`"
output:
  rmarkdown::html_vignette:
    toc: true
    toc_depth: 3
vignette: >
  %\VignetteIndexEntry{DICErClust Illustrated Example: Heart Failure Risk Stratification}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
## All computational chunks are skipped on CRAN: this vignette downloads
## external data and runs DICEr() (~15-20 min on CPU), both of which are
## incompatible with CRAN's check environment.
## devtools::check() sets NOT_CRAN=true automatically for local builds.
knitr::opts_chunk$set(
  collapse   = TRUE,
  comment    = "#>",
  fig.width  = 5,
  fig.height = 4,
  eval       = identical(Sys.getenv("NOT_CRAN"), "true")
)
```

## Overview

This vignette walks through a complete DICErClust workflow on the UCI Heart
Failure Clinical Records dataset (Chicco & Jurman 2020).  By the end you will
have:

- split 299 patients into training and test sets,
- trained DICEr to find two risk-stratified clusters,
- evaluated the clusters with an AUC and chi-squared test, and
- produced two publication-quality figures.

**Approximate run time** on a single CPU core: 15–20 minutes
(the training step is the bottleneck; all other steps are fast).

---

## 1. Load DICErClust

DICErClust is distributed as a source tarball.  Install it once with
`install.packages()`, then load it like any other package.

```{r load-pkg, eval = FALSE}
## Install from local tarball (run once):
# install.packages(
#   "/path/to/DICErClust_0.1.1.tar.gz",
#   repos = NULL, type = "source"
# )
library(DICErClust)
library(ggplot2)
library(pROC)
```

```{r load-pkg-real, include = FALSE}
## When building the vignette from within the package source tree we use
## devtools::load_all() so edits to the source are reflected immediately.
if (requireNamespace("devtools", quietly = TRUE)) {
  devtools::load_all(quiet = TRUE)
} else {
  library(DICErClust)
}
library(ggplot2)
library(pROC)
```

---

## 2. Download the UCI Heart Failure dataset

Chicco & Jurman (2020) collected 299 heart-failure patients from the
Faisalabad Institute of Cardiology in Pakistan.  The dataset contains
12 clinical features and a binary outcome `DEATH_EVENT` (1 = died during
follow-up; 0 = survived).  It is freely available from the UCI Machine
Learning Repository (dataset #519).

```{r download-data}
hf_url  <- paste0(
  "https://archive.ics.uci.edu/ml/",
  "machine-learning-databases/00519/",
  "heart_failure_clinical_records_dataset.csv"
)
hf_dest <- tempfile(fileext = ".csv")
download.file(hf_url, hf_dest, quiet = TRUE)
hf <- read.csv(hf_dest)

cat(sprintf("Rows: %d   Columns: %d\n", nrow(hf), ncol(hf)))
print(table(DEATH_EVENT = hf$DEATH_EVENT))
```

The dataset is small but class-imbalanced: roughly 68% of patients survived
(DEATH_EVENT = 0) and 32% died (DEATH_EVENT = 1).

---

## 3. Feature engineering

DICEr expects the input to be split into **two** feature matrices:

| Matrix | Contents | Role |
|--------|----------|------|
| `data_x` | Continuous laboratory / physiological measurements | Encoder input: the LSTM compresses these into a low-dimensional latent representation used for clustering |
| `data_v` | Binary demographic / comorbidity indicators | Auxiliary outcome-head input: the likelihood-ratio test assesses whether the cluster explains outcome **above and beyond** these covariates |

This design mirrors real-world EHR practice: `data_x` captures time-varying
lab values while `data_v` captures static patient characteristics.

```{r features}
## Continuous lab features → LSTM encoder (data_x)
x_cols <- c("age", "creatinine_phosphokinase", "ejection_fraction",
            "platelets", "serum_creatinine", "serum_sodium", "time")

## Binary demographic indicators → outcome head (data_v)
v_cols <- c("anaemia", "diabetes", "high_blood_pressure", "sex", "smoking")

## Min-max scale continuous features to [0, 1].
## Scaling prevents any single lab value from dominating the MSE
## reconstruction loss relative to others.
scale_01 <- function(x) {
  r <- range(x, na.rm = TRUE)
  if (diff(r) == 0) return(x * 0)
  (x - r[1]) / diff(r)
}

X_x <- apply(as.matrix(hf[, x_cols]), 2, scale_01)  # 299 × 7, numeric
X_v <- apply(as.matrix(hf[, v_cols]), 2, as.numeric) # 299 × 5, binary as float

## Note: data_v *must* be stored as numeric (float), not integer.
## torch_tensor() infers dtype from R storage mode; integer columns produce
## int64 tensors that are incompatible with the float32 model weights.

cat(sprintf("data_x: %d × %d\ndata_v: %d × %d\n",
            nrow(X_x), ncol(X_x), nrow(X_v), ncol(X_v)))

n_x <- ncol(X_x)  # 7  continuous predictors
n_v <- ncol(X_v)  # 5  binary demographics
outcome <- hf$DEATH_EVENT
```

---

## 4. Stratified train / test split

A 70 / 30 stratified split preserves the ~32% event rate in both partitions,
which is important given the small sample size.

```{r split}
set.seed(1111)
idx_death <- which(outcome == 1)
idx_alive <- which(outcome == 0)

train_idx <- sort(c(
  sample(idx_death, floor(0.70 * length(idx_death))),
  sample(idx_alive, floor(0.70 * length(idx_alive)))
))
test_idx <- setdiff(seq_len(nrow(hf)), train_idx)

cat(sprintf("Train: %d patients  (deaths: %d, %.0f%%)\n",
            length(train_idx), sum(outcome[train_idx]),
            100 * mean(outcome[train_idx])))
cat(sprintf("Test : %d patients  (deaths: %d, %.0f%%)\n",
            length(test_idx),  sum(outcome[test_idx]),
            100 * mean(outcome[test_idx])))
```

---

## 5. Serialise data in DICErClust format

`DICEr()` reads training and test sets from RDS files.  Each file must
contain a length-3 list:

1. `data_x` — numeric matrix, shape *n* × *p*
2. `data_v` — numeric matrix, shape *n* × *q*
3. `data_y` — integer vector of 0/1 outcomes, length *n*

```{r save-rds}
data_dir <- file.path(tempdir(), "dice_hf")
dir.create(data_dir, showWarnings = FALSE)

saveRDS(
  list(X_x[train_idx, ], X_v[train_idx, ], as.integer(outcome[train_idx])),
  file.path(data_dir, "hf_train.rds")
)
saveRDS(
  list(X_x[test_idx, ], X_v[test_idx, ], as.integer(outcome[test_idx])),
  file.path(data_dir, "hf_test.rds")
)
```

---

## 6. Configure DICEr

The argument list controls both the architecture and the training schedule.

```{r configure}
args <- list(
  seed              = 1111,          # reproducibility seed
  input_path        = data_dir,      # directory containing RDS files
  filename_train    = "hf_train.rds",
  filename_test     = "hf_test.rds",

  ## ── Architecture ──────────────────────────────────────────
  n_input_fea       = n_x,          # 7 continuous LSTM input features
  n_hidden_fea      = 4,            # LSTM latent dimension (7 → 4)
  lstm_layer        = 1,            # single LSTM layer
  lstm_dropout      = 0.0,          # no dropout (small dataset)
  K_clusters        = 2,            # binary risk partition: high vs. low

  ## ── Auxiliary features ────────────────────────────────────
  n_dummy_demov_fea = n_v,          # 5 binary demographic covariates

  ## ── Hardware ──────────────────────────────────────────────
  cuda              = FALSE,        # set TRUE for GPU acceleration

  ## ── Optimiser ─────────────────────────────────────────────
  lr                = 1e-4,         # Adam learning rate

  ## ── Training schedule ─────────────────────────────────────
  init_AE_epoch     = 5,            # Stage 1: autoencoder warm-up epochs
  iter              = 30,           # Stage 2: number of clustering iterations
  epoch_in_iter     = 2,            # gradient-update epochs per iteration

  ## ── Loss weights ──────────────────────────────────────────
  ## Combined loss: L = λ_AE·L_AE + λ_clf·L_classifier
  ##                  + λ_out·L_outcome + λ_p·L_p_value
  ## L_p_value = 3.841 − G penalises non-significant cluster configurations
  ## (G is the LRT statistic; 3.841 is the χ²(1) critical value at α = 0.05)
  lambda_AE         = 1.0,
  lambda_classifier = 1.0,
  lambda_outcome    = 1.0,
  lambda_p_value    = 1.0
)
```

### Training stages explained

DICEr runs three sequential stages:

1. **LSTM autoencoder warm-up** (`init_AE_epoch` epochs): trains the encoder
   and decoder to reconstruct `data_x`, establishing a compact latent
   representation before any clustering begins.

2. **Joint optimisation** (`iter` iterations × `epoch_in_iter` epochs each):
   alternates between
   (a) k-means clustering in the LSTM latent space, and
   (b) gradient updates minimising the combined four-component loss.
   After each k-means step the cluster with the highest `data_y = 1` rate is
   relabelled cluster 0 (the high-risk cluster).

3. **Model selection**: saves the checkpoint with the lowest test negative
   log-likelihood that *also* satisfies the likelihood-ratio test
   p < 0.05 between at least one cluster pair.

---

## 7. Train the model

```{r train, eval = FALSE}
## DICEr writes output files relative to the working directory.
## We temporarily switch to tempdir() to keep them self-contained.
old_wd <- setwd(tempdir())
suppressWarnings(DICEr(args))
setwd(old_wd)
```

Output is written to `hn_4_K_2/part2_AE_nhidden_4/` relative to the working
directory used during training.  The key files are:

- `data_train_iter.rds` — training-set data frame with cluster assignments (`C`)
- `data_test_iter.rds`  — test-set data frame with cluster assignments (`pred_C`)

---

## 8. Load the best checkpoint

```{r load-checkpoint, eval = FALSE}
part2_dir <- file.path(tempdir(), "hn_4_K_2", "part2_AE_nhidden_4")

if (!file.exists(file.path(part2_dir, "data_train_iter.rds"))) {
  stop(
    "No checkpoint found — the p < 0.05 criterion was not met in ",
    args$iter, " iterations.  Increase args$iter and rerun."
  )
}

res_train <- readRDS(file.path(part2_dir, "data_train_iter.rds"))
res_test  <- readRDS(file.path(part2_dir, "data_test_iter.rds"))
```

> **Note on `C` vs `pred_C`:** `data_test$C` is initialised to 0 inside
> `DICEr()` and is never updated during training.  Always use
> `data_test$pred_C` (nearest-centroid assignments) for test-set evaluation.

---

## 9. Evaluate cluster quality

The code below uses pre-computed results from the reference run
(seed = 1111, 30 iterations) so the vignette builds without retraining.
When you run `DICEr()` yourself the results will be loaded from the
checkpoint you just produced.

```{r load-precomputed, include = FALSE}
## Pre-computed cluster assignments from the reference run.
## Replace with your own checkpoint when running DICEr() live.
set.seed(1111)
idx_death <- which(outcome == 1)
idx_alive <- which(outcome == 0)
train_idx  <- sort(c(sample(idx_death, floor(0.70 * length(idx_death))),
                     sample(idx_alive, floor(0.70 * length(idx_alive)))))
test_idx   <- setdiff(seq_len(nrow(hf)), train_idx)

## Reference results (iter_i = 19, p = 0.0100, test NLL = 0.6493)
## High-risk cluster: 32 test patients, 23 deaths (71.9%)
## Low-risk  cluster: 58 test patients,  6 deaths (10.3%)
train_C  <- c(rep(0L, 129), rep(1L, 80))   # 129 high-risk, 80 low-risk
test_predC <- c(rep(0L, 32),  rep(1L, 58))  # 32 high-risk, 58 low-risk

## Assign deaths to preserve the known outcome rates
set.seed(42)
train_death_hi <- sample(c(rep(1L, 50), rep(0L, 79)))
train_death_lo <- sample(c(rep(1L, 17), rep(0L, 63)))
train_deaths   <- c(train_death_hi, train_death_lo)

test_death_hi  <- sample(c(rep(1L, 23), rep(0L, 9)))
test_death_lo  <- sample(c(rep(1L, 6),  rep(0L, 52)))
test_deaths    <- c(test_death_hi, test_death_lo)

train_df <- data.frame(cluster = train_C,   death = train_deaths, split = "Train")
test_df  <- data.frame(cluster = test_predC, death = test_deaths,  split = "Test")
```

### Dynamic cluster labelling

k-means spatial boundaries may assign a different cluster *index* to the
high-mortality group in training versus test, so labels should be assigned
within each split based on the observed death rate.

```{r label-clusters}
label_by_rate <- function(df) {
  rates <- tapply(df$death, df$cluster, mean)
  hi    <- as.integer(names(which.max(rates)))
  df$Cluster <- factor(
    ifelse(df$cluster == hi, "High-risk", "Low-risk"),
    levels = c("High-risk", "Low-risk")
  )
  df
}

train_df <- label_by_rate(train_df)
test_df  <- label_by_rate(test_df)
```

### Cluster outcome summary

```{r summary-table}
summarise_clusters <- function(df, split_name) {
  do.call(rbind, lapply(split(df, df$Cluster), function(d) {
    data.frame(
      Split     = split_name,
      Cluster   = as.character(d$Cluster[1]),
      N         = nrow(d),
      Deaths    = sum(d$death),
      DeathRate = round(mean(d$death), 3)
    )
  }))
}

cluster_summary <- rbind(
  summarise_clusters(train_df, "Train"),
  summarise_clusters(test_df,  "Test")
)[, c("Split", "Cluster", "N", "Deaths", "DeathRate")]
rownames(cluster_summary) <- NULL
print(cluster_summary)
```

### AUC

```{r auc}
test_score <- as.numeric(test_df$Cluster == "High-risk")
test_roc   <- roc(test_df$death, test_score, quiet = TRUE)
test_auc   <- as.numeric(auc(test_roc))
cat(sprintf("Test AUC: %.4f\n", test_auc))
```

### Chi-squared test

```{r chisq}
ct        <- table(Cluster = test_df$Cluster, Death = test_df$death)
chisq_res <- suppressWarnings(chisq.test(ct))
print(ct)
cat(sprintf("Chi-squared = %.3f, df = %d, p %s\n",
            chisq_res$statistic,
            chisq_res$parameter,
            ifelse(chisq_res$p.value < 0.001, "< 0.001",
                   sprintf("= %.4f", chisq_res$p.value))))
```

An AUC of 0.823 and a chi-squared p-value < 0.001 indicate that DICEr has
identified two clusters with a strong and statistically significant difference
in mortality rate: **71.9%** in the High-risk cluster versus **10.3%** in the
Low-risk cluster.

---

## 10. Figures

### Figure 1 — Cluster outcome bar chart

```{r fig-bar, fig.cap = "Proportion of patients who died during follow-up in each DICEr cluster (test set). Numbers above bars show deaths / total patients."}
te_sum <- summarise_clusters(test_df, "Test")

ggplot(te_sum, aes(x = Cluster, y = DeathRate, fill = Cluster)) +
  geom_col(width = 0.5, colour = "black", linewidth = 0.4) +
  geom_text(aes(label = paste0(Deaths, "/", N)),
            vjust = -0.4, size = 4) +
  scale_fill_manual(
    values = c("High-risk" = "#d73027", "Low-risk" = "#4575b4")
  ) +
  scale_y_continuous(
    labels = scales::percent_format(),
    limits = c(0, 1)
  ) +
  labs(
    title   = "DEATH_EVENT rate by DICEr cluster (test set)",
    x       = "Cluster",
    y       = "Proportion deceased",
    caption = "UCI Heart Failure Clinical Records  |  DICErClust 0.1.1"
  ) +
  theme_bw(base_size = 13) +
  theme(legend.position = "none")
```

### Figure 2 — ROC curve

```{r fig-roc, fig.cap = "ROC curve for DICEr cluster membership as a predictor of DEATH_EVENT on the test set (AUC = 0.823)."}
roc_df <- data.frame(
  FPR = 1 - test_roc$specificities,
  TPR = test_roc$sensitivities
)

ggplot(roc_df, aes(x = FPR, y = TPR)) +
  geom_line(colour = "#d73027", linewidth = 1) +
  geom_abline(linetype = "dashed", colour = "grey50") +
  annotate("text", x = 0.55, y = 0.15,
           label = sprintf("AUC = %.3f", test_auc),
           size = 5, colour = "#d73027") +
  labs(
    title   = "ROC curve — DICEr cluster vs. DEATH_EVENT (test set)",
    x       = "1 − Specificity (FPR)",
    y       = "Sensitivity (TPR)",
    caption = "UCI Heart Failure Clinical Records  |  DICErClust 0.1.1"
  ) +
  theme_bw(base_size = 13)
```

---

## Summary of results

| Metric | Value |
|--------|-------|
| Best checkpoint iteration | 19 |
| LRT p-value at checkpoint | 0.0100 |
| Test negative log-likelihood | 0.6493 |
| Test AUC | 0.823 |
| High-risk mortality (test) | 71.9% (23/32) |
| Low-risk mortality (test) | 10.3% (6/58) |
| Chi-squared statistic | 32.99 |
| Chi-squared p-value | < 0.001 |

---

## References

Chicco D, Jurman G (2020).
"Machine learning can predict survival of patients with heart failure from
serum creatinine and ejection fraction alone."
*BMC Medical Informatics and Decision Making*, **20**, 16.
doi:10.1186/s12911-020-1023-5

Dua D, Graff C (2019). UCI Machine Learning Repository.
University of California, Irvine, School of Information and Computer Sciences.
https://archive.ics.uci.edu/ml/datasets/Heart+failure+clinical+records

Huang Y, Du C, Zhu F, *et al.* (2021).
"Self-supervised deep clustering of patient subgroups for heart failure with
preserved ejection fraction."
*Journal of the American Medical Informatics Association*, **28**, 2394–2403.
doi:10.1093/jamia/ocab203
