\(\newcommand{\ci}{\perp\!\!\!\perp}\) \(\newcommand{\cA}{\mathcal{A}}\) \(\newcommand{\cB}{\mathcal{B}}\) \(\newcommand{\cC}{\mathcal{C}}\) \(\newcommand{\cD}{\mathcal{D}}\) \(\newcommand{\cE}{\mathcal{E}}\) \(\newcommand{\cF}{\mathcal{F}}\) \(\newcommand{\cG}{\mathcal{G}}\) \(\newcommand{\cH}{\mathcal{H}}\) \(\newcommand{\cI}{\mathcal{I}}\) \(\newcommand{\cJ}{\mathcal{J}}\) \(\newcommand{\cK}{\mathcal{K}}\) \(\newcommand{\cL}{\mathcal{L}}\) \(\newcommand{\cM}{\mathcal{M}}\) \(\newcommand{\cN}{\mathcal{N}}\) \(\newcommand{\cO}{\mathcal{O}}\) \(\newcommand{\cP}{\mathcal{P}}\) \(\newcommand{\cQ}{\mathcal{Q}}\) \(\newcommand{\cR}{\mathcal{R}}\) \(\newcommand{\cS}{\mathcal{S}}\) \(\newcommand{\cT}{\mathcal{T}}\) \(\newcommand{\cU}{\mathcal{U}}\) \(\newcommand{\cV}{\mathcal{V}}\) \(\newcommand{\cW}{\mathcal{W}}\) \(\newcommand{\cX}{\mathcal{X}}\) \(\newcommand{\cY}{\mathcal{Y}}\) \(\newcommand{\cZ}{\mathcal{Z}}\) \(\newcommand{\bA}{\mathbf{A}}\) \(\newcommand{\bB}{\mathbf{B}}\) \(\newcommand{\bC}{\mathbf{C}}\) \(\newcommand{\bD}{\mathbf{D}}\) \(\newcommand{\bE}{\mathbf{E}}\) \(\newcommand{\bF}{\mathbf{F}}\) \(\newcommand{\bG}{\mathbf{G}}\) \(\newcommand{\bH}{\mathbf{H}}\) \(\newcommand{\bI}{\mathbf{I}}\) \(\newcommand{\bJ}{\mathbf{J}}\) \(\newcommand{\bK}{\mathbf{K}}\) \(\newcommand{\bL}{\mathbf{L}}\) \(\newcommand{\bM}{\mathbf{M}}\) \(\newcommand{\bN}{\mathbf{N}}\) \(\newcommand{\bO}{\mathbf{O}}\) \(\newcommand{\bP}{\mathbf{P}}\) \(\newcommand{\bQ}{\mathbf{Q}}\) \(\newcommand{\bR}{\mathbf{R}}\) \(\newcommand{\bS}{\mathbf{S}}\) \(\newcommand{\bT}{\mathbf{T}}\) \(\newcommand{\bU}{\mathbf{U}}\) \(\newcommand{\bV}{\mathbf{V}}\) \(\newcommand{\bW}{\mathbf{W}}\) \(\newcommand{\bX}{\mathbf{X}}\) \(\newcommand{\bY}{\mathbf{Y}}\) \(\newcommand{\bZ}{\mathbf{Z}}\) \(\newcommand{\ba}{\mathbf{a}}\) \(\newcommand{\bb}{\mathbf{b}}\) \(\newcommand{\bc}{\mathbf{c}}\) \(\newcommand{\bd}{\mathbf{d}}\) \(\newcommand{\be}{\mathbf{e}}\) \(\newcommand{\bg}{\mathbf{g}}\) \(\newcommand{\bh}{\mathbf{h}}\) \(\newcommand{\bi}{\mathbf{i}}\) \(\newcommand{\bj}{\mathbf{j}}\) \(\newcommand{\bk}{\mathbf{k}}\) \(\newcommand{\bl}{\mathbf{l}}\) \(\newcommand{\bm}{\mathbf{m}}\) \(\newcommand{\bn}{\mathbf{n}}\) \(\newcommand{\bo}{\mathbf{o}}\) \(\newcommand{\bp}{\mathbf{p}}\) \(\newcommand{\bq}{\mathbf{q}}\) \(\newcommand{\br}{\mathbf{r}}\) \(\newcommand{\bs}{\mathbf{s}}\) \(\newcommand{\bt}{\mathbf{t}}\) \(\newcommand{\bu}{\mathbf{u}}\) \(\newcommand{\bv}{\mathbf{v}}\) \(\newcommand{\bw}{\mathbf{w}}\) \(\newcommand{\bx}{\mathbf{x}}\) \(\newcommand{\by}{\mathbf{y}}\) \(\newcommand{\bz}{\mathbf{z}}\) \(\newcommand{\balpha}{\boldsymbol{\alpha}}\) \(\newcommand{\bbeta}{\boldsymbol{\beta}}\) \(\newcommand{\btheta}{\boldsymbol{\theta}}\) \(\newcommand{\bpi}{\boldsymbol{\pi}}\) \(\newcommand{\hbpi}{\widehat{\boldsymbol{\pi}}}\) \(\newcommand{\bxi}{\boldsymbol{\xi}}\) \(\newcommand{\bmu}{\boldsymbol{\mu}}\) \(\newcommand{\bepsilon}{\boldsymbol{\epsilon}}\) \(\newcommand{\bzero}{\mathbf{0}}\) \(\newcommand{\T}{\text{T}}\) \(\newcommand{\Trace}{\text{Trace}}\) \(\newcommand{\Cov}{\text{Cov}}\) \(\newcommand{\Var}{\text{Var}}\) \(\newcommand{\E}{\mathbb{E}}\) \(\newcommand{\Pr}{\text{Pr}}\) \(\newcommand{\pr}{\text{pr}}\) \(\newcommand{\pdf}{\text{pdf}}\) \(\newcommand{\P}{\text{P}}\) \(\newcommand{\p}{\text{p}}\) \(\newcommand{\One}{\mathbf{1}}\) \(\newcommand{\argmin}{\operatorname*{arg\,min}}\) \(\newcommand{\argmax}{\operatorname*{arg\,max}}\)

Generalized Random Forest

For the most basic background of random forest, we refer to this lecture note on how to split a tree node and this one for how to build a forest. Assuming that we understand how the tree finds the split in the regression setting, which is to maximize the variance reduction of the current observation’s outcome \(Y\) at a given internal node. The issue for our treatment decision problem is that splitting using this outcome does not lead to terminal nodes that are homogeneous in terms of the treatment effect or treatment decision. Hence, the tree split needs to facilitate the treatment effect estimation. For that task, the generalized random forest (GRF) by Athey, Tibshirani, and Wager (2019) is framework that can be used. But before we go into the details of the GRF, let’s first review the basic idea of the GRF.

Psudo-Outcome View of Node Split

At each internal node, the random forest tries to find an axis-aligned split that that puts some higher value outcomes to one side and lower value outcomes to the other side. This is done by variance reduction. However, when we do not have the outcome directly observed, we may use a pseudo-outcome motivated from the influence function. In a rough sense, the influence function corresponding to one observation is a measure of how much the estimated parameter would change if we perturb the value \(Y_i\) of that subject. One can naively think that this is to re-estimate the parameter without the subject \(i\) and then compare the difference with the whole sample. A simple example is when estimating the mean with i.i.d. samples, the parameter of interest is \(\theta = E(Y)\), with

\[ \hat \theta = \frac{1}{n} \sum_{i=1}^n Y_i. \]

and the influence function can be computed as1

\[ \begin{aligned} \text{IF}_i(\hat \theta) &= \frac{\frac{1}{n}\left[ \frac{n-1}{n} \sum_i Y_i + Y_i \right] - \frac{1}{n} \sum_i Y_i}{\frac{1}{n}}\\ &= Y_i - \hat\theta \end{aligned} \]

The interpretation of the influence function is that it measures how much the parameter would change if we perturb the value of \(Y_i\). A positive influence function means that the parameter would increase if we increase \(Y_i\), and vice versa. Hence, we can instead using this influence function (Psudo-Outcome) as our outcome to split the tree node, instead of using the original outcomes. But of course, in this regression case, the split is exactly the same since the variance reduction on \(Y_i\) would be the same as the variance reduction on \(Y_i - \hat\theta\).

For estmiating the treatment effect, we can set our goal using the following procedure. First, we fit a regression model, without covaraite on all data at an internal node:

\[ Y \sim \beta_0 + \tau A \]

This provides a regression estimation \(\hat \tau\) of the average treatment effect of the internal node samples. This estimator has well-established analytic form from the linear regression litureture, with the form

\[ \hat \tau = \frac{\sum_i (A_i - \bar A)(Y_i - \bar Y)}{\sum_i (A_i - \bar A)^2} \]

Then we can use the pseudo-outcomes introduced previously to calculate the influence function for \(\hat \tau\), which would turn out to be

\[ \text{IF}_i(\hat \tau) = \frac{(A_i - \bar A)\left(Y_i - \bar Y - \hat \tau (A_i - \bar A)\right)}{\sum_j (A_j - \bar A)^2} \]

At each internal node, we need to re-calculate these influence functions, but then they will be used against all covariates. Please also note that in observational study, propensity score would be used adjust the estimation, and a outcome model of \(\E[Y | X]\) would be used to adjust the outcome. The effect of that is similar to the variance reduction property we discussed previously.

Example: the grf package

These procedures are implemented in the grf package.

  # generate data
  set.seed(1)
  n = 800
  p = 5
  x = matrix(runif(n*p), n, p)
  ps = exp(x[,3]-0.5)/(1+exp(x[,3]-0.5))
  a = rbinom(n, 1, ps)*2-1
  side = sign(x[,2] - sin(2*pi*x[,1])/3-0.5)
  
  R <- rnorm(n, mean = ifelse(side==a, 2, 0.5) + 3*x[, 4] + 3*x[, 5], sd = 0.5)
  
  # fit grf
  library(grf)
  c.forest <- causal_forest(x, R, a)

  table(c.forest$predictions > 0, side)
##        side
##          -1   1
##   FALSE 382  12
##   TRUE   36 370
  mean((c.forest$predictions > 0) == (side == 1))
## [1] 0.94

  # plot the estimated treatment regime
  par(mar=rep(2, 4))
  plot(x[,1], x[,2], pch = 19, xaxt = "n", yaxt = "n", xlab = "", ylab = "", 
       col = ifelse(c.forest$predictions > 0, "deepskyblue", "darkorange"))
  legend("topright", c("Treatment = 1", "Treatment = -1"), pch = 19, 
         col = c("deepskyblue", "darkorange"), cex = 1.3)

A nice thing about using random forests is that we can easily obtain variance estimation and confidence intervals of the conditional treatment effect. A sequence of papers have discussed how to estimate the variance of random forests through infinitesimal jackknife (Wager, Hastie, and Efron 2014), and U-statistics (Mentch and Hooker 2016). Some details are provided in this lecture note.

Reference

Athey, Susan, Julie Tibshirani, and Stefan Wager. 2019. “Generalized Random Forests.” The Annals of Statistics 47 (2): 1148–78.
Mentch, Lucas, and Giles Hooker. 2016. “Quantifying Uncertainty in Random Forests via Confidence Intervals and Hypothesis Tests.” Journal of Machine Learning Research 17 (26): 1–41.
Wager, Stefan, Trevor Hastie, and Bradley Efron. 2014. “Confidence Intervals for Random Forests: The Jackknife and the Infinitesimal Jackknife.” The Journal of Machine Learning Research 15 (1): 1625–51.

  1. This can be directly calculated from the definition of the influence function, which is \(\lim_{\epsilon \rightarrow0} \frac{\hat\theta((1-\epsilon) F + \epsilon\delta_x) - \hat\theta(F)}{\epsilon}\) where \(T\) is the statistic, \(F\) is the distribution of the sample, and \(\delta_x\) is the distribution that puts all point mass at \(x\). And in our case of estimating the mean, we can take \(\epsilon\) to be \(1/n\) and \(x\) to be \(Y_i\)↩︎