First, we load a few R packages

Attribution: A lot of the material for this lecture came from the following resources

Motivation

In the last lecture, we described two types of machine learning algorithms: linear approaches, including linear regression, generalized linear models (GLM), discriminant analysis, and model-free approaches (such as \(k\)-nearest neighbors). The linear approaches were limited in that the partition of the prediction space had to be linear (or in the case of QDA, quadratic).

Today, we look at a set powerful, popular, and well-studied methods that adapt to higher dimensions and also allow these regions to take more complex shapes, and in some cases, still produce models that are interpretable.

We will focus on decision trees (including both regression and classification decision trees) and their extension to random forests.

Decision trees

Decision trees can be applied to both regression and classification problems. We first consider regression problems, and then move on to classification.

Motivating example 1

Let’s use a decision tree to decide what to eat for lunch!

Suppose the things that matter to you are

  1. the location of the restaurants and
  2. waiting time

What we would like to do is partition all the options for what to eat based on our ideal waiting time and money we have, and then predict how much it will cost.

The figure below shoes a decision tree. It consists of splitting rules, starting at the top of tree and consists of the following components:

  • The tree grows from the root Whatever Food, which contains all possible food in the world.
  • Segments of the tree are known as branches
  • An internal node splits at some threshold, and two sides stand for two separated regions
  • Leaves (or regions or terminal nodes) are final decisions. Multiple leaves may point to the same label.

[image source]

We can also convert the tree into different regions for classification:

[image source]

The regions are

  • \(R_1 = \{X | \texttt{ wait } < 5, \texttt{ distance } < 100\}\) (Rice)
  • \(R_2 = \{X | \texttt{ wait } < 15, \texttt{ distance } > 100\}\) (Steak)
  • \(R_3 = \{X | \texttt{ wait } > 5, \texttt{ distance } < 100\}\) (Noodles)
  • \(R_4 = \{X | \texttt{ wait } > 15, \texttt{ distance } > 100\}\) (Burger)

And for regression decision trees, they operate by predicting an outcome variable \(Y\) by partitioning feature (predictor) space. So here we will consider another dimension (cost in this case):

[image source]

The predicted cost for those restaurants is the mean cost for the restaurants in the individual regions.

Motivating example 2

Consider the following dataset containing information on 572 different Italian olive oils from multiple regions in Italy.

## # A tibble: 572 x 10
##    Region Area  palmitic palmitoleic stearic oleic linoleic linolenic
##    <fct>  <fct>    <dbl>       <dbl>   <dbl> <dbl>    <dbl>     <dbl>
##  1 North… Sout…     1075          75     226  7823      672        36
##  2 North… Sout…     1088          73     224  7709      781        31
##  3 North… Sout…      911          54     246  8113      549        31
##  4 North… Sout…      966          57     240  7952      619        50
##  5 North… Sout…     1051          67     259  7771      672        50
##  6 North… Sout…      911          49     268  7924      678        51
##  7 North… Sout…      922          66     264  7990      618        49
##  8 North… Sout…     1100          61     235  7728      734        39
##  9 North… Sout…     1082          60     239  7745      709        46
## 10 North… Sout…     1037          55     213  7944      633        26
## # … with 562 more rows, and 2 more variables: arachidic <dbl>,
## #   eicosenoic <dbl>

We are interested in building a classification tree where Area is the outcome variable. How many areas are there?

## 
##       Sardinia Northern Italy Southern Italy 
##             98            151            323

OK there are three areas.

Let’s just consider two measured predictors: linoleic and eicosenoic. Suppose we wanted to predict the olive oil’s area using these two predictors. What method would you use?

Note that we can describe a classification algorithm using only these two predictors that would work pretty much perfectly:

The prediction algorithm inferred from the figure above is is what we call a decision tree. If eicosnoic is larger than 6.5, predict Southern Italy. If not, then if linoleic is larger than \(1053.5\) predict Sardinia and Norther Italy otherwise.

We can draw this decision tree like this:

In the figure above we used the rpart() function in the rpart R package which stands for ``Recursive Partitioning and Regression Trees’’. We’ll learn more about what that means in a bit.

Regression Trees

Let’s start with case of a continuous outcome. The general idea here is to build a decision tree and at end of each node we will have a different prediction \(\hat{Y}\) for the outcome \(Y\).

The regression tree model does the following:

  1. Divide or partition the predictor space (that is the possible values for \(X_1\), \(X_2\), … \(X_p\)) into \(J\) distinct and non-overlapping regions, \(R_1, R_2, \ldots, R_J\).
  2. For every observation that falls within region \(R_j\), we make the same predition, which is simply the mean of the response values for training observations in \(R_j\).

How to construct regions?

In theory, the regions could have any shape. However, we choose to divide the predictor space into high-dimensional rectangles, or boxes, for simplicity and for ease of interpretation of the resulting predictive model. The goal is to find boxes \(R_1\), … , \(R_J\) that minimize:

\[ RSS = \sum_{j=1}^J \sum_{i \in R_j} (y_i - \hat{y}_{R_j})^2 \] where \(\hat{y}_{R_j}\) is the mean response for the training observations within the \(j^{th}\) box.

This is a very computationally intenseive because we have to consider every possible partition of the feature space into \(J\) boxes.

Intead, we do a top-down, greedy approach known as recursive binary splitting. The ‘top-down’ approach successively splits the predictor space and the ‘greedy’ approach means at each step it looks for the best split made at a particular step, rather than looking ahead and picking a split that will lead to a better tree in some future step.

For example, consider finding a good predictor \(j\) to partition space its axis. A recursive algorithm would look like this:

  1. First select the predictor \(X_j\) and cutpoint \(s\) such that the splitting the predictor space into the regions \(R_1(j,s) = \{X | X_j < s\}\) (aka the region of predictor space in which \(X_j\) takes on a value less than \(s\)) and \(R_2(j,s) = \{X | X_j \geq s \}\) (aka the region of predictor space in which \(X_j\) takes on a value greater than or equal to \(s\)) leads to the greatest possible reduction in the residual sum of squares (RSS) or minimizes this:

\[ \sum_{i:\, x_i \in R_1(j,s))} (y_i - \hat{y}_{R_1})^2 + \sum_{i:\, x_i \in R_2(j,s))} (y_i - \hat{y}_{R_2})^2 \]

where \(\hat{y}_{R_1}\) and \(\hat{y}_{R_2}\) are the mean response for training observations in \(R_1(j,s)\) and \(R_2(j,s)\).

Finding values of \(j\) and \(s\) that minimize the above can be done quickly, especially when the number of features \(p\) is not too large.

  1. Next, we repeat the process, looking for the best predictor and best cutpoint in order to split the data further so as to minimize the RSS within each of the resulting regions.

However, this time, instead of splitting the entire predictor space, we split one of the two previously identified regions. We now have three regions. Again, we look to split one of these three regions further, so as to minimize the RSS.

  1. The process continues until a stopping criterion is reached; for instance, we may continue until no region contains more than five observations.

Predicting the response

Once the regions \(R_1\),…,\(R_J\) have been created, we predict the response for a given test observation using the mean of the training observations in the region to which that test observation belongs.

Tree pruning

To avoid overfitting the data (meaning poor test set performance because you have a very complex tree), a smaller tree with fewer splits (meaning fewer regions) might lead to lower variance and better interpretation (at the cost of slightly more bias).

A common solution to this is to grow a very large tree \(T_0\) and then prune it back to a subtree. Given a subtree, we can estimate its test error using cross-validation.

Instead of considering every subtree, we use something called cost complexity pruning or weakest link pruning with a nonnegative tuning parameter \(\alpha\). You can read more about Algorithm 8.1 on page 309.

For a brief summary of the cost complexity pruning, we borrow an idea (similar to using the lasso to control the complexity of a linear model) for controling the complexity of a tree:

For each value of \(\alpha\) there corresponds a subtree \(T \subset T_0\) such that

\[ \sum_{m=1}^{|T|} \sum_{x_i \in R_m} (y_i - \hat{y}_{R_m})^2 + \alpha |T| \]

where \(|T|\) represents the number of terminal nodes of the tree \(T\), \(R_m\) is the rectangle (i.e. subset of the predictor space) corresponding to the \(m^{th}\) terminal node and $_{R_m} is the predicted response associated with \(R_m\) – aka the mean of the training observations in \(R_m\).

The idea is that the tuning parameter \(\alpha\) controls a trade-off between the subtree’s complexity and its fit to the training data. When \(\alpha = 0\), then the subtree \(T\) will simply equal the original tree \(T_0\), because then the above quanityt just measures the training error.

However, as \(\alpha\) increases, there is a price to pay for having a tree with many terminal nodes, so the quantity above will tend to be minimized for a smaller subtree. Hence branches get pruned from the tree in a nested and predictable fashion.

We can select a value of \(\alpha\) using a validation set or using cross-validation. We then return to the full data set and obtain the subtree corresponding to \(\alpha\). This process is summarized in Algorithm 8.1.

Classification trees

A classification tree is very similar to a regression tree, except that it is used to predict a qualitative response rather than a quantitative one. Recall that for a regression tree, the predicted response for an observation is given by the mean response of the training observations that belong to the same terminal node.

In contrast, for a classification tree, we predict that each observation belongs to the most commonly occurring class of training observations in the region to which it belongs. In interpreting the results of a classification tree, we are often interested not only in the class prediction corresponding to a particular terminal node region, but also in the class proportions among the training observations that fall into that region.

We also use recursive binary splitting to grow a classification tree, but we cannot use \(RSS\) as the criterion for making the binary splits. A natural alternative to \(RSS\) is the classification error rate. We assign an observation in a given region to the most commonly occurring class of training observations in that region. Then, the classification error rate is simply the fraction of the training observations in that region that do not belong to the most common class:

\[ E = 1 - \max (\hat{p}_{mk}) \]

where \(\hat{p}_{mk}\) represents the proportion of training observations in the \(m^{th}\) region that are from the \(k^{th}\) class. However, it turns out that classification error is not sufficiently sensitive for tree-growing, and in practice two other measures are preferable.

  1. The Gini index is defined by

\[ G = \sum_{k=1}^K \hat{p}_{mk} * (1 - \hat{p}_{mk} ) \]

and is a measure of total variance across the \(K\) classes. It is not hard to see that the Gini index takes on a small value if all of the \(\hat{p}_{mk}\)s are close to zero or one. For this reason the Gini index is referred to as a measure of node purity (a small value indicates that a node contains predominantly observations from a single class).

  1. An alternative to the Gini index is cross-entropy, given by

\[ D = - \sum_{k=1}^K \hat{p}_{mk} \log (\hat{p}_{mk} ) \]

Since \(0 \leq \hat{p}_{mk} \leq 1\), it follows that \(0 \leq −\hat{p}_{mk} log(\hat{p}_{mk})\)

Like the Gini index, the cross-entropy will take on a small value if the \(m^{th}\) node is pure (aka if \(\hat{p}_{mk}\)s are close to zero or one). In fact, it turns out that the Gini index and the cross-entropy are quite similar numerically.

When building a classification tree, either the Gini index or the cross-entropy are typically used to evaluate the quality of a particular split (since these two approaches are more sensitive to node purity than is the classification error rate). Any of these three approaches might be used when pruning the tree, but the classification error rate is preferable if prediction accuracy of the final pruned tree is the goal.

What is the data?

In this lecture, we are going to build classification algorithms to predict whether or not domestic flights will arrive late to their destinations. To do this, we will use data that come from the hadley/nycflights13 github repo.

“This package contains information about all flights that departed from NYC (e.g. EWR, JFK and LGA) to destinations in the United States, Puerto Rico, and the American Virgin Islands) in 2013: 336,776 flights in total. To help understand what causes delays, it also includes a number of other useful datasets.”

This package provides the following data tables.

  • flights: all flights that departed from NYC in 2013
  • weather: hourly meterological data for each airport
  • planes: construction information about each plane
  • airports: airport names and locations
  • airlines: translation between two letter carrier codes and names

Data import

To load the data, it is very straight forward.

We can peek at what is in each data, by printing it:

## # A tibble: 336,776 x 19
##     year month   day dep_time sched_dep_time dep_delay arr_time
##    <int> <int> <int>    <int>          <int>     <dbl>    <int>
##  1  2013     1     1      517            515         2      830
##  2  2013     1     1      533            529         4      850
##  3  2013     1     1      542            540         2      923
##  4  2013     1     1      544            545        -1     1004
##  5  2013     1     1      554            600        -6      812
##  6  2013     1     1      554            558        -4      740
##  7  2013     1     1      555            600        -5      913
##  8  2013     1     1      557            600        -3      709
##  9  2013     1     1      557            600        -3      838
## 10  2013     1     1      558            600        -2      753
## # … with 336,766 more rows, and 12 more variables: sched_arr_time <int>,
## #   arr_delay <dbl>, carrier <chr>, flight <int>, tailnum <chr>,
## #   origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, hour <dbl>,
## #   minute <dbl>, time_hour <dttm>
## # A tibble: 16 x 2
##    carrier name                       
##    <chr>   <chr>                      
##  1 9E      Endeavor Air Inc.          
##  2 AA      American Airlines Inc.     
##  3 AS      Alaska Airlines Inc.       
##  4 B6      JetBlue Airways            
##  5 DL      Delta Air Lines Inc.       
##  6 EV      ExpressJet Airlines Inc.   
##  7 F9      Frontier Airlines Inc.     
##  8 FL      AirTran Airways Corporation
##  9 HA      Hawaiian Airlines Inc.     
## 10 MQ      Envoy Air                  
## 11 OO      SkyWest Airlines Inc.      
## 12 UA      United Air Lines Inc.      
## 13 US      US Airways Inc.            
## 14 VX      Virgin America             
## 15 WN      Southwest Airlines Co.     
## 16 YV      Mesa Airlines Inc.

Data wrangling

Next, let’s explore what are the column names inside each of these datasets.

## $flights
##  [1] "year"           "month"          "day"            "dep_time"      
##  [5] "sched_dep_time" "dep_delay"      "arr_time"       "sched_arr_time"
##  [9] "arr_delay"      "carrier"        "flight"         "tailnum"       
## [13] "origin"         "dest"           "air_time"       "distance"      
## [17] "hour"           "minute"         "time_hour"     
## 
## $airlines
## [1] "carrier" "name"   
## 
## $weather
##  [1] "origin"     "year"       "month"      "day"        "hour"      
##  [6] "temp"       "dewp"       "humid"      "wind_dir"   "wind_speed"
## [11] "wind_gust"  "precip"     "pressure"   "visib"      "time_hour" 
## 
## $airports
## [1] "faa"   "name"  "lat"   "lon"   "alt"   "tz"    "dst"   "tzone"
## 
## $planes
## [1] "tailnum"      "year"         "type"         "manufacturer"
## [5] "model"        "engines"      "seats"        "speed"       
## [9] "engine"

We see that some of the column names overlap. For example, the column name carrier exists in both flights and airlines. It would be nice to have the full name of the carrier instead of just the abbreviation.

To do this, we can use the join functions from the dplyr package. For example, to the flights and airlines dataset, we can use the left_join() function:

## # A tibble: 336,776 x 3
##    arr_delay carrier name                    
##        <dbl> <chr>   <chr>                   
##  1        11 UA      United Air Lines Inc.   
##  2        20 UA      United Air Lines Inc.   
##  3        33 AA      American Airlines Inc.  
##  4       -18 B6      JetBlue Airways         
##  5       -25 DL      Delta Air Lines Inc.    
##  6        12 UA      United Air Lines Inc.   
##  7        19 B6      JetBlue Airways         
##  8       -14 EV      ExpressJet Airlines Inc.
##  9        -8 B6      JetBlue Airways         
## 10         8 AA      American Airlines Inc.  
## # … with 336,766 more rows

Now let’s combine 4 of these datasets together. Note, in each case, I’m carefully specifying what to join each dataset by.

## # A tibble: 336,776 x 38
##    year.x month   day dep_time sched_dep_time dep_delay arr_time
##     <int> <int> <int>    <int>          <int>     <dbl>    <int>
##  1   2013     1     1      517            515         2      830
##  2   2013     1     1      533            529         4      850
##  3   2013     1     1      542            540         2      923
##  4   2013     1     1      544            545        -1     1004
##  5   2013     1     1      554            600        -6      812
##  6   2013     1     1      554            558        -4      740
##  7   2013     1     1      555            600        -5      913
##  8   2013     1     1      557            600        -3      709
##  9   2013     1     1      557            600        -3      838
## 10   2013     1     1      558            600        -2      753
## # … with 336,766 more rows, and 31 more variables: sched_arr_time <int>,
## #   arr_delay <dbl>, carrier <chr>, flight <int>, tailnum <chr>,
## #   origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, hour <dbl>,
## #   minute <dbl>, time_hour.x <dttm>, name <chr>, temp <dbl>, dewp <dbl>,
## #   humid <dbl>, wind_dir <dbl>, wind_speed <dbl>, wind_gust <dbl>,
## #   precip <dbl>, pressure <dbl>, visib <dbl>, time_hour.y <dttm>,
## #   year.y <int>, type <chr>, manufacturer <chr>, model <chr>,
## #   engines <int>, seats <int>, speed <int>, engine <chr>

Exploratory Data Analysis

The column we are interested in is the arr_delay (arrival delays in minutes) where the negative times represent early arrivals.

What would some variables that you think would be influential on whether or not a plane has a delayed arrival?

One thing might be whether or not it had a delayed departure. Let’s create a plot to see that relationship.

## Warning: Removed 9430 rows containing missing values (geom_point).

Yup, that is strongly related.

Ok, how about airlines carriers. Are there certain airlines that are have more delayed arrivals (on average) compared to other airlines?

## Warning: Removed 9430 rows containing non-finite values (stat_boxplot).

Possibily.

## # A tibble: 16 x 3
##    name                        med_arr_delay `n()`
##    <chr>                               <dbl> <int>
##  1 Frontier Airlines Inc.                  6   685
##  2 AirTran Airways Corporation             5  3260
##  3 Envoy Air                              -1 26397
##  4 ExpressJet Airlines Inc.               -1 54173
##  5 Mesa Airlines Inc.                     -2   601
##  6 JetBlue Airways                        -3 54635
##  7 Southwest Airlines Co.                 -3 12275
##  8 United Air Lines Inc.                  -6 58665
##  9 US Airways Inc.                        -6 20536
## 10 Endeavor Air Inc.                      -7 18460
## 11 SkyWest Airlines Inc.                  -7    32
## 12 Delta Air Lines Inc.                   -8 48110
## 13 American Airlines Inc.                 -9 32729
## 14 Virgin America                         -9  5162
## 15 Hawaiian Airlines Inc.                -13   342
## 16 Alaska Airlines Inc.                  -17   714

What about which of the three airports that the flight originated from?

## # A tibble: 3 x 2
##   origin `median(arr_delay, na.rm = TRUE)`
##   <chr>                              <dbl>
## 1 EWR                                   -4
## 2 JFK                                   -6
## 3 LGA                                   -5

What about the size of the plane? A surrogate variable we could explore is the number of seats on a plane as proxy for the size.

## Warning: Removed 57759 rows containing missing values (geom_point).

What about the hour of the day that the flight leaves?

## Warning: Removed 9430 rows containing missing values (geom_point).

OK, so let’s create a new column titled arr_delay_status that represents whether or not the plane arrived more than 15 mins late to its destination. We will also select a subset of variables to consider for purposes of this lecutre. Finally, we drop any rows with NA and downsample to only 5,000 rows to keep the computational side small for the lecture.

## [1] 5000    6

We can also explore whether or not we have a balanced dataset (i.e. we might expect that we have more 0s vs 1s, otherwise that would be really bad for airlines….)

## 
##    0    1 
## 3811 1189

Data analysis

Split data into train/tune/test

We will split the data into a training and testing using the createDataPartition() function in the caret package with the argument p being the percentages of data that goes into training:

Classification trees using rpart

To build a classification tree, we will use the train() function with the method = "rpart" argument from the caret package. We briefly saw this function in our introduction to machine learning lecture. Now you know a bit more about what this means.

## CART 
## 
## 4001 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 4001, 4001, 4001, 4001, 4001, 4001, ... 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa    
##   0.00105042  0.8931700  0.6871750
##   0.00157563  0.8959278  0.6933095
##   0.59978992  0.8320790  0.3659109
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.00157563.

We can see how are we are doing in our training error with the confusionMatrix() function.

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 2964   85
##          1  296  656
##                                           
##                Accuracy : 0.9048          
##                  95% CI : (0.8953, 0.9137)
##     No Information Rate : 0.8148          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.7158          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9092          
##             Specificity : 0.8853          
##          Pos Pred Value : 0.9721          
##          Neg Pred Value : 0.6891          
##              Prevalence : 0.8148          
##          Detection Rate : 0.7408          
##    Detection Prevalence : 0.7621          
##       Balanced Accuracy : 0.8972          
##                                           
##        'Positive' Class : 0               
## 

Note: Kappa (or Cohen’s Kappa) is like classification accuracy, except that it is normalized at the baseline of random chance on your dataset. It is a more useful measure to use on problems that have an imbalance in the classes (e.g. 70-30 split for classes 0 and 1 and you can achieve 70% accuracy by predicting all instances are for class 0).

We can plot the model using the rpart.plot() function.

Each node shows:

  • the predicted of a delayed arrival (by 15 mins) or not,
  • the predicted probability of a delayed arrival,
  • the percentage of observations in the node.

Now, if you look closely above, you will see that there is some tuning going on. We haven’t talked about this yet, but if you are using the caret package with method = "rpart", this is pruning the tree. The pruning is happening using a complexity parameter (cp). This that \(\alpha\) tuning parameter that we talked about above. If you do not want to use the algorithm, you can control this parameter using the tuneGrid argument in train().

## CART 
## 
## 4001 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 4001, 4001, 4001, 4001, 4001, 4001, ... 
## Resampling results across tuning parameters:
## 
##   cp   Accuracy   Kappa    
##   0.0  0.8813087  0.6563895
##   0.2  0.9014211  0.7108880
##   0.4  0.9014211  0.7108880
##   0.6  0.8439677  0.4214045
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.4.

Next, we set up the parameters using the trainControl() function
in the caret package to provide more details on how to train the algorithm in train(). The default is to use the bootstrap and here number refers to the number or resampling iterations.

## CART 
## 
## 4001 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Bootstrapped (10 reps) 
## Summary of sample sizes: 4001, 4001, 4001, 4001, 4001, 4001, ... 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa    
##   0.00105042  0.8885486  0.6661627
##   0.00157563  0.8922078  0.6752045
##   0.59978992  0.8471774  0.4267014
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.00157563.

Alternatively, we can ask for number=5 cross-fold in a cross-fold validation for tuning our complexity parameter.

## CART 
## 
## 4001 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3200, 3202, 3201, 3201, 3200 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa    
##   0.00105042  0.8992708  0.6988161
##   0.00157563  0.8987718  0.6992623
##   0.59978992  0.8433092  0.4164130
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.00105042.

You can also try method="repeatedcv", number=5, repeats=3 in our cross validation and ask to repeat that three times (repeat=3).

Classification using glm

We can also compare to how the regression trees compare to something like logistic regression that we learned last time.

## Generalized Linear Model 
## 
## 4001 samples
##    5 predictor
##    2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3200, 3202, 3200, 3201, 3201 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.9027799  0.7104051
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 2963   86
##          1  291  661
##                                           
##                Accuracy : 0.9058          
##                  95% CI : (0.8963, 0.9147)
##     No Information Rate : 0.8133          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.7194          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9106          
##             Specificity : 0.8849          
##          Pos Pred Value : 0.9718          
##          Neg Pred Value : 0.6943          
##              Prevalence : 0.8133          
##          Detection Rate : 0.7406          
##    Detection Prevalence : 0.7621          
##       Balanced Accuracy : 0.8977          
##                                           
##        'Positive' Class : 0               
## 

More details on decision trees

Why use decision trees?

Decision trees for regression and classification have a number of advantages over the more classical classification approaches.

Advantages

  1. Trees are very easy to explain to people. In fact, they are even easier to explain than linear regression!
  2. Some people believe that decision trees more closely mirror human decision-making than do the regression and classification approaches seen in previous lectures.
  3. Trees can be displayed graphically, and are easily interpreted even by a non-expert (especially if they are small).
  4. Trees can easily handle qualitative predictors without the need to create dummy variables.

Disadvantages

  1. Trees generally do not have the same level of predictive accuracy as some of the other regression and classification approaches .

However, by aggregating many decision trees, using methods like bagging, random forests, and boosting, the predictive performance of trees can be substantially improved. We introduce these concepts next.

Bagging

Bootstrap aggregation (or bagging) is a general-purpose technique used to improve the variance of a statistical learning method. Here, we will use it to improve the performance of decision trees, which suffers from high variance. Meaning if we split the training data into two parts at random, and fit a decision tree to both halves, the results that we get could be quite different.

In general, to reduce the variance, one approach is to take many training sets from the population, build a separate prediction model (e.g. a decision tree) using each training set, and average the resulting predictions (e.g. majority vote). In other words, we could calculate \(\hat{f}^{1}(x)\), \(\hat{f}^2(x)\), …, \(\hat{f}^B(x)\) using \(B\) separate training sets, and average them in order to obtain a single low-variance statistical learning model, given by

\[ \hat{f}_{avg}(x) = \frac{1}{B} \sum_{b=1}^B \hat{f}^b(x) \]

Of course, this is not practical because we generally do not have access to multiple training sets.

The key idea here is to use boostrap samples (or random sampling with replacement) from the (single) training data set. We generate \(B\) different bootstrapped training datasets, train our method on the \(b^{th}\) bootstrapped training set in order to get \(\hat{f}^{∗b}(x)\), and finally average all the predictions, to obtain

\[ \hat{f}_{bag}(x) = \frac{1}{B} \sum_{b=1}^B \hat{f}^{*b}(x) \]

This is called bagging.

Bagging with regression trees

To apply bagging to regression trees with a quantitative outcome \(Y\) :

  1. Construct \(B\) trees using \(B\) bootstrapped training sets (trees should be deep and not pruned)
  2. Average the resulting predictions

Hence each individual tree has high variance, but low bias. Averaging these \(B\) trees reduces the variance.

Bagging has been demonstrated to give impressive improvements in accuracy by combining together hundreds or even thousands of trees into a single procedure.

Bagging with classification trees

To apply bagging to classification trees with a qualitative outcome \(Y\):

Bagging be extended to a classification problem using a few possible approaches, but the simplest is as follows.

  1. For a given test observation, we can record the class predicted by each of the \(B\) trees
  2. Average the resulting predictions by taking a majority vote (the overall prediction is the most commonly occurring class among the B predictions)

The number of trees \(B\) is not a critical parameter with bagging. Using a very large value of \(B\) will not lead to overfitting. In practice we use a value of \(B\) sufficiently large that the error has settled down. Using \(B = 100\) is a good starting place.

Ok, let’s try bagging with classification trees. Here, we will use the train() function with the method = "treebag" argument from the caret package.

Note: How did I know what method to pick?

Use help file ?train or look on caret page or use this:

##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        
##   [4] "adaboost"            "amdai"               "ANFIS"              
##   [7] "avNNet"              "awnb"                "awtan"              
##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        
##  [13] "bagFDA"              "bagFDAGCV"           "bam"                
##  [16] "bartMachine"         "bayesglm"            "binda"              
##  [19] "blackboost"          "blasso"              "blassoAveraged"     
##  [22] "bridge"              "brnn"                "BstLm"              
##  [25] "bstSm"               "bstTree"             "C5.0"               
##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           
##  [31] "cforest"             "chaid"               "CSimca"             
##  [34] "ctree"               "ctree2"              "cubist"             
##  [37] "dda"                 "deepboost"           "DENFIS"             
##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            
##  [43] "dwdRadial"           "earth"               "elm"                
##  [46] "enet"                "evtree"              "extraTrees"         
##  [49] "fda"                 "FH.GBML"             "FIR.DM"             
##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            
##  [55] "FS.HGD"              "gam"                 "gamboost"           
##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      
##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h2o"            
##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       
##  [67] "GFS.LT.RS"           "GFS.THRIFT"          "glm.nb"             
##  [70] "glm"                 "glmboost"            "glmnet_h2o"         
##  [73] "glmnet"              "glmStepAIC"          "gpls"               
##  [76] "hda"                 "hdda"                "hdrda"              
##  [79] "HYFIS"               "icr"                 "J48"                
##  [82] "JRip"                "kernelpls"           "kknn"               
##  [85] "knn"                 "krlsPoly"            "krlsRadial"         
##  [88] "lars"                "lars2"               "lasso"              
##  [91] "lda"                 "lda2"                "leapBackward"       
##  [94] "leapForward"         "leapSeq"             "Linda"              
##  [97] "lm"                  "lmStepAIC"           "LMT"                
## [100] "loclda"              "logicBag"            "LogitBoost"         
## [103] "logreg"              "lssvmLinear"         "lssvmPoly"          
## [106] "lssvmRadial"         "lvq"                 "M5"                 
## [109] "M5Rules"             "manb"                "mda"                
## [112] "Mlda"                "mlp"                 "mlpKerasDecay"      
## [115] "mlpKerasDecayCost"   "mlpKerasDropout"     "mlpKerasDropoutCost"
## [118] "mlpML"               "mlpSGD"              "mlpWeightDecay"     
## [121] "mlpWeightDecayML"    "monmlp"              "msaenet"            
## [124] "multinom"            "mxnet"               "mxnetAdam"          
## [127] "naive_bayes"         "nb"                  "nbDiscrete"         
## [130] "nbSearch"            "neuralnet"           "nnet"               
## [133] "nnls"                "nodeHarvest"         "null"               
## [136] "OneR"                "ordinalNet"          "ordinalRF"          
## [139] "ORFlog"              "ORFpls"              "ORFridge"           
## [142] "ORFsvm"              "ownn"                "pam"                
## [145] "parRF"               "PART"                "partDSA"            
## [148] "pcaNNet"             "pcr"                 "pda"                
## [151] "pda2"                "penalized"           "PenalizedLDA"       
## [154] "plr"                 "pls"                 "plsRglm"            
## [157] "polr"                "ppr"                 "PRIM"               
## [160] "protoclass"          "qda"                 "QdaCov"             
## [163] "qrf"                 "qrnn"                "randomGLM"          
## [166] "ranger"              "rbf"                 "rbfDDA"             
## [169] "Rborist"             "rda"                 "regLogistic"        
## [172] "relaxo"              "rf"                  "rFerns"             
## [175] "RFlda"               "rfRules"             "ridge"              
## [178] "rlda"                "rlm"                 "rmda"               
## [181] "rocc"                "rotationForest"      "rotationForestCp"   
## [184] "rpart"               "rpart1SE"            "rpart2"             
## [187] "rpartCost"           "rpartScore"          "rqlasso"            
## [190] "rqnc"                "RRF"                 "RRFglobal"          
## [193] "rrlda"               "RSimca"              "rvmLinear"          
## [196] "rvmPoly"             "rvmRadial"           "SBC"                
## [199] "sda"                 "sdwd"                "simpls"             
## [202] "SLAVE"               "slda"                "smda"               
## [205] "snn"                 "sparseLDA"           "spikeslab"          
## [208] "spls"                "stepLDA"             "stepQDA"            
## [211] "superpc"             "svmBoundrangeString" "svmExpoString"      
## [214] "svmLinear"           "svmLinear2"          "svmLinear3"         
## [217] "svmLinearWeights"    "svmLinearWeights2"   "svmPoly"            
## [220] "svmRadial"           "svmRadialCost"       "svmRadialSigma"     
## [223] "svmRadialWeights"    "svmSpectrumString"   "tan"                
## [226] "tanSearch"           "treebag"             "vbmpRadial"         
## [229] "vglmAdjCat"          "vglmContRatio"       "vglmCumulative"     
## [232] "widekernelpls"       "WM"                  "wsrf"               
## [235] "xgbDART"             "xgbLinear"           "xgbTree"            
## [238] "xyf"

Variable Importance Measures

Bagging typically results in improved accuracy over prediction using a single tree. Unfortunately, however, it can be difficult to interpret the resulting model. Recall that one of the advantages of decision trees is the attractive and easily interpreted diagram that results. However, when we bag a large number of trees, it is no longer possible to represent the resulting statistical learning procedure using a single tree, and it is no longer clear which variables are most important to the procedure. Thus, bagging improves prediction accuracy at the expense of interpretability.

Although the collection of bagged trees is much more difficult to interpret than a single tree, one can obtain an overall summary of the importance of each predictor using the RSS (for bagging regression trees) or the Gini index (for bagging classification trees).

In the case of bagging regression trees, we can record the total amount that the RSS is decreased due to splits over a given predictor, averaged over all \(B\) trees. A large value indicates an important predictor. Similarly, in the context of bagging classification trees, we can add up the total amount that the Gini index is decreased by splits over a given predictor, averaged over all \(B\) trees.

These are known as variable importances.

For example, consider a set of predictors:

The x-axis is “Importance of predictors” calculated as e.g. total amount that the RSS is decreased due to splits over a given predictor, averaged over all \(B\) trees.

You can read about them in Chapter 15 and see an example.

Random Forests

Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered, a random sample of \(m\) predictors is chosen as split candidates from the full set of \(p\) predictors. The split is allowed to use only one of those \(m\) predictors.

A fresh sample of \(m\) predictors is taken at each split, and typically we choose \(m \approx \sqrt{p}\), that is, the number of predictors considered at each split is approximately equal to the square root of the total number of predictors.

In other words, in building a random forest, at each split in the tree, the algorithm is not even allowed to consider a majority of the available predictors. This may sound crazy, but it has a clever rationale. Suppose that there is one very strong predictor in the data set, along with a number of other moderately strong predictors. Then in the collection of bagged trees, most or all of the trees will use this strong predictor in the top split. Consequently, all of the bagged trees will look quite similar to each other. Hence the predictions from the bagged trees will be highly correlated.

Unfortunately, averaging many highly correlated quantities does not lead to as large of a reduction in variance as averaging many uncorrelated quantities. In particular, this means that bagging will not lead to a substantial reduction in variance over a single tree in this setting.

Random forests overcome this problem by forcing each split to consider only a subset of the predictors. Therefore, on average \((p − m)/p\) of the splits will not even consider the strong predictor, and so other predictors will have more of a chance. We can think of this process as decorrelating the trees, thereby making the average of the resulting trees less variable and hence more reliable.

Here we will use the train() function with the method = "rf" argument from the caret package.

## 
## Call:
## summary.resamples(object = bagging_results)
## 
## Models: treebag, rf 
## Number of resamples: 5 
## 
## Accuracy 
##            Min.  1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## treebag 0.86375 0.881250 0.8825000 0.8810272 0.8862500 0.8913858    0
## rf      0.87375 0.897628 0.8986233 0.8952749 0.9013733 0.9050000    0
## 
## Kappa 
##              Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## treebag 0.6203942 0.6569343 0.6642220 0.6584336 0.6747386 0.6758789    0
## rf      0.6189758 0.7020549 0.7037096 0.6908719 0.7113092 0.7183099    0

Relationship between bagging and random forests

If a random forest is built using \(m = p\), then this amounts simply to bagging.

Boosting

Boosting is another approach for improving the predictions resulting from a decision tree. Instead of bagging (or building a tree on a bootstrap data set, independent of the other trees), boosting grows the trees sequentially: each tree is grown using information from previously grown trees. Boosting does not involve bootstrap sampling; instead each tree is fit on a modified version of the original data set.

To read about the algorithmic details of boosting, check out Algorithm 8.2: Boosting for Regression Trees.

We won’t go into the details, but this is the main idea:

Unlike fitting a single large decision tree to the data, which amounts to fitting the data hard and potentially overfitting, the boosting approach instead learns slowly.

Given the current model, we fit a decision tree to the residuals from the model. That is, we fit a tree using the current residuals, rather than the outcome \(Y\), as the response. We then add this new decision tree into the fitted function in order to update the residuals.

The idea is we are slowly improve \(\hat{f}\) in areas where it does not perform well. In general, statistical learning approaches that learn slowly tend to perform well. Note that in boosting, unlike in bagging, the construction of each tree depends strongly on the trees that have already been grown.

Here we use the method=gbm argument for the which uses the gbm R package for Generalized Boosted Regression Models

## 
## Call:
## summary.resamples(object = summarize_results)
## 
## Models: treebag, rf, boost 
## Number of resamples: 5 
## 
## Accuracy 
##              Min.  1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## treebag 0.8637500 0.881250 0.8825000 0.8810272 0.8862500 0.8913858    0
## rf      0.8737500 0.897628 0.8986233 0.8952749 0.9013733 0.9050000    0
## boost   0.8926342 0.893617 0.9012500 0.9030199 0.9062500 0.9213483    0
## 
## Kappa 
##              Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## treebag 0.6203942 0.6569343 0.6642220 0.6584336 0.6747386 0.6758789    0
## rf      0.6189758 0.7020549 0.7037096 0.6908719 0.7113092 0.7183099    0
## boost   0.6802276 0.6842323 0.7066196 0.7146349 0.7358922 0.7662029    0

For more information on the caret package, you can read through the nice documention to see what other algorithms are available for decision trees.

Checking test error rate

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 730  32
##          1  71 166
##                                           
##                Accuracy : 0.8969          
##                  95% CI : (0.8764, 0.9151)
##     No Information Rate : 0.8018          
##     P-Value [Acc > NIR] : 3.211e-16       
##                                           
##                   Kappa : 0.698           
##                                           
##  Mcnemar's Test P-Value : 0.0001809       
##                                           
##             Sensitivity : 0.9114          
##             Specificity : 0.8384          
##          Pos Pred Value : 0.9580          
##          Neg Pred Value : 0.7004          
##              Prevalence : 0.8018          
##          Detection Rate : 0.7307          
##    Detection Prevalence : 0.7628          
##       Balanced Accuracy : 0.8749          
##                                           
##        'Positive' Class : 0               
##