Note: If you run the causal forest code chunks and your computer struggles already with n=100, you probably have the issue raised here. You can solve it by installing this daily build RStudio. Or just focus on the drawing part ;-)


Causal Christmas Tree Challenge

I have drawn a (stylized) christmas tree using two potential outcomes functions and investigate how well causal forests approximate the resulting CATE function for different sample sizes.

library(tidyverse)
library(grf)
library(patchwork)

set.seed(1234)

# Define the functions
cw = 0.05
ch = 0.6
m1 = function(x){(x < -0.8) * (3*x + 3) + ((x >= -0.8) & (x < (-0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(-0.8+cw)) & (x < -0.5)) * (3*x + 3) + ((x >= -0.5) & (x < (-0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(-0.5+cw)) & (x < -0.2)) * (3*x + 3) + ((x >= -0.2) & (x < (-0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(-0.2+cw)) & (x < 0)) * (3*x + 3) + (x >= 0 & x < 0.2) * (-3*x + 3) + ((x >= 0.2) & (x < (0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(0.2+cw)) & (x < 0.5)) * (-3*x + 3) + ((x >= 0.5) & (x < (0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(0.5+cw)) & (x < 0.8)) * (-3*x + 3) + ((x >= 0.8) & (x < (0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(0.8+cw))) * (-3*x + 3)  }
m0 = function(x){0 * ((x<(-0.7)) + (x>(0.07))) 
                - 0.8 *  ((x>(-0.07) & x<(0.07))) }
tau = function(x){m1(x) - m0(x)}

# Plot the two potential outcome fcts
g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="forestgreen") + 
  stat_function(fun=m0,size=1,colour="forestgreen") + ylab("Y(w)") + xlab("X1") + theme_bw()
# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1") + theme_bw()
g2

g3



Causal Forest meets causal christmas tree

I wrote a little function that takes the two potential outcome functions as inputs, draws a random treatment \(W \sim Bernoulli(1/2)\) and adds standard normal noise to the outcome. Then I run it for sample sizes 100/1000/10000/100000 and observe how the approximation of the very hard function improves (runs about 15 minutes on my laptop).

cf_estimation = function(m1,m0,n,p=2,...) {
  # Define the functions
  cw = 0.05
  ch = 0.6
  
  # Get CATE function
  tau = function(x){m1(x) - m0(x)}
  
  # Draw sample
  X = matrix(runif(n*p,-1,1),ncol=p)
  W = rbinom(n,1,1/2)
  Y = W*m1(X[,1]) + (1-W)*m0(X[,1]) + rnorm(n,0,1)
  
  # Run CF
  cf = causal_forest(X, Y, W, ...)
  cates = predict(cf)$predictions

  # Plot
  g = data.frame(x=X[,1],y=cates) %>% ggplot() + geom_point(aes(x=x,y=y),shape="square",color="blue") +
    stat_function(fun=tau,size=1) + ylab("CATE") + ggtitle(paste0("n=",toString(n)))
  
  # RMSE
  rmse = sqrt(mean((cates - tau(X[,1]))^2))

  # Return results
  list("g" = g,"RMSE" = rmse)
}

n100 = cf_estimation(m1,m0,100,tune.parameters = "all")
n1000 = cf_estimation(m1,m0,1000,tune.parameters = "all")
n10000 = cf_estimation(m1,m0,10000,tune.parameters = "all")
n100000 = cf_estimation(m1,m0,100000,tune.parameters = "all")
(n100$g | n1000$g) / (n10000$g | n100000$g)


data.frame(RMSE = c(n100$RMSE,n1000$RMSE,n10000$RMSE,n100000$RMSE),
   n = factor(c("n=100","n=1000","n=10000","n=100000"))) %>%
  ggplot(aes(x=n,y=RMSE)) + geom_point() + theme_bw() + geom_hline(yintercept = 0)



Bonus assignment: Your turn (5P)

Draw also something using two potential outcome functions (either something different or beautify my tree) and check how well causal forests can approximate the CATE function resulting from your drawing. You can reuse and/or modify the code above for this purpose. Send me you solutions by 22.12.

LS0tDQp0aXRsZTogIkNhdXNhbCBNTCAtIEJvbnVzIEFzc2lnbm1lbnQiDQojIGF1dGhvcjogIllvdXIgbmFtZSAoc3R1ZGVudCBJRCkiDQpvdXRwdXQ6IA0KICBodG1sX25vdGVib29rOg0KICAgIHRvYzogdHJ1ZQ0KICAgIHRvY19mbG9hdDogdHJ1ZQ0KICAgIGNvZGVfZm9sZGluZzogc2hvdw0KLS0tDQoNCg0KPGJyPg0KDQoqKk5vdGU6KiogSWYgeW91IHJ1biB0aGUgY2F1c2FsIGZvcmVzdCBjb2RlIGNodW5rcyBhbmQgeW91ciBjb21wdXRlciBzdHJ1Z2dsZXMgYWxyZWFkeSB3aXRoIG49MTAwLCB5b3UgcHJvYmFibHkgaGF2ZSB0aGUgaXNzdWUgcmFpc2VkIFtoZXJlXShodHRwczovL2dpdGh1Yi5jb20vcnN0dWRpby9yc3R1ZGlvL2lzc3Vlcy8xMzk2NSNpc3N1ZWNvbW1lbnQtMTgzMDUxOTQ3NCkuIFlvdSBjYW4gc29sdmUgaXQgYnkgaW5zdGFsbGluZyB0aGlzIFtkYWlseSBidWlsZCBSU3R1ZGlvXShodHRwczovL2RhaWxpZXMucnN0dWRpby5jb20vdmVyc2lvbi8yMDIzLjEyLjAtZGFpbHkrMzM2LykuIE9yIGp1c3QgZm9jdXMgb24gdGhlIGRyYXdpbmcgcGFydCA7LSkNCg0KPGJyPg0KDQojIENhdXNhbCBDaHJpc3RtYXMgVHJlZSBDaGFsbGVuZ2UNCg0KSSBoYXZlIGRyYXduIGEgKHN0eWxpemVkKSBjaHJpc3RtYXMgdHJlZSB1c2luZyB0d28gcG90ZW50aWFsIG91dGNvbWVzIGZ1bmN0aW9ucyBhbmQgaW52ZXN0aWdhdGUgaG93IHdlbGwgY2F1c2FsIGZvcmVzdHMgYXBwcm94aW1hdGUgdGhlIHJlc3VsdGluZyBDQVRFIGZ1bmN0aW9uIGZvciBkaWZmZXJlbnQgc2FtcGxlIHNpemVzLg0KDQoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUYsZmlnLmhlaWdodD00LCBmaWcud2lkdGg9M30NCmxpYnJhcnkodGlkeXZlcnNlKQ0KbGlicmFyeShncmYpDQpsaWJyYXJ5KHBhdGNod29yaykNCg0Kc2V0LnNlZWQoMTIzNCkNCg0KIyBEZWZpbmUgdGhlIGZ1bmN0aW9ucw0KY3cgPSAwLjA1DQpjaCA9IDAuNg0KbTEgPSBmdW5jdGlvbih4KXsoeCA8IC0wLjgpICogKDMqeCArIDMpICsgKCh4ID49IC0wLjgpICYgKHggPCAoLTAuOCtjdykpKSAqICgtMC44KjMgKyAzICsgY2gpICsgKCh4PigtMC44K2N3KSkgJiAoeCA8IC0wLjUpKSAqICgzKnggKyAzKSArICgoeCA+PSAtMC41KSAmICh4IDwgKC0wLjUrY3cpKSkgKiAoLTAuNSozICsgMyArIGNoKSArICgoeD4oLTAuNStjdykpICYgKHggPCAtMC4yKSkgKiAoMyp4ICsgMykgKyAoKHggPj0gLTAuMikgJiAoeCA8ICgtMC4yK2N3KSkpICogKC0wLjIqMyArIDMgKyBjaCkgKyAoKHg+KC0wLjIrY3cpKSAmICh4IDwgMCkpICogKDMqeCArIDMpICsgKHggPj0gMCAmIHggPCAwLjIpICogKC0zKnggKyAzKSArICgoeCA+PSAwLjIpICYgKHggPCAoMC4yK2N3KSkpICogKC0wLjIqMyArIDMgKyBjaCkgKyAoKHg+KDAuMitjdykpICYgKHggPCAwLjUpKSAqICgtMyp4ICsgMykgKyAoKHggPj0gMC41KSAmICh4IDwgKDAuNStjdykpKSAqICgtMC41KjMgKyAzICsgY2gpICsgKCh4PigwLjUrY3cpKSAmICh4IDwgMC44KSkgKiAoLTMqeCArIDMpICsgKCh4ID49IDAuOCkgJiAoeCA8ICgwLjgrY3cpKSkgKiAoLTAuOCozICsgMyArIGNoKSArICgoeD4oMC44K2N3KSkpICogKC0zKnggKyAzKSAgfQ0KbTAgPSBmdW5jdGlvbih4KXswICogKCh4PCgtMC43KSkgKyAoeD4oMC4wNykpKSANCiAgICAgICAgICAgICAgICAtIDAuOCAqICAoKHg+KC0wLjA3KSAmIHg8KDAuMDcpKSkgfQ0KdGF1ID0gZnVuY3Rpb24oeCl7bTEoeCkgLSBtMCh4KX0NCg0KIyBQbG90IHRoZSB0d28gcG90ZW50aWFsIG91dGNvbWUgZmN0cw0KZzIgPSBkYXRhLmZyYW1lKHggPSBjKC0xLCAxKSkgJT4lIGdncGxvdChhZXMoeCkpICsgc3RhdF9mdW5jdGlvbihmdW49bTEsc2l6ZT0xLGNvbG91cj0iZm9yZXN0Z3JlZW4iKSArIA0KICBzdGF0X2Z1bmN0aW9uKGZ1bj1tMCxzaXplPTEsY29sb3VyPSJmb3Jlc3RncmVlbiIpICsgeWxhYigiWSh3KSIpICsgeGxhYigiWDEiKSArIHRoZW1lX2J3KCkNCiMgUGxvdCBDQVRFIGZ1bmN0aW9uDQpnMyA9IGRhdGEuZnJhbWUoeCA9IGMoLTEsIDEpKSAlPiUgZ2dwbG90KGFlcyh4KSkgKyBzdGF0X2Z1bmN0aW9uKGZ1bj10YXUsc2l6ZT0xKSArIHlsYWIoIkNBVEUiKSArIHhsYWIoIlgxIikgKyB0aGVtZV9idygpDQpnMg0KZzMNCmBgYA0KDQo8YnI+DQo8YnI+DQoNCiMgQ2F1c2FsIEZvcmVzdCBtZWV0cyBjYXVzYWwgY2hyaXN0bWFzIHRyZWUNCg0KSSB3cm90ZSBhIGxpdHRsZSBmdW5jdGlvbiB0aGF0IHRha2VzIHRoZSB0d28gcG90ZW50aWFsIG91dGNvbWUgZnVuY3Rpb25zIGFzIGlucHV0cywgZHJhd3MgYSByYW5kb20gdHJlYXRtZW50ICRXIFxzaW0gQmVybm91bGxpKDEvMikkIGFuZCBhZGRzIHN0YW5kYXJkIG5vcm1hbCBub2lzZSB0byB0aGUgb3V0Y29tZS4gVGhlbiBJIHJ1biBpdCBmb3Igc2FtcGxlIHNpemVzIDEwMC8xMDAwLzEwMDAwLzEwMDAwMCBhbmQgb2JzZXJ2ZSBob3cgdGhlIGFwcHJveGltYXRpb24gb2YgdGhlIHZlcnkgaGFyZCBmdW5jdGlvbiBpbXByb3ZlcyAocnVucyBhYm91dCAxNSBtaW51dGVzIG9uIG15IGxhcHRvcCkuDQoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUZ9DQpjZl9lc3RpbWF0aW9uID0gZnVuY3Rpb24obTEsbTAsbixwPTIsLi4uKSB7DQogICMgRGVmaW5lIHRoZSBmdW5jdGlvbnMNCiAgY3cgPSAwLjA1DQogIGNoID0gMC42DQogIA0KICAjIEdldCBDQVRFIGZ1bmN0aW9uDQogIHRhdSA9IGZ1bmN0aW9uKHgpe20xKHgpIC0gbTAoeCl9DQogIA0KICAjIERyYXcgc2FtcGxlDQogIFggPSBtYXRyaXgocnVuaWYobipwLC0xLDEpLG5jb2w9cCkNCiAgVyA9IHJiaW5vbShuLDEsMS8yKQ0KICBZID0gVyptMShYWywxXSkgKyAoMS1XKSptMChYWywxXSkgKyBybm9ybShuLDAsMSkNCiAgDQogICMgUnVuIENGDQogIGNmID0gY2F1c2FsX2ZvcmVzdChYLCBZLCBXLCAuLi4pDQogIGNhdGVzID0gcHJlZGljdChjZikkcHJlZGljdGlvbnMNCg0KICAjIFBsb3QNCiAgZyA9IGRhdGEuZnJhbWUoeD1YWywxXSx5PWNhdGVzKSAlPiUgZ2dwbG90KCkgKyBnZW9tX3BvaW50KGFlcyh4PXgseT15KSxzaGFwZT0ic3F1YXJlIixjb2xvcj0iYmx1ZSIpICsNCiAgICBzdGF0X2Z1bmN0aW9uKGZ1bj10YXUsc2l6ZT0xKSArIHlsYWIoIkNBVEUiKSArIGdndGl0bGUocGFzdGUwKCJuPSIsdG9TdHJpbmcobikpKQ0KICANCiAgIyBSTVNFDQogIHJtc2UgPSBzcXJ0KG1lYW4oKGNhdGVzIC0gdGF1KFhbLDFdKSleMikpDQoNCiAgIyBSZXR1cm4gcmVzdWx0cw0KICBsaXN0KCJnIiA9IGcsIlJNU0UiID0gcm1zZSkNCn0NCg0KbjEwMCA9IGNmX2VzdGltYXRpb24obTEsbTAsMTAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KbjEwMDAgPSBjZl9lc3RpbWF0aW9uKG0xLG0wLDEwMDAsdHVuZS5wYXJhbWV0ZXJzID0gImFsbCIpDQpuMTAwMDAgPSBjZl9lc3RpbWF0aW9uKG0xLG0wLDEwMDAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KbjEwMDAwMCA9IGNmX2VzdGltYXRpb24obTEsbTAsMTAwMDAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KYGBgDQoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUZ9DQoobjEwMCRnIHwgbjEwMDAkZykgLyAobjEwMDAwJGcgfCBuMTAwMDAwJGcpDQoNCmRhdGEuZnJhbWUoUk1TRSA9IGMobjEwMCRSTVNFLG4xMDAwJFJNU0UsbjEwMDAwJFJNU0UsbjEwMDAwMCRSTVNFKSwNCiAgIG4gPSBmYWN0b3IoYygibj0xMDAiLCJuPTEwMDAiLCJuPTEwMDAwIiwibj0xMDAwMDAiKSkpICU+JQ0KICBnZ3Bsb3QoYWVzKHg9bix5PVJNU0UpKSArIGdlb21fcG9pbnQoKSArIHRoZW1lX2J3KCkgKyBnZW9tX2hsaW5lKHlpbnRlcmNlcHQgPSAwKQ0KYGBgDQoNCg0KPGJyPg0KPGJyPg0KDQojIEJvbnVzIGFzc2lnbm1lbnQ6IFlvdXIgdHVybiAoNVApDQoNCkRyYXcgYWxzbyBzb21ldGhpbmcgdXNpbmcgdHdvIHBvdGVudGlhbCBvdXRjb21lIGZ1bmN0aW9ucyAoZWl0aGVyIHNvbWV0aGluZyBkaWZmZXJlbnQgb3IgYmVhdXRpZnkgbXkgdHJlZSkgYW5kIGNoZWNrIGhvdyB3ZWxsIGNhdXNhbCBmb3Jlc3RzIGNhbiBhcHByb3hpbWF0ZSB0aGUgQ0FURSBmdW5jdGlvbiByZXN1bHRpbmcgZnJvbSB5b3VyIGRyYXdpbmcuIFlvdSBjYW4gcmV1c2UgYW5kL29yIG1vZGlmeSB0aGUgY29kZSBhYm92ZSBmb3IgdGhpcyBwdXJwb3NlLiBTZW5kIG1lIHlvdSBzb2x1dGlvbnMgYnkgMjIuMTIuDQoNCg==