Explore treeheatr

Trang Le

2020-06-12

treeheatr displays a more interpretable decision tree visualization by integrating a heatmap at its terminal nodes. Let’s explore the package treeheatr a little deeper and see what it can do!

library(treeheatr)

Let’s begin

with the iris dataset! Running the heat_tree() function can be as simple as:

dat_raw <- na.omit(penguins)
heat_tree(dat_raw, target_lab = 'species')
#> Registered S3 method overwritten by 'seriation':
#>   method         from 
#>   reorder.hclust gclus

But we can adjust a few graphical parameters. We can also add a custom layout for a subset of the nodes by specifying it in the custom_layout parameter. And we can relax the p value threshold p_thres to include more features that might be less important in classifying the samples but still included in the decision tree, or set show_all_feats = TRUE to include all features, even the ones that were not used to build the tree.

heat_tree(
  dat_raw, target_lab = 'species',
  # target_cols = c('#E69F00', '#56B4E9', '#009E73'),
  target_cols = c('darkorange','purple','cyan4'),
  custom_layout = data.frame(id = 1, x = 0.1, y = 1), 
  show_all_feats = TRUE,
  panel_space = 0.05, target_space = 0.2, tree_space_bottom = 0.1, heat_rel_height = 0.4)

We can also customize our heattree by passing parameters through to different ggparty geoms. These list parameters are named *_vars. For example:

heat_tree(
  dat_raw, target_lab = 'species',
  par_node_vars = list(
    label.size = 0.2,
    label.padding = ggplot2::unit(0.1, 'lines'),
    line_list = list(
      ggplot2::aes(label = paste('Node', id)),
      ggplot2::aes(label = splitvar),
      ggplot2::aes(label = paste('p =', formatC(p.value, format = 'e', digits = 2)))),
    line_gpar = list(
      list(size = 8),
      list(size = 8),
      list(size = 6)),
    id = 'inner'),
  terminal_vars = list(size = 0),
  cont_legend = TRUE, cate_legend = TRUE,
  edge_vars = list(size = 1, color = 'grey'))

Smart node layout

These extreme visualizations may not be very interpretable but serves the purpose of showing the ability to generalize of the node layout when the tree grows in size. The implemented smart layout weighs the x-position of the parent node according to the level of the child nodes as to avoid crossing of tree branches. This relative weight can be adjusted with the lev_fac parameter in heat_tree(). The default lev_fac = 1.3 seems to provide aesthetically pleasing trees, independent of the tree size.

In this next figure, on the top, lev_fac = 1 makes parent node perfectly in the middle of child nodes (note a few branch crossing), which contrasts lev_fac = 1.3 (default) on the bottom.

heat_tree(wine_quality_red, target_lab = 'target', lev_fac = 1, title = 'lev_fac = 1')
heat_tree(wine_quality_red, target_lab = 'target', title = 'lev_fac = 1.3')

Clustering

Unless you turn it off (clust_feats = FALSE, clust_samps = FALSE), treeheatr automatically performs clustering when organizing the heatmap. To order the features, clustering is run on the two groups of features, continuous and categorical, across all samples (including the class label, unless clust_class = FALSE). To order the samples, clustering is run on samples within each terminal node of all features (not only the displayed features). treeheatr uses cluster::daisy() with the Gower metric to incorporate both continuous and nominal categorical feature types. Now, cluster::daisy() may throw this warning if your dataset contains binary features:

binary variable(s) treated as interval scaled

but in general this is safe to ignore because the goal of clustering is to improve our interpretability of the tree-based model and not to make precise inference about each cluster.

Mixed data/feature types

As shown above in the penguins example, treeheatr supports mixed feature types.

For continuous variables/features, we can choose to either percentize (scale-rank), normalize (subtract the min and divide by the max) or scale (subtract the mean and divide by the standard deviation) each feature. Depending on what we want to show in the heatmap, one transformation method can be more effective than the other. Details on the strengths and weaknesses of different types of data transformation for heatmap display can be found in this vignette of the heatmaply package.

We highly recommend that, when dealing with mixed feature types, the user supply feat_types to indicate whether a feature should be considered ‘numeric’ (continuous) or ‘factor’ (categorical) as shown below. When feat_types is not specified, treeheatr automatically inferred each column type from the original dataset.

Regression

In general, compared to classification, regression task is more difficult to interpret with a decision tree. However, a heatmap may shed some light on how the tree groups the samples in different terminal nodes. Also, removing the terminal node label may show the groups better. Here’s an example:

heat_tree(data = galaxy,
          target_lab = 'target',
          task = 'regression',
          terminal_vars = NULL,
          tree_space_bottom = 0)

You’re the Warren Beatty of your heat_tree()

Anyone got that Food Wishes reference?

You can manually define your own tree for the custom_tree argument following the partykit vignette.

As an example, we will examine the datasets of COVID-19 cases in Wuhan from 2020-01-10 to 2020-02-18 from a recent study with the conditional decision tree.

First, a quick simplification of the column names:

library(dplyr)
library(partykit)

selected_train <- train_covid %>%
  select(
    LDH = 'Lactate dehydrogenase',
    hs_CRP = 'High sensitivity C-reactive protein',
    Lymphocyte = '(%)lymphocyte',
    outcome = Type2
  ) %>%
  na.omit()

selected_test <- test_covid %>%
  select(
    LDH = 'Lactate dehydrogenase',
    hs_CRP = 'High sensitivity C-reactive protein',
    Lymphocyte = '(%)lymphocyte',
    'outcome'
  )

We now apply the tree structure in Figure 2 of the original study (only show the training set for now):

# first argument indicates the index of the feature used for splitting

split_ldh <- partysplit(1L, breaks = 365)
split_crp <- partysplit(2L, breaks = 41.2)
split_lymp <- partysplit(3L, breaks = 14.7)

custom_tree <- partynode(1L, split = split_ldh , kids = list(
  partynode(2L, split = split_crp, kids = list(
    partynode(3L, info = 'Survival'),
    partynode(4L, split = split_lymp, kids = list(
      partynode(5L, info = 'Death'),
      partynode(6L, info = 'Survival'))))),
  partynode(7L, info = 'Death')))

heat_tree(
  selected_train,
  target_lab = 'outcome',
  label_map = c(`1` = 'Death', `0` = 'Survival'),
  custom_tree = custom_tree)

Apply the learned tree on external/holdout/test/validation dataset

You can print measures evaluating the conditional decision tree’s performance by setting print_eval = TRUE. By defaults, we show 5 measures for classification tasks:

and 4 measures for regression tasks:

You can also choose to show performance based on any other set of appropriate metrics listed on the yardstick reference page, for example, with

metrics = yardstick::metric_set(yardstick::f_meas)

as a heat_tree argument to show the F score (a combination of precision and recall).

Warning: We do not recommend print_eval on the training set because these measures may gives you an over-optimistic view of how the tree performs (this would show pure training accuracy, not cross-validated).

Let’s now apply the custom tree we (or really, Yan et al.) designed earlier and see how it performs on the test set:

heat_tree(
  selected_train,
  data_test = selected_test,
  target_lab = 'outcome',
  label_map = c(`1` = 'Death', `0` = 'Survival'),
  print_eval = TRUE,
  custom_tree = custom_tree,
  lev_fac = 3
)