A Classification Tree Example

Let’s generate a model with nonlinear classification rule.

    set.seed(1)
    n = 500
    x1 = runif(n, -1, 1)
    x2 = runif(n, -1, 1)
    y = rbinom(n, size = 1, prob = ifelse(x1^2 + x2^2 < 0.6, 0.9, 0.1))
    
    par(mar=rep(2,4))
    plot(x1, x2, col = ifelse(y == 1, "deepskyblue", "darkorange"), pch = 19)
    symbols(0, 0, circles = sqrt(0.6), add = TRUE, inches = FALSE, cex = 2)

A classification tree model is recursively splitting the feature space such that the

    library(rpart)
    rpart.fit = rpart(as.factor(y)~x1+x2, data = data.frame(x1, x2, y))
    par(mfrow = c(1, 2))
    
    # the tree structure    
    par(mar=rep(0.5,4))
    plot(rpart.fit)
    text(rpart.fit)    
    
    # and the tuning parameter 
    par(mar=rep(2,4))
    plotcp(rpart.fit)    

    # if you want to peek into the tree 
    
    rpart.fit$cptable
##           CP nsplit rel error    xerror       xstd
## 1 0.17040359      0 1.0000000 1.0000000 0.04984280
## 2 0.14798206      3 0.4843049 0.6816143 0.04612343
## 3 0.01121076      4 0.3363229 0.4125561 0.03885386
## 4 0.01000000      7 0.3004484 0.3721973 0.03730931
    prune(rpart.fit, cp = 0.041)
## n= 500 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 500 223 0 (0.55400000 0.44600000)  
##    2) x2< -0.6444322 90   6 0 (0.93333333 0.06666667) *
##    3) x2>=-0.6444322 410 193 1 (0.47073171 0.52926829)  
##      6) x1>=0.6941279 68   8 0 (0.88235294 0.11764706) *
##      7) x1< 0.6941279 342 133 1 (0.38888889 0.61111111)  
##       14) x2>=0.7484327 53   7 0 (0.86792453 0.13207547) *
##       15) x2< 0.7484327 289  87 1 (0.30103806 0.69896194)  
##         30) x1< -0.6903174 51   9 0 (0.82352941 0.17647059) *
##         31) x1>=-0.6903174 238  45 1 (0.18907563 0.81092437) *

The model proceed with the following steps. Note that steps 5 and 6 are not really benifical.

There are many other packages that can perform the same analysis.

    library(tree)
## Warning: package 'tree' was built under R version 4.1.3
    tree.fit = tree(as.factor(y)~x1+x2, data = data.frame(x1, x2, y))
    plot(tree.fit)
    text(tree.fit)

Gini Impurity vs. Shannon Entropy vs. Misclassification Error

Gini Impurity is used in CART, while ID3/C4.5 uses the Shannon entropy. These measures have different effects than the misclassification error. They usually perfer “pure” nodes, meaning that the benifit of singling out a set of pure class terminal node is large for Gini and Shannon. This is because their measures are nonlinear.

Bagging

    # bagging from ipred package
    library(ipred)
    library(rpart)
    
    set.seed(2)
    n = 1000
    x1 = runif(n, -1, 1)
    x2 = runif(n, -1, 1)
    y = rbinom(n, size = 1, prob = ifelse((x1 + x2 > -0.5) & (x1 + x2 < 0.5) , 0.8, 0.2))
    xgrid = expand.grid(x1 = seq(-1, 1, 0.01), x2 = seq(-1, 1, 0.01))
    par(mfrow=c(1,2), mar=c(0.5, 0.5, 2, 0.5))
    
    # CART
    rpart.fit = rpart(as.factor(y)~x1+x2, data = data.frame(x1, x2, y))
    #rpart.fit = rpart(as.factor(y)~x1+x2, data = data.frame(x1, x2, y)[sample(1:n, n, replace = TRUE), ])
    pred = matrix(predict(rpart.fit, xgrid, type = "class") == 1, 201, 201)
    contour(seq(-1, 1, 0.01), seq(-1, 1, 0.01), pred, levels=0.5, labels="",axes=FALSE)
    points(x1, x2, col = ifelse(y == 1, "deepskyblue", "darkorange"), pch = 19, yaxt="n", xaxt = "n")
    points(xgrid, pch=".", cex=1.2, col=ifelse(pred, "deepskyblue", "darkorange"))
    box()    
    title("CART")
    
    #Bagging
    bag.fit = bagging(as.factor(y)~x1+x2, data = data.frame(x1, x2, y), nbagg = 200, ns = 400)
    pred = matrix(predict(prune(bag.fit), xgrid) == 1, 201, 201)
    contour(seq(-1, 1, 0.01), seq(-1, 1, 0.01), pred, levels=0.5, labels="",axes=FALSE)
    points(x1, x2, col = ifelse(y == 1, "deepskyblue", "darkorange"), pch = 19, yaxt="n", xaxt = "n")
    points(xgrid, pch=".", cex=1.2, col=ifelse(pred, "deepskyblue", "darkorange"))
    box()
    title("Bagging")

Random Forests

In this two dimensional setting, we don’t see much improvement by using random forests. However, the improvement is significant in high dimensional settings.

    library(randomForest)
## Warning: package 'randomForest' was built under R version 4.1.3
    par(mar=c(0.5, 0.5, 2, 0.5))
    rf.fit = randomForest(cbind(x1, x2), as.factor(y), ntree = 1000, mtry = 1, nodesize = 20, sampsize = 500)
    pred = matrix(predict(rf.fit, xgrid) == 1, 201, 201)
    contour(seq(-1, 1, 0.01), seq(-1, 1, 0.01), pred, levels=0.5, labels="",axes=FALSE)
    points(x1, x2, col = ifelse(y == 1, "deepskyblue", "darkorange"), pch = 19, yaxt="n", xaxt = "n")
    points(xgrid, pch=".", cex=1.2, col=ifelse(pred, "deepskyblue", "darkorange"))
    box()
    title("Random Forests")

Random Forests vs. Kernel

I wrote a small function that will extract the kernel weights from a random forests for predicting a testing point \(x\). This is essentially the counts for how many times a training data falls into the same terminal node as \(x\). Since the prediction on \(x\) are essentially the average of them in a weighted fashion, this is basially a kernel averaging approach. However, the kernel weights are adaptive to the ture structure.

    # fit a random forest model
    rf.fit = randomForest(cbind(x1, x2), as.factor(y), ntree = 300, mtry = 1, nodesize = 20, keep.inbag = TRUE)
    pred = matrix(predict(rf.fit, xgrid) == 1, 201, 201)
    
    par(mfrow=c(1,2), mar=c(0.5, 0.5, 2, 0.5))

    # check the kernel weight at different points

    plotRFKernel(rf.fit, data.frame(cbind(x1, x2)), c(-0.1, 0.4))
    plotRFKernel(rf.fit, data.frame(cbind(x1, x2)), c(0, 0.6))

As contrast, here is the regular Gaussian kernel weights (after some tuning). This effect will play an important role when \(p\) is large.

    # Gaussain kernel weights
    onex = c(-0.1, 0.4)
    h = 0.2
    wt = exp(-0.5*rowSums(sweep(cbind(x1, x2), 2, onex, FUN = "-")^2)/h^2)
    contour(seq(-1, 1, 0.01), seq(-1, 1, 0.01), pred, levels=0.5, labels="",axes=FALSE)
    points(x1, x2, cex = 4*wt^(2/3), pch = 1, cex.axis=1.25, lwd = 2)
    points(x1, x2, col = ifelse(y == 1, "deepskyblue", "darkorange"), pch = 19, cex = 0.75, yaxt="n", xaxt = "n")
    points(xgrid, pch=".", cex=1.2, col=ifelse(pred, "deepskyblue", "darkorange"))
    points(onex[1], onex[2], pch = 4, col = "red", cex =4, lwd = 6)        
    box()