From “naive trees” to Causal Forests
DGP
Consider the following experimental DGP with discrete heterogeneous
treatment effects, but otherwise similar to those in previous
notebooks:
\(p=10\) independent covariates
\(X_1,...,X_k,...,X_{10}\) drawn from a
uniform distribution: \(X_k \sim
uniform(-\pi,\pi)\)
The treatment model is \(W \sim
Bernuolli(\underbrace{2/3}_{e(X)})\)
The potential outcome model of the controls is \(Y(0) = \underbrace{sin(X_1)}_{m_0(X)} +
\varepsilon\), with \(\varepsilon \sim
N(0,1/2)\)
The CATE function is an indicator function \(\tau(X) = 1[X_1 > -0.5\pi]\)
The potential outcome model of the treated is \(Y(1) = m_0(X) + \tau(X) + \varepsilon\),
with \(\varepsilon \sim
N(0,1/2)\)
if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
if (!require("tidyverse")) install.packages("tidyverse", dependencies = TRUE); library(tidyverse)
if (!require("patchwork")) install.packages("patchwork", dependencies = TRUE); library(patchwork)
if (!require("rpart")) install.packages("rpart", dependencies = TRUE); library(rpart)
if (!require("rpart.plot")) install.packages("rpart.plot", dependencies = TRUE); library(rpart.plot)
if (!require("partykit")) install.packages("partykit", dependencies = TRUE); library(partykit)
if (!require("causalTree")) {
if (!require("devtools")) install.packages("devtools", dependencies = TRUE); library(devtools)
install_github("susanathey/causalTree")
}; library(causalTree)
set.seed(12345) # Admittedly a bit seed-hacked in favor of clean illustration
# Set parameters
n = 1000
p = 10
# Draw sample
x = matrix(runif(n*p,-pi,pi),ncol=p)
e = function(x){2/3}
m0 = function(x){sin(x)}
tau = function(x){1*(x>-0.5*pi)}
m1 = function(x){m0(x) + tau(x)}
w = rbinom(n,1,e(x))
y = m0(x[,1]) + w*tau(x[,1]) + rnorm(n,0,1/2)
g1 = data.frame(x = c(-pi, pi)) %>% ggplot(aes(x)) + stat_function(fun=e,size=1) + ylab("e") + xlab("X1")
g2 = data.frame(x = c(-pi, pi)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,aes(colour="Y1")) +
stat_function(fun=m0,size=1,aes(colour="Y0")) + ylab("Y") + xlab("X1")
g3 = data.frame(x = c(-pi, pi)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab(expression(tau)) + xlab("X1")
g1 / g2 / g3
T-learner with regression trees
First we consider to estimate the conditional means in the treated
sample and the control separately and taking the difference of the
predicted outcomes as estimates for the CATE (see slide 19).
- Use regression tree to fit model in control subsample
df = data.frame(x=x,y=y)
tree0 = rpart(y~x,data = df,subset = (w==0))
rpart.plot(tree0)
- Use regression tree to fit model in treated subsample
tree1 = rpart(y~x,data = df,subset= (w==1))
rpart.plot(tree1)
- Plot predicted outcomes and CATEs:
df$apo_tree0 = predict(tree0,newdata=data.frame(x))
df$apo_tree1 = predict(tree1,newdata=data.frame(x))
df$cate_tree = df$apo_tree1 - df$apo_tree0
g1 = ggplot(df) + stat_function(fun=m1,size=1) + ylab("m1") +
geom_point(aes(x=x[,1],y=apo_tree1),shape="square",color="blue")
g2 = ggplot(df) + stat_function(fun=m0,size=1) + ylab("m0") +
geom_point(aes(x=x[,1],y=apo_tree0),shape="square",color="blue")
g3 = ggplot(df) + stat_function(fun=tau,size=1) + ylab(expression(tau)) +
geom_point(aes(x=x[,1],y=cate_tree),shape="square",color="blue")
g1 / g2 / g3
We observe that the predicted CATEs are quite erratic. Approximating
the outcomes is very challenging for the tree and it fails quite
dramatically especially in the control sample where there are fewer
observations available. This spills over to the downstream CATE
estimates. The CATE is in principle tailored to be found by a tree
structure, but the more complicated outcome functions distract this
“naive” procedure.
Causal Tree
Handcoded
Causal Trees are build to directly estimate CATEs and to not be
distracted by potentially more complicated outcome functions.
Before using the package let’s handcode the main idea to see how and
that it is working.
Recall from the lecture that the splitting criterion can be expressed
as \(\max \sum_i
\hat{\tau}^{tree}(X_i)^2\). Below we write code that searches for
the first split. To this end we iterate over a grid of potential split
point \(s\). At each split point and
variable \(j\) we proceed as
follows
Calculate the effect in the left leaf as mean difference between
treated and control units \(\hat{\tau}_{L(j,s)}\)
Calculate the effect in the right leaf as mean difference between
treated and control units \(\hat{\tau}_{R(j,s)}\)
Calculate the criterion as \(N_L
\hat{\tau}_{L(j,s)}^2 + N_R \hat{\tau}_{R(j,s)}^2\), where \(N_L\) and \(N_R\) are the number of units observed in
the left leaf and the right leaf, respectively.
Finally, we pick the variable and the splitting point that maximizes
the criterion. In our DGP it should split at \(X_1 = -0.5\pi \approx -1.6\).
# Hand-coded causal tree
grid = seq(-3,3,0.01)
criterion = matrix(NA,length(grid),p)
colnames(criterion) = paste0("X",1:p)
for (j in 1:p) {
for (i in 1:length(grid)) {
# Indicator for being right of cut-off
right = (x[,j] > grid[i])
# Calculate the effect as mean differences in the two leaves
cate_left = mean(y[w==1 & !right]) - mean(y[w==0 & !right])
cate_right = mean(y[w==1 & right]) - mean(y[w==0 & right])
# Calculate and store criterion
criterion[i,j] = (n-sum(right)) * (cate_left)^2 + sum(right) * (cate_right)^2
}
}
# Find maximum
index_max = which(criterion == max(criterion), arr.ind = TRUE)
# Plot criteria
data.frame(x=grid,criterion) %>%
pivot_longer(cols=-x,names_to = "Variable",values_to = "Criterion") %>%
ggplot(aes(x=x ,y=Criterion,colour=Variable)) + geom_line(size=1) + geom_vline(xintercept=-0.5*pi) +
geom_vline(xintercept=grid[index_max[1]],linetype = "dashed")
The detected splitting point (dashed line) is very close to the
correct splitting point (solid line). The other nine variables show
quite flat criteria along the grid, as it should be the case. Only the
criterion along \(X_1\) rises the
closer the candidate split comes to the correct one.
Package
The causalTree
package implements the procedure including cross-validation:
# Implemented causalTree adapting specification from R example
ctree = causalTree(y~x, data = df, treatment = w,
split.Rule = "CT", cv.option = "CT", split.Honest = T,split.Bucket = F, xval = 5,
cp = 0, minsize = 20)
[1] 2
[1] "CT"
opcp = ctree$cptable[,1][which.min(ctree$cptable[,4])]
opfit = prune(ctree, opcp)
df$cate_ct = predict(opfit)
rpart.plot(opfit)
ggplot(df) + stat_function(fun=tau,size=1) + ylab(expression(tau)) +
geom_point(aes(x=x[,1],y=cate_ct),shape="square",color="blue")
The package version does not know that only one split is required,
but finds the correct split value and prunes the tree appropriately via
cross-validation.
T-learner with random forest
Regression trees are obviously not well-suited to estimate smooth
funtions. Thus, we use next a self-tuned honest regression forest to
form the separate prediction models:
rf0 = regression_forest(x[w==0,], y[w==0],tune.parameters = "all")
rf1 = regression_forest(x[w==1,], y[w==1],tune.parameters = "all")
df$apo_rf0 = predict(rf0,newdata=x)$predictions
df$apo_rf1 = predict(rf1,newdata=x)$predictions
df$cate_rf = df$apo_rf1 - df$apo_rf0
g1 = ggplot(df) + stat_function(fun=m1,size=1) + ylab("m1") +
geom_point(aes(x=x[,1],y=apo_rf1),shape="square",color="blue")
g2 = ggplot(df) + stat_function(fun=m0,size=1) + ylab("m0") +
geom_point(aes(x=x[,1],y=apo_rf0),shape="square",color="blue")
g3 = ggplot(df) + stat_function(fun=tau,size=1) + ylab(expression(tau)) +
geom_point(aes(x=x[,1],y=cate_rf),shape="square",color="blue")
g1 / g2 / g3
This looks much better in terms of approximating the simple discrete
CATE function. However, some problems in approximating the complicated
outcome functions still spills over to the CATE estimates.
Causal Forest
The Causal Forest uses an approximation of the splitting criterion
that we handcoded above and exploits the resulting weights of an
ensemble of Causal Trees to estimate the CATEs.
Further, we can extract an estimate of the two outcome functions
as
\(\hat{m}(0,X) = \hat{m}(X) -
\hat{e}(X) \hat{\tau}(X)\)
\(\hat{m}(1,X) = \hat{m}(X) +
(1-\hat{e}(X)) \hat{\tau}(X)\)
As \(\hat{m}(X)\) and \(\hat{e}(X)\) are the nuisance parameters of
the Causal Forest, this comes at no additional computational costs.
cf = causal_forest(x, y, w,tune.parameters = "all")
df$cate_cf = predict(cf)$predictions
df$apo_cf0 = cf$Y.hat - cf$W.hat * df$cate_cf
df$apo_cf1 = cf$Y.hat + (1-cf$W.hat) * df$cate_cf
g1 = ggplot(df) + stat_function(fun=m1,size=1) + ylab("m1") +
geom_point(aes(x=x[,1],y=apo_cf1),shape="square",color="blue")
g2 = ggplot(df) + stat_function(fun=m0,size=1) + ylab("m0") +
geom_point(aes(x=x[,1],y=apo_cf0),shape="square",color="blue")
g3 = ggplot(df) + stat_function(fun=tau,size=1) + ylab(expression(tau)) +
geom_point(aes(x=x[,1],y=cate_cf),shape="square",color="blue")
g1 / g2 / g3
The Causal Forest pretty much nails the CATE function. Especially it
is flat in the regions where the effect is stable, unlike when using the
two separate Random Forests above.
Comparison
As we defined the truth, we can calculate the MSE for the different
parameters and methods:
mses = matrix(NA,4,3)
colnames(mses) = c("m(0,X)","m(1,X)","tau(X)")
mses[1,1] = mean( (df$apo_tree0 - m0(x[,1]))^2 )
mses[3,1] = mean( (df$apo_rf0 - m0(x[,1]))^2 )
mses[4,1] = mean( (df$apo_cf0 - m0(x[,1]))^2 )
mses[1,2] = mean( (df$apo_tree1 - m1(x[,1]))^2 )
mses[3,2] = mean( (df$apo_rf1 - m1(x[,1]))^2 )
mses[4,2] = mean( (df$apo_cf1 - m1(x[,1]))^2 )
mses[1,3] = mean( (df$cate_tree - tau(x[,1]))^2 )
mses[2,3] = mean( (df$cate_ct - tau(x[,1]))^2 )
mses[3,3] = mean( (df$cate_rf - tau(x[,1]))^2 )
mses[4,3] = mean( (df$cate_cf - tau(x[,1]))^2 )
data.frame(Method = factor(c("Tree","Causal Tree","Forest","Causal Forest"),
levels=c("Tree","Causal Tree","Forest","Causal Forest")),
mses) %>%
pivot_longer(cols=-Method,names_to = "Target",values_to = "MSE") %>%
ggplot(aes(x=Method,y=MSE)) + geom_point() + facet_wrap(~Target) +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)) +
geom_hline(yintercept = 0)
As expected the Causal Forest provides the lowest MSE for the
approximation of the CATE.
Interestingly it also approximates the control outcome better than
the separate forest. This might be further investigated in a bigger
simulation study, but not in this notebook.
Causal Forest behind the scenes
To get a better understanding what Causal Forests are doing, consider
the following, now observational, DGP with only two covariates:
\(p=2\) independent covariates
\(X_1,X_{2}\) drawn from a uniform
distribution: \(X_k \sim
uniform(-\pi,\pi)\)
The treatment model is \(W \sim
Bernoulli(\underbrace{\Phi(sin(X_1))}_{e(X)})\), where \(\Phi(\cdot)\) is the standard normal
cumulative density function
The potential outcome model of the controls is \(Y(0) = \underbrace{sin(X_1)}_{m_0(X)} +
\varepsilon\), with \(\varepsilon \sim
N(0,1/10)\)
The CATE function is an indicator function \(\tau(X) = 1[X_1 > -0.5\pi]\)
The potential outcome model of the treated is \(Y(1) = m_0(X) + \tau(X) + \varepsilon\),
with \(\varepsilon \sim
N(0,1/10)\)
# Set parameters
p = 2
# Draw sample
x = matrix(runif(n*p,-pi,pi),ncol=p)
e = function(x){pnorm(sin(x))}
m0 = function(x){sin(x)}
tau = function(x){0 + 1*(x>-0.5*pi)}
m1 = function(x){m0(x) + tau(x)}
w = rbinom(n,1,e(x))
y = m0(x[,1]) + w*tau(x[,1]) + rnorm(n,0,1/10)
g1 = data.frame(x = c(-pi, pi)) %>% ggplot(aes(x)) + stat_function(fun=e,size=1) + ylab("e") + xlab("X1")
g2 = data.frame(x = c(-pi, pi)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,aes(colour="Y1")) +
stat_function(fun=m0,size=1,aes(colour="Y0")) + ylab("Y") + xlab("X1")
g3 = data.frame(x = c(-pi, pi)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab(expression(tau)) + xlab("X1")
g1 / g2 / g3
Weighted residual-on-residual regression at single point
Recall that the Causal Forest can be written as a
residual-on-residual regression with \(x\)-specific weights \(\alpha(x)\):
\[\hat{\tau}^{cf}(x) =
argmin_{\breve{\tau}} \left\{ \sum_{i=1}^N \alpha_i(x) \left[(Y_i -
\hat{m}(X_i)) - \breve{\tau} (W_i - \hat{e}(X_i)) \right]^2
\right\}\]
The weights can be accessed using the get_forest_weights
function on a causal_forest
object.
Consider that we are interested in estimating the CATE for \(X_1= -3\) and all other Xs equal to
zero.
The estimated value can be obtained via the predict
function
# Run CF
cf = causal_forest(x, y, w,tune.parameters = "all")
# Define test point
testx = matrix(c(-3,rep(0,p-1)),nrow=1)
# Check what package predicts
predict(cf,newdata = testx)$predictions
[1] 0.02728292
which is reasonably close to the true value of zero given the small
number of observations.
Alternatively we can handcode it as a weighted residual-on-residual
regression after extracting the underlying weights.
# Get residuals
res_y = y - cf$Y.hat
res_w = w - cf$W.hat
# Replicate handcoded
alphax = get_forest_weights(cf,newdata = testx)[1,]
coef(lm(res_y ~ res_w,weights = alphax))
(Intercept) res_w
0.01304695 0.02728292
The second coefficient (the slope coefficient) is identical to the
estimated CATE.
Remark: Note that the intercept is not necessary as we know
(and see) that it should be zero. However, the grf
package
includes it and to numerically match the grf output, we apply it as
well. However, note that the differences to the case without constant
are negligible:
# Replicate handcoded w/o constant
coef(lm(res_y ~ 0 + res_w,weights = alphax))
res_w
0.03039751
How the weights move for different predictions
Finally, this setting can be used to nicely illustrate what happens
at different values of \(X_1\). The two
residuals can be plotted against each other. Additionally the size of
the residuals is proportional to the weight they receive and the color
indicates lower to higher values of the variable:
# Run same over grid an see how weights move
grid = seq(-3,0,1)
gridx = cbind(grid,matrix(0,length(grid),p-1))
grid_hat = predict(cf,newdata = gridx)$predictions
alpha = get_forest_weights(cf,newdata = gridx)
for (i in 1:length(grid)) {
g1 = data.frame(x=grid,tau_hat=grid_hat) %>%
ggplot(aes(x=x ,y=tau_hat)) + stat_function(fun=tau,size=1) +
geom_line(color="blue") +
geom_point(aes(x=grid[i],y=grid_hat[i]),size=4,color="blue",shape=4)
rorr = lm(res_y ~ res_w,weights = alpha[i,])
g2 = data.frame(res_w,res_y,alpha=alpha[i,],x=x) %>%
ggplot(aes(x=res_w,y=res_y)) + geom_point(aes(size=alpha,color=x[,1]),alpha=0.5) +
geom_abline(intercept=rorr$coefficients[1],slope=rorr$coefficients[2]) +
annotate("text", x = -0.25, y = 1, label = paste0("tau(",toString(grid[i]),") = slope of line = ",
toString(round(rorr$coefficients[2],2))))+
scale_colour_gradient(low = "black", high = "yellow")
print(g1 / g2)
}
The jump from \(X = -2\) to \(X=-1\) is most instructive. We see that
before the jump of the true CATE function units with low values of \(X_1\) receive most weight. Furthermore,
controls (with negative treatment residuals) and treated (with positive
treatment residuals) of those with large weights are quite similar. This
results in a nearly horizontal slope and thus a CATE estimate of close
to zero.
After the jump, units with larger values of \(X_1\) receive most of the weight, resulting
in a clearly positive slope. As the slope of the line is the estimated
CATE, we (rightly) estimate a substantial positive effect after the
jump.
Take-away
Targeting CATEs explicitly works.
The Causal Forest can be thought of as residual-on-residual
regressions where the residuals are always the same, but the weights are
individualized.
Suggestions to play with the toy model
Some suggestions:
Increase/decrease the number of observations
Create different CATE function
Change the treatment shares
Increase noise level in second part and see how we can’t see much
if outcome noise is not very small. In real datasets you won’t see much,
but here it helps to build intuition.
