Model Performance Evaluation
1 Prerequisites
2 Import Synthetic Data
panc_ca <- read_csv("data/synthetic_pancreas.csv",
col_types = cols(death_os = 'i', fnstage = 'c', grade = 'c')) %>%
mutate(across(where(is.character), as.factor),
.keep = "unused")
# lookup list for feature names
lookup <- c(
sex = "Gender", age = "Age", hp_lv = "Medical center", urbangp = "Urbanization",
height = "Height", weight = "Weight", bmi = "BMI", hist = "Histology", subtype = "Subsite",
fnt = "T stage", fnn = "N stage", fnm = "M stage", fnstage = "AJCC stage", grade = "Grade",
smoke_yr = "Years of smoking", drink_gp = "Alcohol consumption",
trtop = "Surgery", trtct = "Chemotherapy", trtrt = "Radiotherapy",
trttarget = "Targeted therapy", trthorm = "Hormonal therapy", trtimmu = "Immunotherapy"
)
head(panc_ca, 50)3 Multiple Imputation
We performed multiple imputation by chained equations (MICE) with random forests to generate 5 complete datasets. The follow-up time and event indicator were excluded from the imputation process.
4 Imputation Convergence Diagnostics
Convergence of the MICE algorithm was assessed using autocorrelation and the potential scale reduction factor (PSRF) provided in the mice package. Autocorrelations approaching zero and PSRFs approaching one indicated adequate convergence.
panc_mice_convg <- convergence(panc_mice)
panc_mice_convg %>%
filter(!if_all(c(ac, psrf), is.na)) %>%
pivot_longer(c(ac, psrf), names_to = "index") %>%
mutate(vrb2 = fct_relevel(vrb, "fnt", "fnn", "grade", "height", "weight", "smoke_yr", "drink_gp")) %>%
ggplot(aes(x = .it, y = value)) +
geom_hline(aes(yintercept = ref), linetype = 3,
data = data.frame(index = c("ac", "psrf"), ref = c(0, 1))) +
geom_line(colour = 4, na.rm = TRUE) +
facet_grid(index ~ vrb, scales = "free",
labeller = labeller(vrb = lookup, index = c(ac = "Autocorrelation", psrf = "PSRF"))) +
labs(x = "Iterations", y = NULL) +
theme_bw()Before model performance evaluation, check that all wrapper functions for modeling run correctly without errors.
cox_fit(Surv(Surv_time, death_os) ~ ., data = panc_impute$Imp.1)
orsf_fit(Surv(Surv_time, death_os) ~ ., data = panc_impute$Imp.1)
xgbst_cox_fit(Surv(Surv_time, death_os) ~ ., data = panc_impute$Imp.1)
nn_deepsurv_fit(Surv(Surv_time, death_os) ~ ., data = panc_impute$Imp.1)
nn_coxtime_fit(Surv(Surv_time, death_os) ~ ., data = panc_impute$Imp.1)
nn_pchazard_fit(Surv(Surv_time, death_os) ~ ., data = panc_impute$Imp.1)5 Benchmark
For each imputed dataset, we applied nested cross-validation to obtain an unbiased estimate of model performance for selecting the optimal model. An outer 5-fold cross-validation loop split the data into 5 equal subsets, with each (20%) used in turn for evaluation of time-dependent performance metrics, and the remaining (80%) for training. An inner 5-fold cross-validation loop was employed during training for hyperparameter optimization, using a random search strategy.
First, create a list that contains the training and prediction functions for each model. The randomly searched hyperparameter grids can be supplied to the training function as vectors of equal length.
model_fn <- list(
cox = list(fit = \(f, x) cox_fit(f, x),
pred = cox_pred),
cox2 = list(fit = \(f, x) cox_fit(f, x, interaction = TRUE),
pred = cox_pred),
orsf = list(fit = \(f, x) {
orsf_fit(f, x,
n_tree = 500, # runif(16, 500, 1000),
colsample = runif(16, 0.1, 0.5),
alpha = 10^runif(16, -3, -1),
eval_horizon = 12 * 1:5)
}, pred = orsf_pred),
xgbst = list(fit = \(f, x) {
xgbst_cox_fit(f, x,
learn_rate = 10^runif(48, -3, -1),
tree_depth = sample(4:8, 48, replace = TRUE),
colsample = runif(48, 0.1, 0.5),
subsample = runif(48, 0.5, 1),
penalty = 10^runif(48, -5, 2),
mixture = runif(48, 0, 1),
n_core = 8)
}, pred = xgbst_cox_pred),
nn_deepsurv = list(fit = \(f, x) {
nn_deepsurv_fit(f, x,
n_layer = sample(4:8, 48, replace = TRUE),
frac_layer1 = runif(48, 0.5, 1),
width_ratio = runif(48, 0.5, 1),
learn_rate = 10^runif(48, -3, -1),
dropout = runif(48, 0, 0.5),
n_core = 8)
}, pred = nn_surv_pred),
nn_coxtime = list(fit = \(f, x) {
nn_coxtime_fit(f, x,
n_layer = sample(4:8, 48, replace = TRUE),
frac_layer1 = runif(48, 0.5, 1),
width_ratio = runif(48, 0.5, 1),
learn_rate = 10^runif(48, -3, -1),
dropout = runif(48, 0, 0.5),
n_core = 8)
}, pred = nn_surv_pred),
nn_pchazard = list(fit = \(f, x) {
nn_pchazard_fit(f, x,
n_layer = sample(4:8, 48, replace = TRUE),
frac_layer1 = runif(48, 0.5, 1),
width_ratio = runif(48, 0.5, 1),
learn_rate = 10^runif(48, -3, -1),
dropout = runif(48, 0, 0.5),
n_core = 8)
}, pred = nn_surv_pred)
)To reduce variability from data partitioning, we repeated the nested cross-validation process 10 times, resulting in a total of 5(imputed datasets)×5(outer folds)×10(repeats)=250 estimates for each performance metric at each time point.
# suppress warning messages from {survivalmodels}
Sys.setenv("PYTHONWARNINGS" = "ignore")
cv_bench <- imap(panc_impute, \(df, imp_id) {
cv_ind <- kfold_cv_ind(nrow(df), k = 5, repeats = 10)
imap(cv_ind, \(x, fold_id) {
cli::cli_h2(paste(imp_id, fold_id, sep = "_"))
start_time <- Sys.time()
res <- bench(Surv(Surv_time, death_os) ~ ., df,
model_fn = model_fn,
pred_horiz = 12 * seq(1, 5, 0.5),
test_ind = x)
end_time <- Sys.time()
cli::cli_alert_info("Finished in {prettyunits::pretty_dt(end_time - start_time)}.")
return(res)
})
})
# saveRDS(cv_bench, "data/cv_bench.rds")
# cv_bench <- readRDS("data/cv_bench.rds")── Imp.1_Repeat.1_Fold.1 ──
- Model: cox → training... 173ms → testing... ✔
- Model: cox2 → training... 6.3s → testing... ✔
- Model: orsf → training... 17.7s → testing... ✔
- Model: xgbst → training... 14.7s → testing... ✔
- Model: nn_deepsurv → training... 27.5s → testing... ✔
- Model: nn_coxtime → training... 25.7s → testing... ✔
- Model: nn_pchazard → training... 21.9s → testing... ✔
ℹ Finished in 1m 58.7s.
── Imp.1_Repeat.1_Fold.2 ──
- Model: cox → training... 137ms → testing... ✔
- Model: cox2 → training... 4.7s → testing... ✔
- Model: orsf → training... 17.8s → testing... ✔
- Model: xgbst → training... 18.6s → testing... ✔
- Model: nn_deepsurv → training... 27.4s → testing... ✔
- Model: nn_coxtime → training... 21.8s → testing... ✔
- Model: nn_pchazard → training... 25.1s → testing... ✔
ℹ Finished in 2m 0.2s.
continuing...
cv_bench is a nested list including the performance metrics and runtime of each cross-validation fold for each model. In some rare cases, deep learning models may fail to fit and produce abnormal predictions (e.g., AUC < 0.5 or IPA < 0). These folds are excluded prior to calculating the average performance.
perf <- map_dfr(cv_bench,
\(x) map_dfr(x, \(y) y$perf_metric, .id = "fold"),
.id = "imp") %>%
filter(all(TPR > 0.5 & PPV > 0.5 & AUC > 0.5 & IPA > 0),
.by = c(imp, fold, model)) %>%
mutate(TNR = 1-FPR, FPR = NULL, F1 = 2/(1/TPR + 1/PPV),
times = times / 12)
perf_pool <- perf %>%
summarise(across(c(TPR, TNR, PPV, F1, AUC, IPA), ~ mean(.x)),
.by = c(model, times))The code for plotting the time-dependent metrics is shown below:
perf_pool_long <- perf_pool %>%
pivot_longer(-c(model, times), names_to = "metric") %>%
mutate(model = fct_relevel(model, "cox", "cox2", "orsf", "nn_deepsurv", "nn_pchazard", "nn_coxtime", "xgbst")) %>%
add_row(times = c(1, 5), metric = "TPR", value = range(perf_pool$TNR)) %>%
mutate(metric = case_match(metric,
"TPR" ~ "(a) Sensitivity",
"TNR" ~ "(b) Specificity",
"PPV" ~ "(c) Precision",
"F1" ~ "(d) F1 score",
"AUC" ~ "(e) AUC",
"IPA" ~ "(f) sBrier"))
model_lookup <- c(
"cox" = "Cox",
"cox2" = "Cox-Interact",
"orsf" = "ORSF",
"xgbst" = "XGBoost",
"nn_deepsurv" = "DeepSurv",
"nn_pchazard" = "PC-Hazard",
"nn_coxtime" = "Cox-Time"
)
# palette.colors(palette = "Okabe-Ito")
model_col <- c("#0072B2", "#CC79A7", "#009E73", "#E69F00", "#56B4E9", "#F0E442", "#D55E00")
perf_pool_long %>%
ggplot(aes(x = times, y = value, colour = model)) +
geom_line(linewidth = 1, na.rm = TRUE) +
facet_wrap(~ metric, nrow = 2, scales = "free_y", axes = "all") +
labs(x = "Time (years since cancer diagnosis)", y = NULL) +
scale_y_continuous(breaks = seq(0, 1, by = 0.02),
labels = scales::label_percent(suffix = '')) +
scale_color_manual(name = "Models", na.translate = FALSE,
labels = model_lookup,
values = model_col) +
theme_bw() +
theme(strip.text = element_text(hjust = 0),
strip.background = element_blank(),
panel.border = element_blank(),
axis.line = element_line(color = 'black'))