Note: If you run the causal forest code chunks and
your computer struggles already with n=100, you probably have the issue
raised here.
You can solve it by installing this daily
build RStudio. Or just focus on the drawing part ;-)
Causal Christmas Tree Challenge
I have drawn a (stylized) christmas tree using two potential outcomes
functions and investigate how well causal forests approximate the
resulting CATE function for different sample sizes.
library(tidyverse)
library(grf)
library(patchwork)
set.seed(1234)
# Define the functions
cw = 0.05
ch = 0.6
m1 = function(x){(x < -0.8) * (3*x + 3) + ((x >= -0.8) & (x < (-0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(-0.8+cw)) & (x < -0.5)) * (3*x + 3) + ((x >= -0.5) & (x < (-0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(-0.5+cw)) & (x < -0.2)) * (3*x + 3) + ((x >= -0.2) & (x < (-0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(-0.2+cw)) & (x < 0)) * (3*x + 3) + (x >= 0 & x < 0.2) * (-3*x + 3) + ((x >= 0.2) & (x < (0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(0.2+cw)) & (x < 0.5)) * (-3*x + 3) + ((x >= 0.5) & (x < (0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(0.5+cw)) & (x < 0.8)) * (-3*x + 3) + ((x >= 0.8) & (x < (0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(0.8+cw))) * (-3*x + 3) }
m0 = function(x){0 * ((x<(-0.7)) + (x>(0.07)))
- 0.8 * ((x>(-0.07) & x<(0.07))) }
tau = function(x){m1(x) - m0(x)}
# Plot the two potential outcome fcts
g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="forestgreen") +
stat_function(fun=m0,size=1,colour="forestgreen") + ylab("Y(w)") + xlab("X1") + theme_bw()
# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1") + theme_bw()
g2
g3
Causal Forest meets causal christmas tree
I wrote a little function that takes the two potential outcome
functions as inputs, draws a random treatment \(W \sim Bernoulli(1/2)\) and adds standard
normal noise to the outcome. Then I run it for sample sizes
100/1000/10000/100000 and observe how the approximation of the very hard
function improves (runs about 15 minutes on my laptop).
cf_estimation = function(m1,m0,n,p=2,...) {
# Define the functions
cw = 0.05
ch = 0.6
# Get CATE function
tau = function(x){m1(x) - m0(x)}
# Draw sample
X = matrix(runif(n*p,-1,1),ncol=p)
W = rbinom(n,1,1/2)
Y = W*m1(X[,1]) + (1-W)*m0(X[,1]) + rnorm(n,0,1)
# Run CF
cf = causal_forest(X, Y, W, ...)
cates = predict(cf)$predictions
# Plot
g = data.frame(x=X[,1],y=cates) %>% ggplot() + geom_point(aes(x=x,y=y),shape="square",color="blue") +
stat_function(fun=tau,size=1) + ylab("CATE") + ggtitle(paste0("n=",toString(n)))
# RMSE
rmse = sqrt(mean((cates - tau(X[,1]))^2))
# Return results
list("g" = g,"RMSE" = rmse)
}
n100 = cf_estimation(m1,m0,100,tune.parameters = "all")
n1000 = cf_estimation(m1,m0,1000,tune.parameters = "all")
n10000 = cf_estimation(m1,m0,10000,tune.parameters = "all")
n100000 = cf_estimation(m1,m0,100000,tune.parameters = "all")
(n100$g | n1000$g) / (n10000$g | n100000$g)
data.frame(RMSE = c(n100$RMSE,n1000$RMSE,n10000$RMSE,n100000$RMSE),
n = factor(c("n=100","n=1000","n=10000","n=100000"))) %>%
ggplot(aes(x=n,y=RMSE)) + geom_point() + theme_bw() + geom_hline(yintercept = 0)
Bonus assignment: Your turn (5P)
Draw also something using two potential outcome functions (either
something different or beautify my tree) and check how well causal
forests can approximate the CATE function resulting from your drawing.
You can reuse and/or modify the code above for this purpose. Send me you
solutions by 22.12.
LS0tDQp0aXRsZTogIkNhdXNhbCBNTCAtIEJvbnVzIEFzc2lnbm1lbnQiDQojIGF1dGhvcjogIllvdXIgbmFtZSAoc3R1ZGVudCBJRCkiDQpvdXRwdXQ6IA0KICBodG1sX25vdGVib29rOg0KICAgIHRvYzogdHJ1ZQ0KICAgIHRvY19mbG9hdDogdHJ1ZQ0KICAgIGNvZGVfZm9sZGluZzogc2hvdw0KLS0tDQoNCg0KPGJyPg0KDQoqKk5vdGU6KiogSWYgeW91IHJ1biB0aGUgY2F1c2FsIGZvcmVzdCBjb2RlIGNodW5rcyBhbmQgeW91ciBjb21wdXRlciBzdHJ1Z2dsZXMgYWxyZWFkeSB3aXRoIG49MTAwLCB5b3UgcHJvYmFibHkgaGF2ZSB0aGUgaXNzdWUgcmFpc2VkIFtoZXJlXShodHRwczovL2dpdGh1Yi5jb20vcnN0dWRpby9yc3R1ZGlvL2lzc3Vlcy8xMzk2NSNpc3N1ZWNvbW1lbnQtMTgzMDUxOTQ3NCkuIFlvdSBjYW4gc29sdmUgaXQgYnkgaW5zdGFsbGluZyB0aGlzIFtkYWlseSBidWlsZCBSU3R1ZGlvXShodHRwczovL2RhaWxpZXMucnN0dWRpby5jb20vdmVyc2lvbi8yMDIzLjEyLjAtZGFpbHkrMzM2LykuIE9yIGp1c3QgZm9jdXMgb24gdGhlIGRyYXdpbmcgcGFydCA7LSkNCg0KPGJyPg0KDQojIENhdXNhbCBDaHJpc3RtYXMgVHJlZSBDaGFsbGVuZ2UNCg0KSSBoYXZlIGRyYXduIGEgKHN0eWxpemVkKSBjaHJpc3RtYXMgdHJlZSB1c2luZyB0d28gcG90ZW50aWFsIG91dGNvbWVzIGZ1bmN0aW9ucyBhbmQgaW52ZXN0aWdhdGUgaG93IHdlbGwgY2F1c2FsIGZvcmVzdHMgYXBwcm94aW1hdGUgdGhlIHJlc3VsdGluZyBDQVRFIGZ1bmN0aW9uIGZvciBkaWZmZXJlbnQgc2FtcGxlIHNpemVzLg0KDQoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUYsZmlnLmhlaWdodD00LCBmaWcud2lkdGg9M30NCmxpYnJhcnkodGlkeXZlcnNlKQ0KbGlicmFyeShncmYpDQpsaWJyYXJ5KHBhdGNod29yaykNCg0Kc2V0LnNlZWQoMTIzNCkNCg0KIyBEZWZpbmUgdGhlIGZ1bmN0aW9ucw0KY3cgPSAwLjA1DQpjaCA9IDAuNg0KbTEgPSBmdW5jdGlvbih4KXsoeCA8IC0wLjgpICogKDMqeCArIDMpICsgKCh4ID49IC0wLjgpICYgKHggPCAoLTAuOCtjdykpKSAqICgtMC44KjMgKyAzICsgY2gpICsgKCh4PigtMC44K2N3KSkgJiAoeCA8IC0wLjUpKSAqICgzKnggKyAzKSArICgoeCA+PSAtMC41KSAmICh4IDwgKC0wLjUrY3cpKSkgKiAoLTAuNSozICsgMyArIGNoKSArICgoeD4oLTAuNStjdykpICYgKHggPCAtMC4yKSkgKiAoMyp4ICsgMykgKyAoKHggPj0gLTAuMikgJiAoeCA8ICgtMC4yK2N3KSkpICogKC0wLjIqMyArIDMgKyBjaCkgKyAoKHg+KC0wLjIrY3cpKSAmICh4IDwgMCkpICogKDMqeCArIDMpICsgKHggPj0gMCAmIHggPCAwLjIpICogKC0zKnggKyAzKSArICgoeCA+PSAwLjIpICYgKHggPCAoMC4yK2N3KSkpICogKC0wLjIqMyArIDMgKyBjaCkgKyAoKHg+KDAuMitjdykpICYgKHggPCAwLjUpKSAqICgtMyp4ICsgMykgKyAoKHggPj0gMC41KSAmICh4IDwgKDAuNStjdykpKSAqICgtMC41KjMgKyAzICsgY2gpICsgKCh4PigwLjUrY3cpKSAmICh4IDwgMC44KSkgKiAoLTMqeCArIDMpICsgKCh4ID49IDAuOCkgJiAoeCA8ICgwLjgrY3cpKSkgKiAoLTAuOCozICsgMyArIGNoKSArICgoeD4oMC44K2N3KSkpICogKC0zKnggKyAzKSAgfQ0KbTAgPSBmdW5jdGlvbih4KXswICogKCh4PCgtMC43KSkgKyAoeD4oMC4wNykpKSANCiAgICAgICAgICAgICAgICAtIDAuOCAqICAoKHg+KC0wLjA3KSAmIHg8KDAuMDcpKSkgfQ0KdGF1ID0gZnVuY3Rpb24oeCl7bTEoeCkgLSBtMCh4KX0NCg0KIyBQbG90IHRoZSB0d28gcG90ZW50aWFsIG91dGNvbWUgZmN0cw0KZzIgPSBkYXRhLmZyYW1lKHggPSBjKC0xLCAxKSkgJT4lIGdncGxvdChhZXMoeCkpICsgc3RhdF9mdW5jdGlvbihmdW49bTEsc2l6ZT0xLGNvbG91cj0iZm9yZXN0Z3JlZW4iKSArIA0KICBzdGF0X2Z1bmN0aW9uKGZ1bj1tMCxzaXplPTEsY29sb3VyPSJmb3Jlc3RncmVlbiIpICsgeWxhYigiWSh3KSIpICsgeGxhYigiWDEiKSArIHRoZW1lX2J3KCkNCiMgUGxvdCBDQVRFIGZ1bmN0aW9uDQpnMyA9IGRhdGEuZnJhbWUoeCA9IGMoLTEsIDEpKSAlPiUgZ2dwbG90KGFlcyh4KSkgKyBzdGF0X2Z1bmN0aW9uKGZ1bj10YXUsc2l6ZT0xKSArIHlsYWIoIkNBVEUiKSArIHhsYWIoIlgxIikgKyB0aGVtZV9idygpDQpnMg0KZzMNCmBgYA0KDQo8YnI+DQo8YnI+DQoNCiMgQ2F1c2FsIEZvcmVzdCBtZWV0cyBjYXVzYWwgY2hyaXN0bWFzIHRyZWUNCg0KSSB3cm90ZSBhIGxpdHRsZSBmdW5jdGlvbiB0aGF0IHRha2VzIHRoZSB0d28gcG90ZW50aWFsIG91dGNvbWUgZnVuY3Rpb25zIGFzIGlucHV0cywgZHJhd3MgYSByYW5kb20gdHJlYXRtZW50ICRXIFxzaW0gQmVybm91bGxpKDEvMikkIGFuZCBhZGRzIHN0YW5kYXJkIG5vcm1hbCBub2lzZSB0byB0aGUgb3V0Y29tZS4gVGhlbiBJIHJ1biBpdCBmb3Igc2FtcGxlIHNpemVzIDEwMC8xMDAwLzEwMDAwLzEwMDAwMCBhbmQgb2JzZXJ2ZSBob3cgdGhlIGFwcHJveGltYXRpb24gb2YgdGhlIHZlcnkgaGFyZCBmdW5jdGlvbiBpbXByb3ZlcyAocnVucyBhYm91dCAxNSBtaW51dGVzIG9uIG15IGxhcHRvcCkuDQoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUZ9DQpjZl9lc3RpbWF0aW9uID0gZnVuY3Rpb24obTEsbTAsbixwPTIsLi4uKSB7DQogICMgRGVmaW5lIHRoZSBmdW5jdGlvbnMNCiAgY3cgPSAwLjA1DQogIGNoID0gMC42DQogIA0KICAjIEdldCBDQVRFIGZ1bmN0aW9uDQogIHRhdSA9IGZ1bmN0aW9uKHgpe20xKHgpIC0gbTAoeCl9DQogIA0KICAjIERyYXcgc2FtcGxlDQogIFggPSBtYXRyaXgocnVuaWYobipwLC0xLDEpLG5jb2w9cCkNCiAgVyA9IHJiaW5vbShuLDEsMS8yKQ0KICBZID0gVyptMShYWywxXSkgKyAoMS1XKSptMChYWywxXSkgKyBybm9ybShuLDAsMSkNCiAgDQogICMgUnVuIENGDQogIGNmID0gY2F1c2FsX2ZvcmVzdChYLCBZLCBXLCAuLi4pDQogIGNhdGVzID0gcHJlZGljdChjZikkcHJlZGljdGlvbnMNCg0KICAjIFBsb3QNCiAgZyA9IGRhdGEuZnJhbWUoeD1YWywxXSx5PWNhdGVzKSAlPiUgZ2dwbG90KCkgKyBnZW9tX3BvaW50KGFlcyh4PXgseT15KSxzaGFwZT0ic3F1YXJlIixjb2xvcj0iYmx1ZSIpICsNCiAgICBzdGF0X2Z1bmN0aW9uKGZ1bj10YXUsc2l6ZT0xKSArIHlsYWIoIkNBVEUiKSArIGdndGl0bGUocGFzdGUwKCJuPSIsdG9TdHJpbmcobikpKQ0KICANCiAgIyBSTVNFDQogIHJtc2UgPSBzcXJ0KG1lYW4oKGNhdGVzIC0gdGF1KFhbLDFdKSleMikpDQoNCiAgIyBSZXR1cm4gcmVzdWx0cw0KICBsaXN0KCJnIiA9IGcsIlJNU0UiID0gcm1zZSkNCn0NCg0KbjEwMCA9IGNmX2VzdGltYXRpb24obTEsbTAsMTAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KbjEwMDAgPSBjZl9lc3RpbWF0aW9uKG0xLG0wLDEwMDAsdHVuZS5wYXJhbWV0ZXJzID0gImFsbCIpDQpuMTAwMDAgPSBjZl9lc3RpbWF0aW9uKG0xLG0wLDEwMDAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KbjEwMDAwMCA9IGNmX2VzdGltYXRpb24obTEsbTAsMTAwMDAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KYGBgDQoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUZ9DQoobjEwMCRnIHwgbjEwMDAkZykgLyAobjEwMDAwJGcgfCBuMTAwMDAwJGcpDQoNCmRhdGEuZnJhbWUoUk1TRSA9IGMobjEwMCRSTVNFLG4xMDAwJFJNU0UsbjEwMDAwJFJNU0UsbjEwMDAwMCRSTVNFKSwNCiAgIG4gPSBmYWN0b3IoYygibj0xMDAiLCJuPTEwMDAiLCJuPTEwMDAwIiwibj0xMDAwMDAiKSkpICU+JQ0KICBnZ3Bsb3QoYWVzKHg9bix5PVJNU0UpKSArIGdlb21fcG9pbnQoKSArIHRoZW1lX2J3KCkgKyBnZW9tX2hsaW5lKHlpbnRlcmNlcHQgPSAwKQ0KYGBgDQoNCg0KPGJyPg0KPGJyPg0KDQojIEJvbnVzIGFzc2lnbm1lbnQ6IFlvdXIgdHVybiAoNVApDQoNCkRyYXcgYWxzbyBzb21ldGhpbmcgdXNpbmcgdHdvIHBvdGVudGlhbCBvdXRjb21lIGZ1bmN0aW9ucyAoZWl0aGVyIHNvbWV0aGluZyBkaWZmZXJlbnQgb3IgYmVhdXRpZnkgbXkgdHJlZSkgYW5kIGNoZWNrIGhvdyB3ZWxsIGNhdXNhbCBmb3Jlc3RzIGNhbiBhcHByb3hpbWF0ZSB0aGUgQ0FURSBmdW5jdGlvbiByZXN1bHRpbmcgZnJvbSB5b3VyIGRyYXdpbmcuIFlvdSBjYW4gcmV1c2UgYW5kL29yIG1vZGlmeSB0aGUgY29kZSBhYm92ZSBmb3IgdGhpcyBwdXJwb3NlLiBTZW5kIG1lIHlvdSBzb2x1dGlvbnMgYnkgMjIuMTIuDQoNCg==