\(\newcommand{\ci}{\perp\!\!\!\perp}\) \(\newcommand{\cA}{\mathcal{A}}\) \(\newcommand{\cB}{\mathcal{B}}\) \(\newcommand{\cC}{\mathcal{C}}\) \(\newcommand{\cD}{\mathcal{D}}\) \(\newcommand{\cE}{\mathcal{E}}\) \(\newcommand{\cF}{\mathcal{F}}\) \(\newcommand{\cG}{\mathcal{G}}\) \(\newcommand{\cH}{\mathcal{H}}\) \(\newcommand{\cI}{\mathcal{I}}\) \(\newcommand{\cJ}{\mathcal{J}}\) \(\newcommand{\cK}{\mathcal{K}}\) \(\newcommand{\cL}{\mathcal{L}}\) \(\newcommand{\cM}{\mathcal{M}}\) \(\newcommand{\cN}{\mathcal{N}}\) \(\newcommand{\cO}{\mathcal{O}}\) \(\newcommand{\cP}{\mathcal{P}}\) \(\newcommand{\cQ}{\mathcal{Q}}\) \(\newcommand{\cR}{\mathcal{R}}\) \(\newcommand{\cS}{\mathcal{S}}\) \(\newcommand{\cT}{\mathcal{T}}\) \(\newcommand{\cU}{\mathcal{U}}\) \(\newcommand{\cV}{\mathcal{V}}\) \(\newcommand{\cW}{\mathcal{W}}\) \(\newcommand{\cX}{\mathcal{X}}\) \(\newcommand{\cY}{\mathcal{Y}}\) \(\newcommand{\cZ}{\mathcal{Z}}\) \(\newcommand{\bA}{\mathbf{A}}\) \(\newcommand{\bB}{\mathbf{B}}\) \(\newcommand{\bC}{\mathbf{C}}\) \(\newcommand{\bD}{\mathbf{D}}\) \(\newcommand{\bE}{\mathbf{E}}\) \(\newcommand{\bF}{\mathbf{F}}\) \(\newcommand{\bG}{\mathbf{G}}\) \(\newcommand{\bH}{\mathbf{H}}\) \(\newcommand{\bI}{\mathbf{I}}\) \(\newcommand{\bJ}{\mathbf{J}}\) \(\newcommand{\bK}{\mathbf{K}}\) \(\newcommand{\bL}{\mathbf{L}}\) \(\newcommand{\bM}{\mathbf{M}}\) \(\newcommand{\bN}{\mathbf{N}}\) \(\newcommand{\bO}{\mathbf{O}}\) \(\newcommand{\bP}{\mathbf{P}}\) \(\newcommand{\bQ}{\mathbf{Q}}\) \(\newcommand{\bR}{\mathbf{R}}\) \(\newcommand{\bS}{\mathbf{S}}\) \(\newcommand{\bT}{\mathbf{T}}\) \(\newcommand{\bU}{\mathbf{U}}\) \(\newcommand{\bV}{\mathbf{V}}\) \(\newcommand{\bW}{\mathbf{W}}\) \(\newcommand{\bX}{\mathbf{X}}\) \(\newcommand{\bY}{\mathbf{Y}}\) \(\newcommand{\bZ}{\mathbf{Z}}\) \(\newcommand{\ba}{\mathbf{a}}\) \(\newcommand{\bb}{\mathbf{b}}\) \(\newcommand{\bc}{\mathbf{c}}\) \(\newcommand{\bd}{\mathbf{d}}\) \(\newcommand{\be}{\mathbf{e}}\) \(\newcommand{\bg}{\mathbf{g}}\) \(\newcommand{\bh}{\mathbf{h}}\) \(\newcommand{\bi}{\mathbf{i}}\) \(\newcommand{\bj}{\mathbf{j}}\) \(\newcommand{\bk}{\mathbf{k}}\) \(\newcommand{\bl}{\mathbf{l}}\) \(\newcommand{\bm}{\mathbf{m}}\) \(\newcommand{\bn}{\mathbf{n}}\) \(\newcommand{\bo}{\mathbf{o}}\) \(\newcommand{\bp}{\mathbf{p}}\) \(\newcommand{\bq}{\mathbf{q}}\) \(\newcommand{\br}{\mathbf{r}}\) \(\newcommand{\bs}{\mathbf{s}}\) \(\newcommand{\bt}{\mathbf{t}}\) \(\newcommand{\bu}{\mathbf{u}}\) \(\newcommand{\bv}{\mathbf{v}}\) \(\newcommand{\bw}{\mathbf{w}}\) \(\newcommand{\bx}{\mathbf{x}}\) \(\newcommand{\by}{\mathbf{y}}\) \(\newcommand{\bz}{\mathbf{z}}\) \(\newcommand{\RR}{\mathbb{R}}\) \(\newcommand{\NN}{\mathbb{N}}\) \(\newcommand{\balpha}{\boldsymbol{\alpha}}\) \(\newcommand{\bbeta}{\boldsymbol{\beta}}\) \(\newcommand{\btheta}{\boldsymbol{\theta}}\) \(\newcommand{\hpi}{\widehat{\pi}}\) \(\newcommand{\bpi}{\boldsymbol{\pi}}\) \(\newcommand{\hbpi}{\widehat{\boldsymbol{\pi}}}\) \(\newcommand{\bxi}{\boldsymbol{\xi}}\) \(\newcommand{\bmu}{\boldsymbol{\mu}}\) \(\newcommand{\bepsilon}{\boldsymbol{\epsilon}}\) \(\newcommand{\bzero}{\mathbf{0}}\) \(\newcommand{\T}{\text{T}}\) \(\newcommand{\Trace}{\text{Trace}}\) \(\newcommand{\Cov}{\text{Cov}}\) \(\newcommand{\Var}{\text{Var}}\) \(\newcommand{\E}{\mathbb{E}}\) \(\newcommand{\Pr}{\text{Pr}}\) \(\newcommand{\pr}{\text{pr}}\) \(\newcommand{\pdf}{\text{pdf}}\) \(\newcommand{\P}{\text{P}}\) \(\newcommand{\p}{\text{p}}\) \(\newcommand{\One}{\mathbf{1}}\) \(\newcommand{\argmin}{\operatorname*{arg\,min}}\) \(\newcommand{\argmax}{\operatorname*{arg\,max}}\) \(\newcommand{\dtheta}{\frac{\partial}{\partial\theta} }\) \(\newcommand{\ptheta}{\nabla_\theta}\) \(\newcommand{\alert}[1]{\color{darkorange}{#1}}\) \(\newcommand{\alertr}[1]{\color{red}{#1}}\) \(\newcommand{\alertb}[1]{\color{blue}{#1}}\)

1 Overview

We have introduced the average treatment effect (ATE) estimation in the previous lecture. In this lecture, we will take a step further to study the conditional average treatment effect (CATE) estimation, which aims to estimate the treatment effect for each individual based on the subject specific covariates. CATE estimation is a key component in personalized medicine, where the goal is to recommend treatment based on individual characteristics. Precisely, we are interested in possibly two quantities, the conditional average treatment effect (CATE)1:

\[ \tau(x) = \E[Y(1) - Y(0) | X = x] \]

and the optimal treatment regime can be inferred from the CATE:

\[ \begin{aligned} \pi_\text{opt}(x) &= \underset{\pi}{\arg\max} \,\, \E\left[Y(\pi(X)) | X = x\right] \\ &= \mathbf{1}\{\tau(x) > 0\} \end{aligned} \]

where \(\pi(\cdot)\) is a treatment assignment function that takes the patient’s covariate \(X\) as the input. The best of all such functions, \(\pi_\text{opt}(x)\), which makes the best decision for all individuals, is the called the optimal treatment regime. In this section, we will discuss the estimation of CATE which are often done by regression methods. When these methods are used to estimate the optimal treatment regime, they are often referred to as indirect learning methods. On the other hand, there are also direct learning methods that directly estimate the optimal treatment regime without estimating the CATE first. We will discuss both types of methods in this lecture.

2 The Virtual Twins with Random Forests

The most straight forward idea is to perform regression models on both the treated and untreated groups, and use their difference to infer the treatment decision. Here, we would always assume that there is no unmeasured confounding. Hence, the approach involves two steps:

  • In the first step, we learn the regression functions \(\E[Y | X, A = 1]\) and \(\E[Y | X, A = 0]\).
  • In the second step, we use the CATE to define a label for each individual \[ d_i = \widehat{\E}[Y | X_i, A = 1] - \widehat{\E}[Y | X_i, A = 0] \] and use this label as the outcome to learn the optimal decision rule \(\pi_\text{opt}(x)\).

A typical example of this type of approach is called the Virtual Twins model (Foster, Taylor, and Ruberg 2011). It uses random forests to learn the regression functions in the first step, and in the second step, it uses a single tree model to summarize the decision rule and make it interpretable. This example using sepsis data is mainly based on the vignettes of the aVirtualTwins R package. We also need the packages for random forests and CART.

Let’s consider a simulated data set derived from the SIDES package. The data in .csv format can be downloaded from our course website. In this data set, 470 patients and 14 variables are collected. The variables are listed below.

  • Health: Health outcome (larger the better)
  • THERAPY: 1 for active treatment, 0 for the control treatment
  • TIMFIRST: Time from first sepsis-organ fail to start drug
  • AGE: Patient age in years
  • BLLPLAT: Baseline local platelets
  • blSOFA: Sum of baseline sofa score (cardiovascular, hematology, hepatorenal, and respiration scores)
  • BLLCREAT: Base creatinine
  • ORGANNUM: Number of baseline organ failures
  • PRAPACHE: Pre-infusion apache-ii score
  • BLGCS: Base GLASGOW coma scale score
  • BLIL6: Baseline serum IL-6 concentration
  • BLADL: Baseline activity of daily living score
  • BLLBILI: Baseline local bilirubin
  • BEST: The true best treatment suggested by doctors. You should not use this variable when fitting models!

For each patient, sepsis was observed during their hospital stay. Hence, one of the two treatments (indicated by variable THERAPY) must be chosen to prevent further adverse events. After the treatment, the patient’s health outcome (Health) was measured, with a larger value being the better outcome. The BEST variable is a doctor suggested best treatment, which is not observed. This can be regarded as the unknown truth.

Run the cell below to load the data set and display the first few rows. It will only take a few seconds to complete. Make sure that you have the working directory setup correctly. You can do this by putting your .rmd file and the .csv file in the same folder, then open your .rmd file to pop up RStudio.

  Sepsis <- read.csv("..//..//dataset//Sepsis.csv")
  # remove the first column, which is observation ID
  Sepsis = Sepsis[, -1]
  head(Sepsis)

2.1 Step 1: Outcome Regressions

We will fit two random forests to model the outcome Health: one model is based on all patients who received treatment (THERAPY) 1, and the other model is based on all patients who received treatment 0 (corresponding to our previous label \(-1\)). Denote these two models as \(\hat f_1(x)\) and \(\hat f_0(x)\), respectively. Here, we will also use a little trick in random forest, which is the out-of-bag prediction. This is similar to a leave-one-out type of approach and the fitted value of a subject is calculated from the trees that do not use this subject. For more details, you may refer to this lecture note. This mechanic effectively prevents over-fitting to certain extend. In the meantime, we should also tune both mtry and nodesize, but that step will be skipped in the demonstration. The following code demonstates this idea:

  library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.

  # fit model for treatment 0
  model0 <- randomForest(Health ~ . - BEST, 
                         data = Sepsis[Sepsis$THERAPY == 0, ], 
                         ntree=2000,
                         mtry = 5, 
                         nodesize = 5)
  
  # fit model for treatment 1
  model1 <- randomForest(Health ~ . - BEST, 
                         data = Sepsis[Sepsis$THERAPY == 1, ], 
                         ntree=2000,
                         mtry = 5, 
                         nodesize = 5)
  
  # out-of-bag prediction for treatment 0 group with model 0
  model0.treat0 = model0$predicted
  
  # out-of-bag prediction for treatment 1 group with model 1
  model1.treat1 = model1$predicted
  
  # prediction for treatment 1 group with model 0
  model0.treat1 = predict(model0, Sepsis[Sepsis$THERAPY == 1, ])
  
  # prediction for treatment 0 group with model 1
  model1.treat0 = predict(model1, Sepsis[Sepsis$THERAPY == 0, ])

  # combine predictions together 
  pred.treat0 = rep(NA, nrow(Sepsis))
  pred.treat0[Sepsis$THERAPY == 0] = model0.treat0
  pred.treat0[Sepsis$THERAPY == 1] = model0.treat1
  
  pred.treat1 = rep(NA, nrow(Sepsis))
  pred.treat1[Sepsis$THERAPY == 0] = model1.treat0
  pred.treat1[Sepsis$THERAPY == 1] = model1.treat1  
  
  # which is better?
  pred.best = (pred.treat1 > pred.treat0)
  
  # is this good?
  mean(pred.best == Sepsis$BEST)
## [1] 0.8106383

So it looks like we got pretty good accuracy, 81%. This is simply taking the optimal decision rule as the one that has a higher predicted outcome based on the twin random forest models. But random forest is a very complex model, and is not interpretable. Hence a doctor cannot easily know what is the better treatment.

2.2 Step 2: A Tree Decision Rule

This second step can be optional, but very useful in practice if simple decision rule is preferred. We will construct a single-tree model (CART) to represent the optimal decision rule \(\pi_\text{opt}(x)\). The idea is very simple, we will create an artificial label, best.label, which is 1 if pred.treat1 is larger than pred.treat0, and 0 otherwise. Then we will use all of our covariates to model the best.label we learned.

  library(rpart)

  # fit a decision tree to predict pred.best, 
  # excluding BEST, Health, and THERAPY from the dataset
  best.label = ifelse(pred.treat1 > pred.treat0, "Treatment 1", "Treatment 0")
  best.label = as.factor(best.label)
  
  rpart.fit <- rpart(best.label ~ . - BEST - Health - THERAPY, data = Sepsis)
  
  # in the coming cells, we will prune the tree
  # start by plotting the cross-validation relative error at different tree sizes
  plotcp(rpart.fit)

After some tuning of the tree model, using the cp parameter, we can view the fitted tree model:

  # cut the tree
  cutted.tree = prune(rpart.fit, cp = 0.04)

  library(rpart.plot)
  # make a good plot
  rpart.plot(cutted.tree)

Hence, the conclusion is that, when \(\text{Age} < 52\), we should use treatment 1. While for \(\text{Age} \geq 52\), if the apache score is less than 32, we should use treatment 0 and otherwise treatment 1.

3 Direct Learning with Linear Regression

We can of course fit a linear regression model with interactions between the treatment label and the covariates. If we simply do that and use the fitted model to predict the best treatment, we are doing indirect learning.

\[ \E[Y | X, A] = \boldsymbol \beta^T X + \boldsymbol \alpha^T X \cdot A \]

The following code demonstrate this idea using a simulated data set and the Lasso model. Please note that we will be using \(A \in \{-1, 1\}\) as the treatment label since it will benifit our later development.

  set.seed(1)
  n = 100
  p = 200
  X = matrix(rnorm(n*p), n, p)
  A = rbinom(n, 1, 0.5)*2-1
  beta = c(rep(0.2, 10), rep(0, p-10))
  alpha = c(rep(0, p-2), 0.75, -0.75)
  R = X %*% beta + A * (X %*% alpha) + rnorm(n, sd = 0.5)
  
  testn = 500
  testX = matrix(rnorm(testn*p), testn, p)
  TRUE_BEST = sign(testX %*% alpha)
  library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-10

  # fit two regression models
  Apos.fit <- cv.glmnet(x = cbind(X, A)[A == 1,], y = R[A == 1])
  Aneg.fit <- cv.glmnet(x = cbind(X, A)[A == -1,], y = R[A == -1])
  
  # compare the models for new data
  Apos.pred = predict(Apos.fit, newx = cbind(testX, 1))
  Aneg.pred = predict(Aneg.fit, newx = cbind(testX, -1))
  
  # the predicted treatment label
  indirect.pred = sign(Apos.pred - Aneg.pred)
  
  # the accuracy
  mean(indirect.pred == TRUE_BEST)
## [1] 0.876

This seems to be fairly accurate. We got 88% of accuracy, meaning that, for around 88% of the new patients, we suggested the better treatment.

3.1 Case 1: Randomized Trial

But let’s also explore an interesting fact about our previous equation. This method was proposed by Tian et al. (2014). Suppose we do not use \(Y\) as our outcome of interest. Instead, we use \(AY\), the product of treatment label and the outcome Then, we have

\[ \begin{aligned} \E(AY| X) =& \E( A | X) \cdot \boldsymbol \beta^T X + \E(A^2 | X) \cdot \boldsymbol \alpha^T X.\\ =& \big[ P( A = 1 | X) - P(A = -1 | X) \big] \cdot \boldsymbol \beta^T X + \boldsymbol \alpha^T X \end{aligned} \]

If our study is collected from a balanced randomized trial, i.e., \(\Pr(A = 1 | X) = P(A = -1 | X) = 0.5\), then the first term is zero. And we only have the second term

\[ \E(AY | X) = \boldsymbol \alpha^T X \]

This suggests a very simple idea for directly modeling the best treatment rule using this modified outcome:

  # fitting the direct learning model
  direct.fit <- cv.glmnet(x = X, y = R*A)

  # direct learning predicted label
  direct.pred <- sign(predict(direct.fit, testX))
  
  # the accuracy
  mean(direct.pred == TRUE_BEST)  
## [1] 0.916

This time, the accuracy 92% is slightly better than the previous one.

3.2 Case 2: Observational Study

As we can expect, the way to deal with observational study is to use the propensity score weighting. We can modify the previous equation using a inverse propensity weight. Of course we would also assume suitable conditions as we did in previous sections.

\[ \begin{aligned} & \E \left( \frac{A}{\Pr( A | X)}Y \Biggm| X \right) \\ =& \left[ \frac{1}{\Pr( A | X)}\Pr( A = 1 | X) + \frac{-1}{\Pr( A = -1 | X)} \Pr( A = -1 | X) \right] \cdot \bbeta^T X + \\ &\quad \left[ \frac{1^2}{\Pr( A = 1| X)}\Pr( A = 1| X) + \frac{(-1)^2}{\Pr( A = -1| X)} \Pr( A = -1| X) \right] \balpha^T X \\ =& \left[ 1 - 1 \right] \cdot \bbeta^T X + \left[ 1 - 1 \right] \balpha^T X \\ =& 2 \balpha^T X \end{aligned} \]

So the first term can still be canceled out. Then, our job simply involves one more step: estimate the propensity scores \(\Pr( A | X)\), using, e.g. logistic regression or random forests, and then plug into the previous linear regression estimation using the subject weights \(\widehat{\Pr}(A = a_i | X = x_i)\).

4 CATE Estimation with Causal Forests

We can re-visit the example in our grf lecture regarding the influence function splitting rule. In this case, we have a local estimation problem with the CATE identified via the following moment condition:

\[ \mathbb{E}\left[ \frac{A - e(X)}{e(X)(1 - e(X))} (Y - m(X)) - \tau(X) \mid X \right] = 0, \]

where \(e(X) = \mathbb{P}(A=1 \mid X)\) is the propensity score and

\[ m(X) = \mathbb{E}[Y \mid X] = e(X) \mathbb{E}[Y \mid X, A=1] + (1 - e(X)) \mathbb{E}[Y \mid X, A=0] \]

is the outcome regression (ignoring the treatment propensity). It can be done by directly regress \(Y\) on \(X\) and ignore the treatment label. This is slightly different from our previous doubly robust moment condition, but we can still validate the 0 moment condition by checking that

\[ \mathbb{E}\left[ \frac{A - e(X)}{e(X)(1 - e(X))} m(X) \mid X \right] = 0. \]

since conditioning on \(X\), \(m(X)\) is a constant, and \[ \mathbb{E}[ A - e(X) ] = 0 \] by the definition of propensity score. Hence adding or removing this term will not change the moment condition. While the first term

\[ \begin{aligned} \mathbb{E}\left[ \frac{A - e(X)}{e(X)(1 - e(X))} Y \mid X \right] &= \mathbb{E}\left[ \frac{A - e(X)}{e(X)(1 - e(X))} \left( Y(1) A + Y(0)(1 - A) \right) \mid X \right] \\ &= \frac{1}{e(X)(1 - e(X))} \mathbb{E} \left[ e(X)\mu_1 - e(X) ( e(X) \mu_1(X) + (1 - e(X)) \mu_0(X) \mid X ] \right] \\ &= \frac{1}{e(X)(1 - e(X))} \left[ e(X)(1 - e(X)) (\mu_1(X) - \mu_0(X)) \right] \\ &= \tau(X) \end{aligned} \]

To preserve the orthogonality condition, we need to estimate both nuisance functions \(e(X)\) and \(m(X)\) on samples independently from the influence function of observation \(i\). The idea being that when \(\hat{e}(X)\) or \(\hat{m}(X)\) depend on observation \(i\), they are correlated with \(Y_i\) or \(A_i\), which breaks the mean 0 moment condition. This can be done using a cross-fitting idea. In practice, this may not be exactly preserved, which may not hurt that much. The following code utilize the grf package to do this. By default, the package will use the honest approach to construct the trees, with half samples to construct the tree structure and the other have for the and provide variance estimates for the CATE.

  library(grf)
  
  # Simulated data
  n <- 500
  p <- 5
  X <- matrix(rnorm(n * p), n, p)
  W <- rbinom(n, 1, 0.5)  
  tau <- 1.5 * X[,1] # the true CATE function
  
  Y <- X[,2] + X[, 3] + W * tau + rnorm(n)

  par(mfrow=c(1,2))
  # plot data with two regression lines
  plot(X[,1], Y, col = ifelse(W==1, "deepskyblue", "darkorange"), pch=16, xlab="X1", ylab="Y")
  abline(lm(Y[W==1] ~ X[W==1,1]), col="deepskyblue", lwd=2)
  abline(lm(Y[W==0] ~ X[W==0,1]), col="darkorange", lwd=2) 
  
  # Fit a causal forest
  cf <- causal_forest(X, Y, W,
                      honesty = TRUE,
                      honesty.fraction = 0.5)
  
  # Estimate conditional treatment effects
  cf.pred <- predict(cf, estimate.variance = TRUE)
  tau_hat <- cf.pred$predictions
  
  # plot the estimated CATE against the true CATE
  plot(X[,1], tau_hat, pch=16, xlab="X1", ylab="Estimated CATE", col="black")
  points(X[,1], tau, col="red", pch=16, cex = 0.5)
  legend("bottomright", legend=c("Estimated CATE", "True CATE"), 
         col=c("black", "red"), pch=16, cex = 1.5)

An interesting view of this method, similar to the virtual twins method, is that the final model would not care about if a variable affects the regression function or not, since the splitting rule would force the model to focus on the treatment effect heterogeneity. This can be seen from the variable importance plot below. Here, only the first variable \(X_1\) affects the CATE, while the second and third variables affect the regression function but not the CATE. However, due to the estimation of \(m(X)\), this causal forest is still an indirect learning method.

  variable_importance(cf)
##            [,1]
## [1,] 0.74158514
## [2,] 0.05624845
## [3,] 0.06119734
## [4,] 0.07624305
## [5,] 0.06472602

5 The Meta-Learner View

A popular framework of CATE estimation is the so-called meta-learner proposed by Künzel et al. (2019). The idea is to decompose the CATE estimation problem into several standard supervised learning problems, and then combine the results to get the final CATE estimate. Several meta-learners are proposed in their paper, most notably the X-learner. At this point, the ideas of these estimators should be relatively straight forward based on what we have discussed so far. Hence, let’s just look at the construction of the X-learner. And here are its steps:

  1. Fit Outcome models: \[ \hat{\mu}_1(x) = \E[Y \mid X=x, A=1], \quad \hat{\mu}_0(x) = \E[Y \mid X=x, A=0]. \]

  2. Create augmented treatment effect data for each individual:

    • For treated individuals (\(A_i=1\)): \[ \tilde{D}_i^{(1)} = Y_i - \hat{\mu}_0(X_i) \]
    • For control individuals (\(A_i=0\)): \[ \tilde{D}_i^{(0)} = \hat{\mu}_1(X_i) - Y_i \]
  3. Fit separate CATE models for each group: \[ \hat{\tau}_1(x) = \E[\tilde{D}^{(1)} \mid X=x, A=1], \quad \hat{\tau}_0(x) = \E[\tilde{D}^{(0)} \mid X=x, A=0]. \]

  4. Combine the two estimates using an estimated propensity score \(e(x) = \Pr(A=1 \mid X=x)\): \[ \hat{\tau}_X(x) = e(x)\hat{\tau}_0(x) + (1 - e(x))\hat{\tau}_1(x). \]

The following code implements the X-learner using random forests via the ranger package. Again, we will use a simulated data set where the true CATE is known for evaluation purposes.

  library(ranger)
## 
## Attaching package: 'ranger'
## The following object is masked from 'package:randomForest':
## 
##     importance

  # 1) Outcome models: mu1(x) = E[Y|A=1,X], mu0(x) = E[Y|A=0,X]
  idx1 <- which(W == 1L)
  idx0 <- which(W == 0L)
  df_t1 <- data.frame(y = Y[idx1], x = X[idx1, , drop = FALSE])
  df_t0 <- data.frame(y = Y[idx0], x = X[idx0, , drop = FALSE])
  
  mu1_fit <- ranger(y ~., data = df_t1)
  mu0_fit <- ranger(y ~., data = df_t0)
  
  # 2) Impute pseudo-effects
  # Treated: D^(1) = Y - mu0(X); Control: D^(0) = mu1(X) - Y
  mu0_pred_on_1 <- predict(mu0_fit, data = df_t1)$predictions
  mu1_pred_on_0 <- predict(mu1_fit, data = df_t0)$predictions
  
  D1 <- Y[idx1] - mu0_pred_on_1
  D0 <- mu1_pred_on_0 - Y[idx0]
  
  # 3) CATE models within each group: tau1(x), tau0(x)
  tau1_fit <- ranger(y ~ ., 
                     data = data.frame(x = X[idx1, , drop = FALSE], y = D1))
  
  tau0_fit <- ranger(y ~ ., 
                     data = data.frame(x = X[idx0, , drop = FALSE], y = D0))
  
  # 4.1) Propensity model e(x) = P(A=1|X) using ranger classifier (probability=TRUE)
  prop_fit <- ranger(y ~ ., 
                     data = data.frame(x = X, y = as.factor(W)),
                     probability = TRUE)
  
  e_hat <- prop_fit$predictions[, 2]
  
  # 4.2) Combine: tau_X(x) = e(x)*tau0(x) + (1 - e(x))*tau1(x)
  tau1_hat <- predict(tau1_fit, data = data.frame(x = X))$predictions
  tau0_hat <- predict(tau0_fit, data = data.frame(x = X))$predictions
  tau_hat_X <- e_hat * tau0_hat + (1 - e_hat) * tau1_hat
  
  # 4.3) Evaluation (true CATE known: tau = 2*X[,1])
  xlearner_mse <- mean((tau_hat_X - tau)^2)
  cat(sprintf("X-learner (ranger) in-sample CATE MSE: %.4f\n", xlearner_mse))
## X-learner (ranger) in-sample CATE MSE: 0.2597
  
  # (Optional) Simple diagnostic plot against the driver X1
  plot(X[,1], tau_hat_X, pch = 16, xlab = "X1", ylab = "Estimated CATE (X-learner, ranger)")
  points(X[,1], tau, col = "red", pch = 16, cex = 0.6)
  legend("topleft", legend = c("Estimated CATE", "True CATE"),
         col = c("black", "red"), pch = 16, bty = "n")

Following similar ideas, the DR-learner was proposed in Kennedy (2023), based on the doubly robust idea we have introduced before. The goal is to construct a pseudo-outcome that remains unbiased for \(\tau(x)\) if either the outcome model or the propensity model is correctly specified. It should be noted that the theoretical properties of these meta-learners relies on the concept of Neyman orthogonality, which is a core tool used in modern causal inference with machine learning plug-ins. We will not go into the details of this topic in this course, but interested readers may refer to Chernozhukov et al. (2018) for more information.


Chernozhukov, Victor, Denis Chetverikov, Mert Demirer, Esther Duflo, Christian Hansen, Whitney Newey, and James Robins. 2018. “Double/Debiased Machine Learning for Treatment and Structural Parameters: Double/Debiased Machine Learning.” The Econometrics Journal 21 (1).
Foster, Jared C, Jeremy MG Taylor, and Stephen J Ruberg. 2011. “Subgroup Identification from Randomized Clinical Trial Data.” Statistics in Medicine 30 (24): 2867–80.
Kennedy, Edward H. 2023. “Towards Optimal Doubly Robust Estimation of Heterogeneous Causal Effects.” Electronic Journal of Statistics 17 (2): 3008–49.
Künzel, Sören R, Jasjeet S Sekhon, Peter J Bickel, and Bin Yu. 2019. “Metalearners for Estimating Heterogeneous Treatment Effects Using Machine Learning.” Proceedings of the National Academy of Sciences 116 (10): 4156–65.
Tian, Lu, Ash A Alizadeh, Andrew J Gentles, and Robert Tibshirani. 2014. “A Simple Method for Estimating Interactions Between a Treatment and a Large Number of Covariates.” Journal of the American Statistical Association 109 (508): 1517–32.

  1. We sometimes also interested in the average treatment effect of a specific subgroup of individuals: \(E[ Y(1) - Y(0) | G = 1]\) where \(G\) is some indicator of a specific characteristics. This concept is called Local Average Treatment Effect (LATE)↩︎