SHAP Analysis and Visualization

Author

Tsai, Dai-Rong

1 Prerequisites

library(survival)
library(xgboost)
library(furrr)
library(tidyverse)
library(patchwork)
library(plotshap)

invisible(lapply(dir("R/", full.names = TRUE), source))

# ggplot theme setting
theme_set(
  theme_bw() +
    theme(title = element_text(family = "serif"),
          plot.tag = element_text(family = "serif"))
)

set.seed(1)

2 Import Synthetic Data

panc_ca <- read_csv("data/synthetic_pancreas.csv",
                    col_types = cols(death_os = 'i', fnstage = 'c', grade = 'c')) %>% 
  mutate(bmi = weight / (height / 100)^2,
         across(where(is.character), as.factor),
         .keep = "unused")

# lookup list for feature names
lookup <- c(
  sex = "Gender", age = "Age", hp_lv = "Medical center", urbangp = "Urbanization",
  height = "Height", weight = "Weight", bmi = "BMI", hist = "Histology", subtype = "Subsite",
  fnt = "T stage", fnn = "N stage", fnm = "M stage", fnstage = "AJCC stage", grade = "Grade",
  smoke_yr = "Years of smoking", drink_gp = "Alcohol consumption",
  trtop = "Surgery", trtct = "Chemotherapy", trtrt = "Radiotherapy",
  trttarget = "Targeted therapy", trthorm = "Hormonal therapy", trtimmu = "Immunotherapy"
)

head(panc_ca, 50)

3 Hyperparameter Tuning

We used a repeated 5-fold cross-validation (n_repeat = 5 or more) to enhance the robustness of hyperparameter optimization.

n_grid <- 64
system.time({
  mod_xgbst <- xgbst_cox_fit(
    Surv(Surv_time, death_os) ~ ., data = panc_ca,
    learn_rate = 10^runif(n_grid, -2.5, -1),
    tree_depth = sample(4:8, n_grid, replace = TRUE),
    colsample = runif(n_grid, 0.1, 0.5),
    subsample = runif(n_grid, 0.5, 1),
    penalty = 10^runif(n_grid, -5, 2),
    mixture = runif(n_grid, 0, 1),
    n_fold = 5, n_repeat = 5,
    n_core = 8, progress = TRUE
  )
})
   user  system elapsed 
  3.145   0.142  80.097 

The optimal hyperparameter set is as follows:

xgb_param <- mod_xgbst$tuner %>% slice_min(best_score)
xgb_param

4 Retraining to Improve Stability

Due to the inherent randomness of XGBoost, features or feature pairs with similar importance scores may have fluctuating rankings across retraining, even when hyperparameters are fixed. To ensure robust inference, we repeatedly trained the model using a set of fine-tuned hyperparameters, and recorded respective SHAP interaction values.

shap_int_lst <- replicate(20, {
  start_time <- Sys.time()
  mod <- xgbst_cox_fit(
    Surv(Surv_time, death_os) ~ ., data = panc_ca,
    learn_rate = xgb_param$eta,
    tree_depth = xgb_param$max_depth,
    colsample = xgb_param$colsample_bynode,
    subsample = xgb_param$subsample,
    penalty = xgb_param$alpha + xgb_param$lambda,
    mixture = xgb_param$alpha / (xgb_param$alpha + xgb_param$lambda),
    n_round = xgb_param$best_iteration,
    doCV = FALSE
  )
  shap_int <- predict(mod$fit, newdata = mod$x, predinteraction = TRUE)
  end_time <- Sys.time()
  print(end_time - start_time)
  return(shap_int)
}, simplify = FALSE)
Time difference of 27.92141 secs
Time difference of 28.29427 secs
Time difference of 27.78217 secs
Time difference of 27.80838 secs
Time difference of 27.74697 secs
Time difference of 27.68325 secs
Time difference of 27.62907 secs
Time difference of 27.78623 secs
Time difference of 27.76925 secs
Time difference of 27.79175 secs
Time difference of 27.94803 secs
Time difference of 27.95828 secs
Time difference of 28.08933 secs
Time difference of 27.75989 secs
Time difference of 27.91376 secs
Time difference of 27.73137 secs
Time difference of 28.36571 secs
Time difference of 27.78998 secs
Time difference of 27.68766 secs
Time difference of 27.999 secs

Then we progressively averaged the SHAP interaction values until the importance rankings stabilized. Stability was defined using Kendall’s τ correlation coefficient, calculated between successive rankings of the cumulative mean SHAP interaction values.

shap_int_cummean <- map2(accumulate(shap_int_lst, .f = `+`),
                         seq_along(shap_int_lst),
                         .f = `/`)

rank_vint <- vint_rank(shap_int_cummean)

cor_vint <- rank_vint %>%
  select(starts_with("rank")) %>%
  cor(method = "kendall")

data.frame(n = 2:20, rho = cor_vint[cbind(1:19, 2:20)]) %>%
  ggplot(aes(x = n, y = rho)) +
  geom_line() +
  geom_point(shape = 21, fill = 4) +
  geom_hline(yintercept = 1, linetype = 2) +
  scale_x_continuous(breaks = seq(0, 20, 5)) +
  labs(x = "Training repetitions",
       y = expression(paste("Kendall's  ", tau)))

We selected the 15th cumulative mean, as additional repetitions did not notably improve the stability of the rankings.

shap_int <- shap_int_cummean[[15]]
shap <- rowSums(shap_int, dims = 2)

5 SHAP Visualization

Before plotting, relabeling the categorical features may be needed to improve the appearance of the charts, e.g., removing leading numbers or converting levels that should not be displayed into NA.

panc_ca <- panc_ca %>%
  mutate(across(c(sex, hist, subtype, drink_gp), \(x) fct_relabel(x, ~ str_remove(.x, "^\\d_"))),
         across(c(hist, subtype), \(x) fct_na_level_to_value(x, "others")))

5.1 SHAP Summary Plot

Each point corresponds to an individual in the dataset, colored by its feature value. The SHAP summary plot is obtained by projecting the SHAP dependence plots onto the y-axis. It succinctly visualizes the magnitude, distribution, and direction of each feature’s effect on the predictions.

The SHAP values for the histological types SPN and NET are close, and SPN will be hidden behind NET. By specifying order_by_val, we can adjust the plotting order of different levels of categorical features.

fmt1 <- "<span style='color: %s;'><b>%s</b></span>"
fmt2 <- "%s<span style='font-size: 8pt;'>(%s)</span>"

lookup2 <- map2(lookup, names(lookup), ~ {
  lvl <- if(is.factor(panc_ca[[.y]])) levels(panc_ca[[.y]]) else c("low", "high")
  .x <- paste0(.x, ifelse(.y %in% c("hist", "subtype", "urbangp", "smoke_yr", "drink_gp"), "<br>", ""))
  fmt <- sprintf(fmt2, .x, paste(sprintf(fmt1, shap.colors(length(lvl)), lvl), collapse = '/'))
  return(fmt)
})

plot_shap_summary(panc_ca, shap, top_n = 12, varlab = lookup2, order_by_val = c("hist")) +
  theme(axis.text.y = ggtext::element_markdown())

5.2 SHAP Interaction Matrix

SHAP interaction importance for each feature pair is measured by averaging the absolute SHAP interaction values across all patients. For the heat map, features are arranged along both axes based on their mean interaction importance across all other features.

p1 <- plot_shap_heatmap(shap_int, varlab = lookup)
p2 <- plot_shap_vint(shap_int, top_n = 12, varlab = lookup) +
  theme(plot.margin = unit(c(4, 0, 1, 0), "lines"))
gridExtra::grid.arrange(p1, p2, nrow = 1)

5.3 SHAP Dependence Plot

5.3.1 Main Effect

5.3.2 Interaction Effect

We exhibit the following interactions in order:

  1. Treatment-Histology interaction
  2. Treatment-Stage interaction
  3. Age-related interaction

The SHAP values of both features in a pair are summed to determine their overall impact.

shapint.lim <- c(-2, 1.2)

p1 <- plot_shap_dependence(hist ~ trtop, panc_ca, shap, shap_int, which = 2,
                           varlab = lookup, size = 0.5, trend.line.alpha = 0.3, trend.line.width = 0.7) +
  coord_cartesian(ylim = shapint.lim) +
  ggtitle("(1) Treatment-Histology Interaction")
p2 <- plot_shap_dependence(hist ~ trtct, panc_ca, shap, shap_int, which = 2,
                           varlab = lookup, size = 0.5, trend.line.alpha = 0.3, trend.line.width = 0.7) +
  coord_cartesian(ylim = shapint.lim)
p3 <- plot_shap_dependence(fnstage ~ trtop, panc_ca, shap, shap_int, which = 2,
                           varlab = lookup, size = 0.5, trend.line.alpha = 0.3, trend.line.width = 0.7) +
  coord_cartesian(ylim = shapint.lim) +
  ggtitle("(2) Treatment-Stage Interaction")
p4 <- plot_shap_dependence(fnstage ~ trtct, panc_ca, shap, shap_int, which = 2,
                           varlab = lookup, size = 0.5, trend.line.alpha = 0.3, trend.line.width = 0.7) +
  coord_cartesian(ylim = shapint.lim)
p5 <- plot_shap_dependence(age ~ trtct, panc_ca, shap, shap_int, which = 2,
                           varlab = lookup, size = 0.5, trend.line.alpha = 0.3, trend.line.width = 0.7) +
  coord_cartesian(ylim = shapint.lim) +
  ggtitle("(3) Age-related Interaction")
p6 <- plot_shap_dependence(age ~ hist, panc_ca, shap, shap_int, which = 2,
                           varlab = lookup, size = 0.5, trend.line.alpha = 0.3, trend.line.width = 0.7) +
  coord_cartesian(ylim = shapint.lim)

(p1 + p3 + p5) / (p2 + p4 + p6)

5.4 Clinical Vignettes

The clinical vignettes for 3 randomly selected example cases are illustrated using SHAP waterfall plots, along with their predicted survival curves. The cases are arranged in descending order of predicted survival.

sample_id <- sample(nrow(panc_ca), 3)
pred_horiz <- 12 * seq(0.1, 5, 0.1)

surv <- xgbst_cox_pred(mod_xgbst, panc_ca[sample_id, ], pred_horiz = pred_horiz)
colnames(surv) <- pred_horiz
rownames(surv) <- sample_id
surv_dat <- as.data.frame(surv) %>%
  rownames_to_column("id") %>% 
  pivot_longer(-id, names_to = "time", values_to = "surv",
               names_transform = ~ as.numeric(.x) / 12) %>% 
  mutate(id = fct_reorder(id, surv, mean, .desc = TRUE),
         label = paste("Case", as.integer(id)))

surv_plot <- ggplot(surv_dat, aes(x = time, y = surv, group = label, colour = label)) +
  geom_line(linewidth = 1) +
  scale_y_continuous(labels = scales::label_percent()) +
  guides(colour = guide_legend(NULL)) +
  labs(x = "Time (years since cancer diagnosis)\n",
       y = "Survival propability",
       title = "(D) Survival Curves") +
  theme(plot.title = element_text(vjust = 2),
        plot.margin = unit(c(1, 1, 1, 5), "lines"),
        legend.position = "inside",
        legend.position.inside = c(0.85, 0.85),
        legend.background = element_rect(fill = "transparent"))

shap_plot <- map2(levels(surv_dat$id), 1:3, \(id, order) {
  plot_shap_waterfall(panc_ca, shap, which = as.integer(id), varlab = lookup) +
    labs(title = sprintf("(%s) Case %d", LETTERS[order], order)) +
    theme(plot.title = element_text(vjust = 2))
})

gridExtra::grid.arrange(grobs = c(shap_plot, surv_plot), nrow = 2)