Handcode Double ML for ATE using AIPW
Implement it using the causalDML
The Application Notebook builds on the dataset that is kindly
provided in the hdm
package. The data was used in Chernozhukov
and Hansen (2004). Their paper investigates the effect of
participation in the employer-sponsored 401(k) retirement savings plan
) on net assets (net_tfa
). Since then the
data was used to showcase many new methods. It is not the most
comprehensive dataset with basically ten
age: age
db: defined benefit pension
educ: education (in years)
fsize: family size
hown: home owner
inc: income (in US $)
male: male
marr: married
pira: participation in individual retirement account (IRA)
twoearn: two earners
However, it is publicly available and the relatively few covariates ensure that the programs do not run too long.
# To install the causalDML package uncomment the following two lines
# library(devtools)
# install_github(repo="MCKnaus/causalDML")
# Load the packages required for later
set.seed(1234) # for replicability
options(scipen = 10) # Switch off scientific notation
# Outcome
Y = pension$net_tfa
# Treatment
W = pension$p401
# Create main effects matrix
X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
We want to estimate the effect of 401(k) participation on net assets. Today without imposing effect heterogeneity.
To implement the AIPW we estimate the nuisance parameters \(e(X)=E[W|X]\), \(m(0,X)=E[Y|W=0,X]\) and \(m(1,X)=E[Y|W=1,X]\) using random forest and plug the predictions into the pseudo outcome: \[ \begin{align} \tilde{Y}_{\gamma_0} & = \underbrace{\hat{m}(0,X)}_{\text{outcome predictions}} + \underbrace{\frac{(1-W) (Y - \hat{m}(0,X))}{1-\hat{e}(X)}}_{\text{weighted residuals}} \\ \tilde{Y}_{\gamma_1} & = \underbrace{\hat{m}(1,X)}_{\text{outcome predictions}} + \underbrace{\frac{W (Y - \hat{m}(1,X))}{\hat{e}(X)}}_{\text{weighted residuals}} \\ \tilde{Y}_{ATE} & = \tilde{Y}_{\gamma_1} - \tilde{Y}_{\gamma_0} \\ &= \underbrace{\hat{m}(1,X) - \hat{m}(0,X)}_{\text{outcome predictions}} + \underbrace{\frac{W (Y - \hat{m}(1,X))}{\hat{e}(X)} - \frac{(1-W) (Y - \hat{m}(0,X))}{1-\hat{e}(X)}}_{\text{weighted residuals}} \end{align} \]
The theoretical results require that we predict the nuisance parameters out-of-sample. The easiest way to do this is two-fold cross-fitting:
Split the sample in two random subsamples, S1 and S2
Form prediction models in S1, use it to predict in S2
Form prediction models in S2, use it to predict in S1
# 2-fold cross-fitting
n = length(Y)
m0hat = m1hat = ehat = rep(NA,n)
# Draw random indices for sample 1
index_s1 = sample(1:n,n/2)
# Create S1
x1 = X[index_s1,]
w1 = W[index_s1]
y1 = Y[index_s1]
# Create sample 2 with those not in S1
x2 = X[-index_s1,]
w2 = W[-index_s1]
y2 = Y[-index_s1]
# Model in S1, predict in S2
rf = regression_forest(x1,w1)
ehat[-index_s1] = predict(rf,newdata=x2)$predictions
rf = regression_forest(x1[w1==0,],y1[w1==0])
m0hat[-index_s1] = predict(rf,newdata=x2)$predictions
rf = regression_forest(x1[w1==1,],y1[w1==1])
m1hat[-index_s1] = predict(rf,newdata=x2)$predictions
# Model in S2, predict in S1
rf = regression_forest(x2,w2)
ehat[index_s1] = predict(rf,newdata=x1)$predictions
rf = regression_forest(x2[w2==0,],y2[w2==0])
m0hat[index_s1] = predict(rf,newdata=x1)$predictions
rf = regression_forest(x2[w2==1,],y2[w2==1])
m1hat[index_s1] = predict(rf,newdata=x1)$predictions
\[ \begin{align} \tilde{Y}_{\gamma_0} & = \hat{m}(0,X) + \frac{(1-W) (Y - \hat{m}(0,X))}{1-\hat{e}(X)} \\ \tilde{Y}_{\gamma_1} & = \hat{m}(1,X) + \frac{W (Y - \hat{m}(1,X))}{\hat{e}(X)} \end{align} \]
Y_t_0 = m0hat + (1-W)*(Y-m0hat)/(1-ehat)
Y_t_1 = m1hat + W*(Y-m1hat)/ehat
summary(lm(Y_t_0 ~ 1))
lm(formula = Y_t_0 ~ 1)
Min 1Q Median 3Q Max
-800176 -15814 -13718 -2679 2706849
Estimate Std. Error t value Pr(>|t|)
(Intercept) 14152.5 793.7 17.83 <2e-16 ***
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 79040 on 9914 degrees of freedom
[1] 14152.48
summary(lm(Y_t_1 ~ 1))
lm(formula = Y_t_1 ~ 1)
Min 1Q Median 3Q Max
-1508685 -21560 -13401 5176 3302290
Estimate Std. Error t value Pr(>|t|)
(Intercept) 25638.2 956.1 26.82 <2e-16 ***
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 95200 on 9914 degrees of freedom
[1] 25638.21
\[ \begin{align} \tilde{Y}_{ATE} & = \tilde{Y}_{\gamma_1} - \tilde{Y}_{\gamma_0} \end{align} \]
Y_ate = Y_t_1 - Y_t_0
summary(lm(Y_ate ~ 1))
lm(formula = Y_ate ~ 1)
Min 1Q Median 3Q Max
-2635515 -13172 -2157 12260 3290639
Estimate Std. Error t value Pr(>|t|)
(Intercept) 11486 1160 9.899 <2e-16 ***
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 115500 on 9914 degrees of freedom
I think this is really neat. We need to be careful in constructing the pseudo-outcomes. However, at the end it boils down to run the simplest OLS you can think of and can use the familiar standard errors to test against the null hypothesis of a zero effect.
package2-fold cross-fitting is easy to implement but especially in small sample sizes, using only 50% of the data to estimate the nuisance parameters might lead to unstable predictions.
Thus, we use the DML_aipw
function of the
package to run 5-fold cross-fitting. This package
requires to create the methods that we use because it allows for
ensemble methods (for a more detailed intro see the GitHub page). For now,
we focus again on honest random forest.
With 5-fold cross-fitting, we split the sample in 5 folds and use 4 folds (80% of the data) to predict the left out fold (20% of the data). We iterate such that every fold is left out once.
# 5-fold cross-fitting with causalDML package
# Create learner
forest = create_method("forest_grf",args=list(tune.parameters = "all"))
# Run and
aipw = DML_aipw(Y,W,X,ml_w=list(forest),ml_y=list(forest),cf=5)
Let’s first have a look at the estimated average potential outcomes:
0 14165.56 799.3773
1 25820.97 972.8471
The average treatment effect is then just the difference between the two potential outcomes:
ATE SE t p
1 - 0 11655.4 1175.5 9.9154 < 2.2e-16 ***
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Finally, we can use the same nuisance parameters to estimate the ATT:
APO_att = APO_dml_atet(Y,aipw$APO$m_mat,aipw$APO$w_mat,aipw$APO$e_mat,aipw$APO$cf_mat)
ATT = ATE_dml(APO_att)
1 - 0 14867.9 1903.7 7.8099 6.302e-15 ***
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
We find that the ATT is roughly $3000 higher than the ATE:
# Collect the results
Effect = c(aipw$ATE$result[1],ATT$results[1])
se = c(aipw$ATE$result[2],ATT$results[2])
Target = c("ATE","ATT"),
cil = Effect - 1.96*se,
ciu = Effect + 1.96*se) %>%
ggplot(aes(x=Target,y=Effect,ymin=cil,ymax=ciu)) + geom_point(size=2.5) + geom_errorbar(width=0.15) +
geom_hline(yintercept=0) + xlab("Target parameter")
If you want to understand whether ATE and ATT are statistically significant, check out application notebook ANB_Generic_DML.