Post-estimation command to extract outcome weights for causal forest
implemented via the causal_forest function from the grf package.
Usage
# S3 method for class 'causal_forest'
get_outcome_weights(
object,
...,
S,
newdata = NULL,
S.tau = NULL,
target = "CATE",
checks = TRUE
)Arguments
- object
An object of class
causal_forest, i.e. the result of runningcausal_forest.- ...
Pass potentially generic get_outcome_weights options.
- S
A smoother matrix reproducing the outcome predictions used in building the
instrumental_forest. Obtained by callingget_forest_weights()for theregression_forestobject producing the outcome predictions.- newdata
Corresponds to
newdataoption inpredict.causal_forest. IfNULL, out-of-bag outcome weights, otherwise for those for the provided test data returned.- S.tau
Required if
target != "CATE", then S.tau is the CATE smoother obtained from runningget_outcome_weights()withtarget == "CATE".- target
Target parameter for which outcome weights should be extracted. Currently
c("CATE","ATE")implemented.- checks
Default
TRUEchecks whether weights numerically replicate original estimates. Only setFALSEif you know what you are doing and need to save computation time.
Value
get_outcome_weights object with omega containing weights and treat the treatment
References
Athey, S., Tibshirani, J., & Wager, S. (2019). Generalized random forest. The Annals of Statistics, 47(2), 1148-1178.
Knaus, M. C. (2024). Treatment effect estimators as weighted outcomes, https://arxiv.org/abs/2411.11559.
Examples
# \donttest{
# Sample from DGP borrowed from grf documentation
n = 500
p = 10
X = matrix(rnorm(n * p), n, p)
W = rbinom(n, 1, 0.5)
Y = pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)
# Run outcome regression and extract smoother matrix
forest.Y = grf::regression_forest(X, Y)
Y.hat = predict(forest.Y)$predictions
outcome_smoother = grf::get_forest_weights(forest.Y)
# Run causal forest with external Y.hats
c.forest = grf::causal_forest(X, Y, W, Y.hat = Y.hat)
# Predict on out-of-bag training samples.
cate.oob = predict(c.forest)$predictions
# Predict using the forest.
X.test = matrix(0, 101, p)
X.test[, 1] = seq(-2, 2, length.out = 101)
cate.test = predict(c.forest, X.test)$predictions
# Calculate outcome weights
omega_oob = get_outcome_weights(c.forest,S = outcome_smoother)
omega_test = get_outcome_weights(c.forest,S = outcome_smoother,newdata = X.test)
# Observe that they perfectly replicate the original CATEs
all.equal(as.numeric(omega_oob$omega %*% Y),
as.numeric(cate.oob))
#> [1] TRUE
all.equal(as.numeric(omega_test$omega %*% Y),
as.numeric(cate.test))
#> [1] TRUE
# Also the ATE estimates are perfectly replicated
omega_ate = get_outcome_weights(c.forest,target = "ATE",
S = outcome_smoother,
S.tau = omega_oob$omega)
all.equal(as.numeric(omega_ate$omega %*% Y),
as.numeric(grf::average_treatment_effect(c.forest, target.sample = "all")[1]))
#> [1] TRUE
# The omega weights can be plugged into balancing packages like cobalt
# }
