Classification & Regression Trees

Prediction Trees are used to predict a response or class \( Y \) from input \( X_1, X_2, \ldots, X_n \). If it is a continuous response it's called a regression tree, if it is categorical, it's called a classification tree. At each node of the tree, we check the value of one the input \( X_i \) and depending of the (binary) answer we continue to the left or to the right subbranch. When we reach a leaf we will find the prediction (usually it is a simple statistic of the dataset the leaf represents, like the most common value from the available classes).

Contrary to linear or polynomial regression which are global models (the predictive formula is supposed to hold in the entire data space), trees try to partition the data space into small enough parts where we can apply a simple different model on each part. The non-leaf part of the tree is just the procedure to determine for each data \( x \) what is the model (i.e, which leaf) we will use to classify it.

One of the most comprehensible non-parametric methods is k-nearest-neighbors: find the points which are most similar to you, and do what, on average, they do. There are two big drawbacks to it: first, you're defining “similar” entirely in terms of the inputs, not the response; second, k is constant everywhere, when some points just might have more very-similar neighbors than others. Trees get around both problems: leaves correspond to regions of the input space (a neighborhood), but one where the responses are similar, as well as the inputs being nearby; and their size can vary arbitrarily. Prediction trees are adaptive nearest-neighbor methods. - From here

Regression Trees

Regression Trees like say linear regression, outputs an expected value given a certain output.

library(tree)

real.estate <- read.table("cadata.dat", header=TRUE)
tree.model <- tree(log(MedianHouseValue) ~ Longitude + Latitude, data=real.estate)
plot(tree.model)
text(tree.model, cex=.75)

plot of chunk unnamed-chunk-2

Notice that the leaf values represent the log of the price, since that was the way we represented the formula in the tree() function.

(note: Knitr seems to output the wrong values above, check the results yourself in R)

We can compare the predictions with the dataset (darker is more expensive) which seem to capture the global price trend:

price.deciles <- quantile(real.estate$MedianHouseValue, 0:10/10)
cut.prices    <- cut(real.estate$MedianHouseValue, price.deciles, include.lowest=TRUE)
plot(real.estate$Longitude, real.estate$Latitude, col=grey(10:2/11)[cut.prices], pch=20, xlab="Longitude",ylab="Latitude")
partition.tree(tree.model, ordvars=c("Longitude","Latitude"), add=TRUE)

plot of chunk unnamed-chunk-3

summary(tree.model)

Regression tree:
tree(formula = log(MedianHouseValue) ~ Longitude + Latitude, 
    data = real.estate)
Number of terminal nodes:  12 
Residual mean deviance:  0.166 = 3430 / 20600 
Distribution of residuals:
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
-2.7600 -0.2610 -0.0136  0.0000  0.2630  1.8400 

Deviance means here the mean squared error.

The flexibility of a tree is basically controlled by how many leaves they have, since that's how many cells they partition things into. The tree fitting function has a number of controls settings which limit how much it will grow | each node has to contain a certain number of points, and adding a node has to reduce the error by at least a certain amount. The default for the latter, min.dev, is 0:01; let's turn it down and see what happens:

tree.model2 <- tree(log(MedianHouseValue) ~ Longitude + Latitude, data=real.estate, mindev=0.001)
plot(tree.model2)
text(tree.model2, cex=.75)

plot of chunk unnamed-chunk-4

summary(tree.model2)

Regression tree:
tree(formula = log(MedianHouseValue) ~ Longitude + Latitude, 
    data = real.estate, mindev = 0.001)
Number of terminal nodes:  68 
Residual mean deviance:  0.105 = 2160 / 20600 
Distribution of residuals:
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
-2.9500 -0.1980 -0.0187  0.0000  0.2000  1.6100 

It's obviously much finer-grained than the previous example (68 leafs against 12), and does a better job of matching the actual prices (lower error).

Also, we can include all the variables, not only the latitude and longitude:

tree.model3 <- tree(log(MedianHouseValue) ~ ., data=real.estate)
plot(tree.model3)
text(tree.model3, cex=.75)

plot of chunk unnamed-chunk-5

summary(tree.model3)

Regression tree:
tree(formula = log(MedianHouseValue) ~ ., data = real.estate)
Variables actually used in tree construction:
[1] "MedianIncome"   "Latitude"       "Longitude"      "MedianHouseAge"
Number of terminal nodes:  15 
Residual mean deviance:  0.132 = 2720 / 20600 
Distribution of residuals:
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
-2.8600 -0.2270 -0.0147  0.0000  0.2070  2.0400 

Classification Trees

Classification trees output the predicted class for a given sample.

Let's use here the iris dataset (and split it into train and test sets):

set.seed(101)
alpha     <- 0.7 # percentage of training set
inTrain   <- sample(1:nrow(iris), alpha * nrow(iris))
train.set <- iris[inTrain,]
test.set  <- iris[-inTrain,]

There are two options for the output:

library(tree)

tree.model <- tree(Species ~ Sepal.Width + Petal.Width, data=train.set)
tree.model
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

 1) root 105 200 versicolor ( 0.33 0.36 0.30 )  
   2) Petal.Width < 0.8 35   0 setosa ( 1.00 0.00 0.00 ) *
   3) Petal.Width > 0.8 70 100 versicolor ( 0.00 0.54 0.46 )  
     6) Petal.Width < 1.7 40  20 versicolor ( 0.00 0.92 0.07 )  
      12) Petal.Width < 1.35 20   0 versicolor ( 0.00 1.00 0.00 ) *
      13) Petal.Width > 1.35 20  20 versicolor ( 0.00 0.85 0.15 )  
        26) Sepal.Width < 3.05 14  10 versicolor ( 0.00 0.79 0.21 ) *
        27) Sepal.Width > 3.05 6   0 versicolor ( 0.00 1.00 0.00 ) *
     7) Petal.Width > 1.7 30   9 virginica ( 0.00 0.03 0.97 )  
      14) Petal.Width < 1.85 8   6 virginica ( 0.00 0.12 0.88 ) *
      15) Petal.Width > 1.85 22   0 virginica ( 0.00 0.00 1.00 ) *
summary(tree.model)

Classification tree:
tree(formula = Species ~ Sepal.Width + Petal.Width, data = train.set)
Number of terminal nodes:  6 
Residual mean deviance:  0.208 = 20.6 / 99 
Misclassification error rate: 0.0381 = 4 / 105 
# Distributional prediction
my.prediction <- predict(tree.model, test.set) # gives the probability for each class
head(my.prediction)
   setosa versicolor virginica
5       1          0         0
10      1          0         0
12      1          0         0
15      1          0         0
16      1          0         0
18      1          0         0
# Point prediction
# Let's translate the probability output to categorical output
maxidx <- function(arr) {
    return(which(arr == max(arr)))
}
idx <- apply(my.prediction, c(1), maxidx)
prediction <- c('setosa', 'versicolor', 'virginica')[idx]
table(prediction, test.set$Species)

prediction   setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         11         1
  virginica       0          1        17
plot(tree.model)
text(tree.model)

plot of chunk unnamed-chunk-7

# Another way to show the data:
plot(iris$Petal.Width, iris$Sepal.Width, pch=19, col=as.numeric(iris$Species))
partition.tree(tree.model, label="Species", add=TRUE)
legend("topright",legend=unique(iris$Species), col=unique(as.numeric(iris$Species)), pch=19)

plot of chunk unnamed-chunk-7

summary(tree.model)

Classification tree:
tree(formula = Species ~ Sepal.Width + Petal.Width, data = train.set)
Number of terminal nodes:  6 
Residual mean deviance:  0.208 = 20.6 / 99 
Misclassification error rate: 0.0381 = 4 / 105 

We can prune the tree to prevent overfitting. The next function prune.tree() allows us to choose how many leafs we want the tree to have, and it returns the best tree with that size.

The argument newdata accepts new input for making the prune decision. If new data is not given, the method uses the original dataset from which the tree model was built.

For classification trees we can also use argument method="misclass" so that the pruning measure should be the number of misclassifications.

pruned.tree <- prune.tree(tree.model, best=4)
plot(pruned.tree)
text(pruned.tree)

plot of chunk unnamed-chunk-8

pruned.prediction <- predict(pruned.tree, test.set, type="class") # give the predicted class
table(pruned.prediction, test.set$Species)

pruned.prediction setosa versicolor virginica
       setosa         15          0         0
       versicolor      0         11         1
       virginica       0          1        17

This package can also do K-fold cross-validation using cv.tree() to find the best tree:

# here, let's use all the variables and all the samples
tree.model <- tree(Species ~ ., data=iris)
summary(tree.model)

Classification tree:
tree(formula = Species ~ ., data = iris)
Variables actually used in tree construction:
[1] "Petal.Length" "Petal.Width"  "Sepal.Length"
Number of terminal nodes:  6 
Residual mean deviance:  0.125 = 18 / 144 
Misclassification error rate: 0.0267 = 4 / 150 

cv.model <- cv.tree(tree.model)
plot(cv.model)

plot of chunk unnamed-chunk-9

cv.model$dev  # gives the deviance for each K (small is better)
[1]  59.09  52.62  52.58  73.34 142.82 336.22
best.size <- cv.model$size[which(cv.model$dev==min(cv.model$dev))] # which size is better?
best.size
[1] 4
# let's refit the tree model (the number of leafs will be no more than best.size)
cv.model.pruned <- prune.misclass(tree.model, best=best.size)
summary(cv.model.pruned)

Classification tree:
snip.tree(tree = tree.model, nodes = c(7L, 12L))
Variables actually used in tree construction:
[1] "Petal.Length" "Petal.Width" 
Number of terminal nodes:  4 
Residual mean deviance:  0.185 = 27 / 146 
Misclassification error rate: 0.0267 = 4 / 150 

The misclassification rate has just slighty increased with the pruning of the tree.

Package rpart

This package is faster than tree.

library(rpart)

rpart.tree <- rpart(Species ~ ., data=train.set)
plot(rpart.tree, uniform=TRUE, branch=0.6, margin=0.05)
text(rpart.tree, all=TRUE, use.n=TRUE)
title("Training Set's Classification Tree")

plot of chunk unnamed-chunk-10

predictions <- predict(rpart.tree, test.set, type="class")
table(test.set$Species, predictions)
            predictions
             setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         11         1
  virginica       0          2        16
prune.rpart.tree <- prune(rpart.tree, cp=0.02) # pruning the tree
plot(prune.rpart.tree, uniform=TRUE, branch=0.6)
text(prune.rpart.tree, all=TRUE, use.n=TRUE)

plot of chunk unnamed-chunk-10

An eg with different costs for errors:

lmat <- matrix(c(0,1,2,
                 1,0,100,
                 2,100,0), ncol = 3)
lmat
     [,1] [,2] [,3]
[1,]    0    1    2
[2,]    1    0  100
[3,]    2  100    0

So, misclassifying the 2nd class for the 3rd (or vice-versa) is highly costly.

rpart.tree <- rpart(Species ~ ., data=train.set, parms = list(loss = lmat))
predictions <- predict(rpart.tree, test.set, type="class")
table(test.set$Species, predictions)
            predictions
             setosa versicolor virginica
  setosa         15          0         0
  versicolor      2         10         0
  virginica       6          1        11

As we see, the algorithm made a different tree to minimize the costly errors.

plot(rpart.tree)
text(rpart.tree)

plot of chunk unnamed-chunk-13

A plotting function to better control the parameters:

## Define a plotting function with decent defaults
plot.rpart.obj <- function(rpart.obj, font.size = 0.8) {
    ## plot decision tree
    plot(rpart.obj,
         uniform   = T,    # if 'TRUE', uniform vertical spacing of the nodes is used
         branch    = 1,    # controls the shape of the branches from parent to child node
         compress  = F,    # if 'FALSE', the leaf nodes will be at the horizontal plot
         nspace    = 0.1,
         margin    = 0.1, # an extra fraction of white space to leave around the borders
         minbranch = 0.3)  # set the minimum length for a branch

    ## Add text
    text(x      = rpart.obj,   #
         splits = T,           # If tree are labeled with the criterion for the split
         all    = T,           # If 'TRUE', all nodes are labeled, otherwise just terminal nodes
         use.n  = T,           # Use numbers to annotate
         cex    = font.size)   # Font size
}

plot.rpart.obj(rpart.tree, 1)

plot of chunk unnamed-chunk-14

The package party gives better plotting and text functions:

library(partykit)

rparty.tree <- as.party(rpart.tree)
rparty.tree

Model formula:
Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width

Fitted party:
[1] root
|   [2] Petal.Length < 4.85
|   |   [3] Petal.Length < 2.45: setosa (n = 35, err = 0%)
|   |   [4] Petal.Length >= 2.45
|   |   |   [5] Petal.Length < 4.65: versicolor (n = 29, err = 0%)
|   |   |   [6] Petal.Length >= 4.65: versicolor (n = 7, err = 14%)
|   [7] Petal.Length >= 4.85
|   |   [8] Petal.Length < 5.15: virginica (n = 11, err = 27%)
|   |   [9] Petal.Length >= 5.15: virginica (n = 23, err = 0%)

Number of inner nodes:    4
Number of terminal nodes: 5
plot(rparty.tree)

plot of chunk unnamed-chunk-15

Just another eg, this time a regression tree:

fit <- rpart(Mileage~Price + Country + Reliability + Type, method="anova", data=cu.summary)
printcp(fit) # display the results

Regression tree:
rpart(formula = Mileage ~ Price + Country + Reliability + Type, 
    data = cu.summary, method = "anova")

Variables actually used in tree construction:
[1] Price Type 

Root node error: 1355/60 = 23

n=60 (57 observations deleted due to missingness)

     CP nsplit rel error xerror  xstd
1 0.623      0      1.00   1.02 0.178
2 0.132      1      0.38   0.54 0.104
3 0.025      2      0.25   0.38 0.085
4 0.012      3      0.22   0.38 0.088
5 0.010      4      0.21   0.40 0.088
plotcp(fit) # visualize cross-validation results

plot of chunk unnamed-chunk-16

summary(fit) # detailed summary of splits
Call:
rpart(formula = Mileage ~ Price + Country + Reliability + Type, 
    data = cu.summary, method = "anova")
  n=60 (57 observations deleted due to missingness)

       CP nsplit rel error xerror    xstd
1 0.62289      0    1.0000 1.0194 0.17820
2 0.13206      1    0.3771 0.5389 0.10441
3 0.02544      2    0.2451 0.3806 0.08544
4 0.01160      3    0.2196 0.3835 0.08766
5 0.01000      4    0.2080 0.3999 0.08764

Variable importance
  Price    Type Country 
     48      42      10 

Node number 1: 60 observations,    complexity param=0.6229
  mean=24.58, MSE=22.58 
  left son=2 (48 obs) right son=3 (12 obs)
  Primary splits:
      Price       < 9446  to the right,  improve=0.6229, (0 missing)
      Type        splits as  LLLRLL,     improve=0.5044, (0 missing)
      Reliability splits as  LLLRR,      improve=0.1263, (11 missing)
      Country     splits as  --LRLRRRLL, improve=0.1244, (0 missing)
  Surrogate splits:
      Type    splits as  LLLRLL,     agree=0.950, adj=0.750, (0 split)
      Country splits as  --LLLLRRLL, agree=0.833, adj=0.167, (0 split)

Node number 2: 48 observations,    complexity param=0.1321
  mean=22.71, MSE=8.498 
  left son=4 (23 obs) right son=5 (25 obs)
  Primary splits:
      Type        splits as  RLLRRL,     improve=0.43850, (0 missing)
      Price       < 12150 to the right,  improve=0.25750, (0 missing)
      Country     splits as  --RRLRL-LL, improve=0.13350, (0 missing)
      Reliability splits as  LLLRR,      improve=0.01637, (10 missing)
  Surrogate splits:
      Price   < 12220 to the right,  agree=0.812, adj=0.609, (0 split)
      Country splits as  --RRLRL-RL, agree=0.646, adj=0.261, (0 split)

Node number 3: 12 observations
  mean=32.08, MSE=8.576 

Node number 4: 23 observations,    complexity param=0.02544
  mean=20.7, MSE=2.907 
  left son=8 (10 obs) right son=9 (13 obs)
  Primary splits:
      Type    splits as  -LR--L,     improve=0.515400, (0 missing)
      Price   < 14960 to the left,   improve=0.131300, (0 missing)
      Country splits as  ----L-R--R, improve=0.007022, (0 missing)
  Surrogate splits:
      Price < 13570 to the right, agree=0.609, adj=0.1, (0 split)

Node number 5: 25 observations,    complexity param=0.0116
  mean=24.56, MSE=6.486 
  left son=10 (14 obs) right son=11 (11 obs)
  Primary splits:
      Price       < 11480 to the right,  improve=0.09693, (0 missing)
      Reliability splits as  LLRRR,      improve=0.07767, (4 missing)
      Type        splits as  L--RR-,     improve=0.04210, (0 missing)
      Country     splits as  --LRRR--LL, improve=0.02202, (0 missing)
  Surrogate splits:
      Country splits as  --LLLL--LR, agree=0.80, adj=0.545, (0 split)
      Type    splits as  L--RL-,     agree=0.64, adj=0.182, (0 split)

Node number 8: 10 observations
  mean=19.3, MSE=2.21 

Node number 9: 13 observations
  mean=21.77, MSE=0.7929 

Node number 10: 14 observations
  mean=23.86, MSE=7.694 

Node number 11: 11 observations
  mean=25.45, MSE=3.521 

# create additional plots
par(mfrow=c(1,2)) # two plots on one page
rsq.rpart(fit) # visualize cross-validation results  

Regression tree:
rpart(formula = Mileage ~ Price + Country + Reliability + Type, 
    data = cu.summary, method = "anova")

Variables actually used in tree construction:
[1] Price Type 

Root node error: 1355/60 = 23

n=60 (57 observations deleted due to missingness)

     CP nsplit rel error xerror  xstd
1 0.623      0      1.00   1.02 0.178
2 0.132      1      0.38   0.54 0.104
3 0.025      2      0.25   0.38 0.085
4 0.012      3      0.22   0.38 0.088
5 0.010      4      0.21   0.40 0.088

plot of chunk unnamed-chunk-16

par(mfrow=c(1,1)) 

# plot tree
plot(fit, uniform=TRUE, main="Regression Tree for Mileage ")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

plot of chunk unnamed-chunk-16


# create attractive postcript plot of tree
post(fit, file = "c:/tree2.ps", title = "Regression Tree for Mileage ")

Random Forests

Random forests are an ensemble learning method for classification (and regression) that operate by constructing a multitude of decision trees at training time and outputting the class that is the mode of the classes output by individual trees – Wikipedia

Check the manual for options and available tools.

library("randomForest")

r <- randomForest(Species ~., data=train.set, importance=TRUE, do.trace=100, ntree=100)
ntree      OOB      1      2      3
  100:   4.76%  0.00%  5.26%  9.38%
print(r)

Call:
 randomForest(formula = Species ~ ., data = train.set, importance = TRUE,      do.trace = 100, ntree = 100) 
               Type of random forest: classification
                     Number of trees: 100
No. of variables tried at each split: 2

        OOB estimate of  error rate: 4.76%
Confusion matrix:
           setosa versicolor virginica class.error
setosa         35          0         0     0.00000
versicolor      0         36         2     0.05263
virginica       0          3        29     0.09375
predictions <- predict(r, test.set)
table(test.set$Species, predictions)
            predictions
             setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         12         0
  virginica       0          3        15
# next function gives a graphical depiction of the marginal effect of a variable on the class probability (classification) or response (regression).
partialPlot(r, train.set, Petal.Width, "versicolor")

plot of chunk unnamed-chunk-17

We can extract a given tree or get some information about the ensemble.

t <- getTree(r, k=2) # get the second tree
print(t)
   left daughter right daughter split var split point status prediction
1              2              3         3        2.60      1          0
2              0              0         0        0.00     -1          1
3              4              5         1        6.05      1          0
4              6              7         4        1.75      1          0
5              8              9         4        1.55      1          0
6              0              0         0        0.00     -1          2
7              0              0         0        0.00     -1          3
8             10             11         3        5.00      1          0
9              0              0         0        0.00     -1          3
10             0              0         0        0.00     -1          2
11             0              0         0        0.00     -1          3
treesize(r) # size of trees of the ensemble
  [1]  6  6  6  8  6  6  9  8  5  8  7  8  7  5  7  6  8  4  4  5  5  8  8
 [24]  7  5  5  6  8  9  6  5  8  7  7  6  9  5  8  7  7  6  9  6  8  7  5
 [47]  9 11  3  7  6  3  7  5  5  5 10 11  8  6  5  8  5  9  6  7  5  6  8
 [70]  8 10  7  8  5  6  7  6  4  7  7  6  8  7  7  6  3  5  6  8  8  6  7
 [93]  7 10  6  8  6  8  5  8
hist(treesize(r))

plot of chunk unnamed-chunk-18

We can also tune the structure, ie, finding the best hyperparameters of the method via grid search:

library("e1071") # to access 'tune' method

tuned.r <- tune(randomForest, train.x = Species ~ .,
                data = train.set,
                validation.x = test.set)

best.model <- tuned.r$best.model
predictions <- predict(best.model, test.set)
table.random.forest <- table(test.set$Species, predictions)
table.random.forest
            predictions
             setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         11         1
  virginica       0          3        15
# computing overall error:
error.rate <- 1 - sum(diag(as.matrix(table.random.forest))) / sum(table.random.forest)
error.rate
[1] 0.08889

Conditional Inference Trees

Conditional inference trees estimate a regression relationship by binary recursive partitioning in a conditional inference framework. Roughly, the algorithm works as follows: 1) Test the global null hypothesis of independence between any of the input variables and the response (which may be multivariate as well). Stop if this hypothesis cannot be rejected. Otherwise select the input variable with strongest association to the resonse. This association is measured by a p-value corresponding to a test for the partial null hypothesis of a single input variable and the response. 2) Implement a binary split in the selected input variable. 3) Recursively repeat steps 1) and 2) – party package help file

library(party)

iris.model <- ctree(Species ~ . , data = train.set)
plot(iris.model)

plot of chunk unnamed-chunk-20

predictions <- predict(iris.model, test.set[,-5])
table(predictions, test.set$Species)

predictions  setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         11         1
  virginica       0          1        17
# what are the predicted probabilities for the given samples?
treeresponse(iris.model, newdata=iris[c(10,87,128),])
[[1]]
[1] 1 0 0

[[2]]
[1] 0 1 0

[[3]]
[1] 0.00000 0.03333 0.96667
# get the probabilities from the barplots showen above:
tapply(treeresponse(iris.model), where(iris.model), unique)
$`2`
[1] 1 0 0

$`5`
[1] 0 1 0

$`6`
[1] 0.0000 0.5714 0.4286

$`7`
[1] 0.00000 0.03333 0.96667
# The package is able to format the plot tree. Eg:
innerWeights <- function(node){
  grid.circle(gp = gpar(fill = "White", col = 1))
  mainlab <- paste( node$psplit$variableName, "\n(n = ")
  mainlab <- paste(mainlab, sum(node$weights),")" , sep = "")
  grid.text(mainlab,gp = gpar(col='red'))
}
plot(iris.model, type='simple', inner_panel = innerWeights)

plot of chunk unnamed-chunk-20