Skip to contents

Install and Load Package

Install and load the GitHub version of the RLT package. Do not use the CRAN version.

  # install.packages("devtools")
  # devtools::install_github("teazrq/RLT")
  library(RLT)
## RLT and Random Forests v4.2.6
## pre-release at github.com/teazrq/RLT

Load other packages used in this guide.

Benchmark Against Existing Packages

We generate a dataset with 1000 observations and 400 variables, with 200 continuous variables and 200 categorical ones with three categories.

  # Set seed for reproducibility
  set.seed(1)

  # Define data size
  trainn <- 800
  testn <- 1000
  n <- trainn + testn
  p <- 30

  # Generate continuous variables (X1) and categorical variables (X2)
  X1 <- matrix(rnorm(n * p / 2), n, p / 2)
  #X2 <- matrix(rnorm(n * p / 2), n, p / 2)
  X2 <- matrix(as.integer(runif(n * p / 2) * 10), n, p / 2)

  # Combine continuous and categorical variables into a data frame (X)
  X <- data.frame(X1, X2)

  # Convert the second half of the columns in X to factors
  X[, (p / 2 + 1):p] <- lapply(X[, (p / 2 + 1):p], as.factor)

  # Generate outcomes (y)
  logit <- function(x) exp(x) / (1 + exp(x))
#  y <- as.factor(rbinom(n, 1, prob = logit(1 + rowSums(X[, 1:5]) + 2 * (X[, p / 2 + 1] %in% c(1, 3)) + rnorm(n))) + 2)
  
  y <- as.factor(rbinom(n, 1, prob = logit(1 + 1*X[, 2] + 3*(X[, p] %in% c(1, 3, 5, 7)))) + 2)
  
  # Set tuning parameters
  ntrees <- 1000
  ncores <- 10
  nmin <- 20
  mtry <- p/2
  samplereplace <- TRUE
  sampleprob <- 0.75
  rule <- "best"
  nsplit <- ifelse(rule == "best", 0, 3)
  importance <- TRUE

  # Split data into training and testing sets
  trainX <- X[1:trainn, ]
  trainY <- y[1:trainn]
  testX <- X[(trainn + 1):(trainn + testn), ]
  testY <- y[(trainn + 1):(trainn + testn)]
  # recording results
  metric = data.frame(matrix(NA, 5, 6))
  rownames(metric) = c("RLT", "randomForestSRC", "randomForest", "ranger", "ranger fast")
  colnames(metric) = c("fit.time", "pred.time", "oob.error",
                       "pred.error", "obj.size", "ave.tree.size")

  # using RLT package 
  start_time <- Sys.time()
  RLTfit <- RLT(trainX, trainY, model = "classification",
                ntrees = ntrees, mtry = mtry, nmin = nmin,
                resample.prob = sampleprob, split.gen = rule,
                resample.replace = samplereplace, 
                nsplit = nsplit, importance = importance,
                param.control = list("alpha" = 0),
                ncores = ncores, verbose = TRUE)
## Classification Random Forest ... 
## ---------- Parameters Summary ----------
##               (N, P) = (800, 30)
##           # of trees = 1000
##         (mtry, nmin) = (15, 20)
##       split generate = Best
##             sampling = 0.75 w/ replace
##   (Obs, Var) weights = (No, No)
##           importance = permute
##        reinforcement = No
## ----------------------------------------
## Do not have 10 cores, use maximum 4 cores.
  metric[1, 1] = difftime(Sys.time(), start_time, units = "secs")
  start_time <- Sys.time()
  RLTPred <- predict(RLTfit, testX, ncores = ncores)
  metric[1, 2] = difftime(Sys.time(), start_time, units = "secs")
  metric[1, 3] = mean(RLTfit$Prediction != trainY)
  metric[1, 4] = mean(RLTPred$Prediction != testY)
  metric[1, 5] = object.size(RLTfit)
  metric[1, 6] = mean(unlist(lapply(RLTfit$FittedForest$SplitVar, length)))

  # use randomForestSRC
  options(rf.cores = ncores)
  start_time <- Sys.time()
  rsffit <- rfsrc(y ~ ., data = data.frame(trainX, "y"= trainY), 
                  ntree = ntrees, nodesize = nmin/2, mtry = mtry, 
                  samptype = ifelse(samplereplace == TRUE, "swor", "swr"),
                  nsplit = nsplit, sampsize = trainn*sampleprob, 
                  importance = ifelse(importance, "permute", "none"))
  metric[2, 1] = difftime(Sys.time(), start_time, units = "secs")
  start_time <- Sys.time()
  rsfpred = predict(rsffit, data.frame(testX))
  metric[2, 2] = difftime(Sys.time(), start_time, units = "secs")
  metric[2, 3] = mean(apply(rsffit$predicted.oob, 1, which.max) + 1 != trainY)
  metric[2, 4] = mean(rsfpred$class != testY)
  metric[2, 5] = object.size(rsffit)
  metric[2, 6] = rsffit$forest$totalNodeCount / rsffit$ntree
  
  # use randomForest
  start_time <- Sys.time()
  rf.fit <- randomForest(trainX, trainY, ntree = ntrees, 
                         mtry = mtry, nodesize = nmin, 
                         replace = samplereplace,
                         sampsize = trainn*sampleprob, 
                         importance = importance)
  metric[3, 1] = difftime(Sys.time(), start_time, units = "secs")
  start_time <- Sys.time()
  rf.pred <- predict(rf.fit, testX)
  metric[3, 2] = difftime(Sys.time(), start_time, units = "secs")
  metric[3, 3] = mean(rf.fit$predicted != trainY)
  metric[3, 4] = mean(rf.pred != testY)
  metric[3, 5] = object.size(rf.fit)
  metric[3, 6] = mean(colSums(rf.fit$forest$nodestatus != 0))
  
  # use ranger  
  start_time <- Sys.time()
  rangerfit <- ranger(y ~ ., data = data.frame(trainX, "y"= trainY), 
                      num.trees = ntrees, min.node.size = nmin, 
                      mtry = mtry, num.threads = ncores, 
                      replace = samplereplace,
                      sample.fraction = sampleprob, 
                      importance = "permutation",
                      respect.unordered.factors = "partition")
## Growing trees.. Progress: 60%. Estimated remaining time: 21 seconds.
  metric[4, 1] = difftime(Sys.time(), start_time, units = "secs")
  start_time <- Sys.time()
  rangerpred = predict(rangerfit, data.frame(testX))
  metric[4, 2] = difftime(Sys.time(), start_time, units = "secs")
  metric[4, 3] = mean(rangerfit$predictions != trainY)
  metric[4, 4] = mean(rangerpred$predictions != testY)
  metric[4, 5] = object.size(rangerfit)
  metric[4, 6] = mean(unlist(lapply(rangerfit$forest$split.varIDs, length)))
  
  # use ranger without partitioning
  start_time <- Sys.time()
  rangerfast <- ranger(y ~ ., data = data.frame(trainX, "y"= trainY), 
                       num.trees = ntrees, min.node.size = nmin, 
                       mtry = mtry, num.threads = ncores, 
                       replace = samplereplace,
                       sample.fraction = sampleprob, 
                       importance = "permutation")
  metric[5, 1] = difftime(Sys.time(), start_time, units = "secs")
  start_time <- Sys.time()
  rangerpred = predict(rangerfast, data.frame(testX))
  metric[5, 2] = difftime(Sys.time(), start_time, units = "secs")
  metric[5, 3] = mean(rangerfast$predictions != trainY)
  metric[5, 4] = mean(rangerpred$predictions != testY)
  metric[5, 5] = object.size(rangerfast)
  metric[5, 6] = mean(unlist(lapply(rangerfast$forest$split.varIDs, length)))
  
  # performance summary
  metric
##                   fit.time  pred.time oob.error pred.error obj.size ave.tree.size
## RLT              3.8877819 0.04142427   0.18375      0.182  4432112        58.618
## randomForestSRC 42.5205061 0.17455602   0.18250      0.171 12907088        71.766
## randomForest     5.0521028 0.04990530   0.18000      0.176  2122056        50.650
## ranger          52.2684577 0.09444880   0.19250      0.183  2182560        58.514
## ranger fast      0.6899104 0.08232188   0.20125      0.201  2754232        76.386

You can use the get.one.tree() function to peek into a single tree.

  get.one.tree(RLTfit, 1)
## Tree #1 in the fitted classification forest:
##     SplitVar   SplitValue LeftNode RightNode NodeWeight  Prob.of..2 Prob.of..3
## 1        X2    -0.6926111        2         3        600 0.193333333  0.8066667
## 2   X4.1 (F)  448.0000000        4         5        146 0.390410959  0.6095890
## 3  X15.1 (F)  852.0000000       26        27        454 0.129955947  0.8700441
## 4   X8.1 (F)  306.0000000        6         7         98 0.510204082  0.4897959
## 5        X3     0.6296932       20        21         48 0.145833333  0.8541667
## 6  X15.1 (F)  682.0000000        8         9         57 0.666666667  0.3333333
## 7  X10.1 (F)  442.0000000       14        15         41 0.292682927  0.7073171
## 8       <NA>           NA       NA        NA         12 0.166666667  0.8333333
## 9        X3     1.0694546       10        11         45 0.800000000  0.2000000
## 10 X14.1 (F)  528.0000000       12        13         40 0.875000000  0.1250000
## 11      <NA>           NA       NA        NA          5 0.200000000  0.8000000
## 12      <NA>           NA       NA        NA         34 1.000000000  0.0000000
## 13      <NA>           NA       NA        NA          6 0.166666667  0.8333333
## 14      <NA>           NA       NA        NA         12 0.833333333  0.1666667
## 15       X9     0.9688454       16        17         29 0.068965517  0.9310345
## 16  X3.1 (F)  512.0000000       18        19         28 0.035714286  0.9642857
## 17      <NA>           NA       NA        NA          1 1.000000000  0.0000000
## 18      <NA>           NA       NA        NA         27 0.000000000  1.0000000
## 19      <NA>           NA       NA        NA          1 1.000000000  0.0000000
## 20  X3.1 (F)   32.0000000       22        23         40 0.050000000  0.9500000
## 21      <NA>           NA       NA        NA          8 0.625000000  0.3750000
## 22 X13.1 (F)  256.0000000       24        25         39 0.025641026  0.9743590
## 23      <NA>           NA       NA        NA          1 1.000000000  0.0000000
## 24      <NA>           NA       NA        NA         38 0.000000000  1.0000000
## 25      <NA>           NA       NA        NA          1 1.000000000  0.0000000
## 26  X4.1 (F)  778.0000000       28        29        227 0.215859031  0.7841410
## 27       X3     2.0473906       52        53        227 0.044052863  0.9559471
## 28  X1.1 (F)  892.0000000       30        31        145 0.296551724  0.7034483
## 29 X11.1 (F)    2.0000000       46        47         82 0.073170732  0.9268293
## 30       X4     0.3217051       32        33         38 0.552631579  0.4473684
## 31 X13.1 (F)  256.0000000       36        37        107 0.205607477  0.7943925
## 32  X7.1 (F)  902.0000000       34        35         27 0.777777778  0.2222222
## 33      <NA>           NA       NA        NA         11 0.000000000  1.0000000
## 34      <NA>           NA       NA        NA          7 0.285714286  0.7142857
## 35      <NA>           NA       NA        NA         20 0.950000000  0.0500000
## 36       X5     0.2664197       38        39         96 0.145833333  0.8541667
## 37      <NA>           NA       NA        NA         11 0.727272727  0.2727273
## 38  X7.1 (F) 1018.0000000       40        41         67 0.059701493  0.9402985
## 39       X5     0.9116686       44        45         29 0.344827586  0.6551724
## 40      <NA>           NA       NA        NA         10 0.300000000  0.7000000
## 41      X12     1.1907859       42        43         57 0.017543860  0.9824561
## 42      <NA>           NA       NA        NA         54 0.000000000  1.0000000
## 43      <NA>           NA       NA        NA          3 0.333333333  0.6666667
## 44      <NA>           NA       NA        NA         13 0.692307692  0.3076923
## 45      <NA>           NA       NA        NA         16 0.062500000  0.9375000
## 46      X11    -2.4272374       48        49         72 0.027777778  0.9722222
## 47      <NA>           NA       NA        NA         10 0.400000000  0.6000000
## 48      <NA>           NA       NA        NA          1 1.000000000  0.0000000
## 49  X7.1 (F)  512.0000000       50        51         71 0.014084507  0.9859155
## 50      <NA>           NA       NA        NA         66 0.000000000  1.0000000
## 51      <NA>           NA       NA        NA          5 0.200000000  0.8000000
## 52       X7     1.2252040       54        55        222 0.031531532  0.9684685
## 53      <NA>           NA       NA        NA          5 0.600000000  0.4000000
## 54      X11     2.3314162       56        57        192 0.010416667  0.9895833
## 55       X7     1.4388455       60        61         30 0.166666667  0.8333333
## 56      X11    -2.0074093       58        59        190 0.005263158  0.9947368
## 57      <NA>           NA       NA        NA          2 0.500000000  0.5000000
## 58      <NA>           NA       NA        NA          9 0.111111111  0.8888889
## 59      <NA>           NA       NA        NA        181 0.000000000  1.0000000
## 60      <NA>           NA       NA        NA          7 0.714285714  0.2857143
## 61      <NA>           NA       NA        NA         23 0.000000000  1.0000000