Bonus assignment (5P)

Draw something using two potential outcome functions and check how well causal forests can approximate the CATE function resulting from your drawing. You only need to modify the bold labels and the first code chunk. Send me your solutions via mail by 22.12. and/or post your results under my Bluesky post.

Draw

I have drawn a (stylized) christmas tree using two potential outcomes functions.

This defines the functions and plots them. Your task is to modify this part:

if (!require("tidyverse")) install.packages("tidyverse", dependencies = TRUE); library(tidyverse)
if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
if (!require("patchwork")) install.packages("patchwork", dependencies = TRUE); library(patchwork)

# Define the functions
cw = 0.05
ch = 0.6
m1 = function(x){(x < -0.8) * (3*x + 3) + ((x >= -0.8) & (x < (-0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(-0.8+cw)) & (x < -0.5)) * (3*x + 3) + ((x >= -0.5) & (x < (-0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(-0.5+cw)) & (x < -0.2)) * (3*x + 3) + ((x >= -0.2) & (x < (-0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(-0.2+cw)) & (x < 0)) * (3*x + 3) + (x >= 0 & x < 0.2) * (-3*x + 3) + ((x >= 0.2) & (x < (0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(0.2+cw)) & (x < 0.5)) * (-3*x + 3) + ((x >= 0.5) & (x < (0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(0.5+cw)) & (x < 0.8)) * (-3*x + 3) + ((x >= 0.8) & (x < (0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(0.8+cw))) * (-3*x + 3)  }
m0 = function(x){0 * ((x<(-0.7)) + (x>(0.07))) 
                - 0.8 *  ((x>(-0.07) & x<(0.07))) }

# Plot the two potential outcome fcts
g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="forestgreen") + 
  stat_function(fun=m0,size=1,colour="forestgreen") + ylab("Y(w)") + xlab("X1") + theme_minimal()
g2

Now plot the resulting CATE:

tau = function(x){m1(x) - m0(x)}
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1") + theme_minimal()
g3



Causal Forest meets causal christmas tree

Now we investigate how well causal forests approximate the resulting CATE function for different sample sizes.

To this end we define a little function that takes the two potential outcome functions as inputs, draws a random treatment \(W \sim Bernoulli(1/2)\) and adds standard normal noise to the outcome. Then it runs for sample sizes 100/1000/10000/100000 and we observe how visual fit and RMSE improve (runs about 15 minutes on my laptop, but feel free to change it to reduce computation time).

cf_estimation = function(m1,m0,n,p=2,...) {
  # Get CATE function
  tau = function(x){m1(x) - m0(x)}
  
  # Draw sample
  X = matrix(runif(n*p,-1,1),ncol=p)
  W = rbinom(n,1,1/2)
  Y = W*m1(X[,1]) + (1-W)*m0(X[,1]) + rnorm(n,0,1)
  
  # Run CF
  cf = causal_forest(X, Y, W, ...)
  cates = predict(cf)$predictions

  # Plot
  g = data.frame(x=X[,1],y=cates) %>% ggplot() + geom_point(aes(x=x,y=y),shape="square",color="blue") +
    stat_function(fun=tau,size=1) + ylab("CATE") + ggtitle(paste0("n=",toString(n)))
  
  # RMSE
  rmse = sqrt(mean((cates - tau(X[,1]))^2))

  # Return results
  list("g" = g,"RMSE" = rmse)
}

set.seed(1234)
n100 = cf_estimation(m1,m0,100,tune.parameters = "all")
n1000 = cf_estimation(m1,m0,1000,tune.parameters = "all")
n10000 = cf_estimation(m1,m0,10000,tune.parameters = "all")
n100000 = cf_estimation(m1,m0,100000,tune.parameters = "all")
(n100$g | n1000$g) / (n10000$g | n100000$g)


data.frame(RMSE = c(n100$RMSE,n1000$RMSE,n10000$RMSE,n100000$RMSE),
   n = factor(c("n=100","n=1000","n=10000","n=100000"))) %>%
  ggplot(aes(x=n,y=RMSE)) + geom_point() + theme_bw() + geom_hline(yintercept = 0)



LS0tDQp0aXRsZTogIkNhdXNhbCBDaHJpc3RtYXMgVHJlZSBDaGFsbGVuZ2UiDQojIGF1dGhvcjogIllvdXIgbmFtZSAoc3R1ZGVudCBJRCkiDQpvdXRwdXQ6IA0KICBodG1sX25vdGVib29rOg0KICAgIHRvYzogdHJ1ZQ0KICAgIHRvY19mbG9hdDogdHJ1ZQ0KICAgIGNvZGVfZm9sZGluZzogc2hvdw0KLS0tDQoNCg0KIyBCb251cyBhc3NpZ25tZW50ICg1UCkNCg0KRHJhdyBzb21ldGhpbmcgdXNpbmcgdHdvIHBvdGVudGlhbCBvdXRjb21lIGZ1bmN0aW9ucyBhbmQgY2hlY2sgaG93IHdlbGwgY2F1c2FsIGZvcmVzdHMgY2FuIGFwcHJveGltYXRlIHRoZSBDQVRFIGZ1bmN0aW9uIHJlc3VsdGluZyBmcm9tIHlvdXIgZHJhd2luZy4gWW91IG9ubHkgbmVlZCB0byBtb2RpZnkgdGhlIGJvbGQgbGFiZWxzIGFuZCB0aGUgZmlyc3QgY29kZSBjaHVuay4gU2VuZCBtZSB5b3VyIHNvbHV0aW9ucyB2aWEgbWFpbCBieSAyMi4xMi4gYW5kL29yIHBvc3QgeW91ciByZXN1bHRzIHVuZGVyIG15IFtCbHVlc2t5XShodHRwczovL2Jza3kuYXBwL3Byb2ZpbGUvbWNrbmF1cy5ic2t5LnNvY2lhbC9wb3N0LzNsZDRiYWg3djdrMnQpIHBvc3QuDQoNCiMjIERyYXcNCg0KSSBoYXZlIGRyYXduIGEgKiooc3R5bGl6ZWQpIGNocmlzdG1hcyB0cmVlKiogdXNpbmcgdHdvIHBvdGVudGlhbCBvdXRjb21lcyBmdW5jdGlvbnMuDQoNClRoaXMgZGVmaW5lcyB0aGUgZnVuY3Rpb25zIGFuZCBwbG90cyB0aGVtLiBZb3VyIHRhc2sgaXMgdG8gbW9kaWZ5IHRoaXMgcGFydDoNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUYsZmlnLmhlaWdodD02LCBmaWcud2lkdGg9NH0NCmlmICghcmVxdWlyZSgidGlkeXZlcnNlIikpIGluc3RhbGwucGFja2FnZXMoInRpZHl2ZXJzZSIsIGRlcGVuZGVuY2llcyA9IFRSVUUpOyBsaWJyYXJ5KHRpZHl2ZXJzZSkNCmlmICghcmVxdWlyZSgiZ3JmIikpIGluc3RhbGwucGFja2FnZXMoImdyZiIsIGRlcGVuZGVuY2llcyA9IFRSVUUpOyBsaWJyYXJ5KGdyZikNCmlmICghcmVxdWlyZSgicGF0Y2h3b3JrIikpIGluc3RhbGwucGFja2FnZXMoInBhdGNod29yayIsIGRlcGVuZGVuY2llcyA9IFRSVUUpOyBsaWJyYXJ5KHBhdGNod29yaykNCg0KIyBEZWZpbmUgdGhlIGZ1bmN0aW9ucw0KY3cgPSAwLjA1DQpjaCA9IDAuNg0KbTEgPSBmdW5jdGlvbih4KXsoeCA8IC0wLjgpICogKDMqeCArIDMpICsgKCh4ID49IC0wLjgpICYgKHggPCAoLTAuOCtjdykpKSAqICgtMC44KjMgKyAzICsgY2gpICsgKCh4PigtMC44K2N3KSkgJiAoeCA8IC0wLjUpKSAqICgzKnggKyAzKSArICgoeCA+PSAtMC41KSAmICh4IDwgKC0wLjUrY3cpKSkgKiAoLTAuNSozICsgMyArIGNoKSArICgoeD4oLTAuNStjdykpICYgKHggPCAtMC4yKSkgKiAoMyp4ICsgMykgKyAoKHggPj0gLTAuMikgJiAoeCA8ICgtMC4yK2N3KSkpICogKC0wLjIqMyArIDMgKyBjaCkgKyAoKHg+KC0wLjIrY3cpKSAmICh4IDwgMCkpICogKDMqeCArIDMpICsgKHggPj0gMCAmIHggPCAwLjIpICogKC0zKnggKyAzKSArICgoeCA+PSAwLjIpICYgKHggPCAoMC4yK2N3KSkpICogKC0wLjIqMyArIDMgKyBjaCkgKyAoKHg+KDAuMitjdykpICYgKHggPCAwLjUpKSAqICgtMyp4ICsgMykgKyAoKHggPj0gMC41KSAmICh4IDwgKDAuNStjdykpKSAqICgtMC41KjMgKyAzICsgY2gpICsgKCh4PigwLjUrY3cpKSAmICh4IDwgMC44KSkgKiAoLTMqeCArIDMpICsgKCh4ID49IDAuOCkgJiAoeCA8ICgwLjgrY3cpKSkgKiAoLTAuOCozICsgMyArIGNoKSArICgoeD4oMC44K2N3KSkpICogKC0zKnggKyAzKSAgfQ0KbTAgPSBmdW5jdGlvbih4KXswICogKCh4PCgtMC43KSkgKyAoeD4oMC4wNykpKSANCiAgICAgICAgICAgICAgICAtIDAuOCAqICAoKHg+KC0wLjA3KSAmIHg8KDAuMDcpKSkgfQ0KDQojIFBsb3QgdGhlIHR3byBwb3RlbnRpYWwgb3V0Y29tZSBmY3RzDQpnMiA9IGRhdGEuZnJhbWUoeCA9IGMoLTEsIDEpKSAlPiUgZ2dwbG90KGFlcyh4KSkgKyBzdGF0X2Z1bmN0aW9uKGZ1bj1tMSxzaXplPTEsY29sb3VyPSJmb3Jlc3RncmVlbiIpICsgDQogIHN0YXRfZnVuY3Rpb24oZnVuPW0wLHNpemU9MSxjb2xvdXI9ImZvcmVzdGdyZWVuIikgKyB5bGFiKCJZKHcpIikgKyB4bGFiKCJYMSIpICsgdGhlbWVfbWluaW1hbCgpDQpnMg0KYGBgDQoNCk5vdyBwbG90IHRoZSByZXN1bHRpbmcgQ0FURToNCmBgYHtyLG1lc3NhZ2U9Rix3YXJuaW5nPUYsZmlnLmhlaWdodD02LCBmaWcud2lkdGg9NH0NCnRhdSA9IGZ1bmN0aW9uKHgpe20xKHgpIC0gbTAoeCl9DQpnMyA9IGRhdGEuZnJhbWUoeCA9IGMoLTEsIDEpKSAlPiUgZ2dwbG90KGFlcyh4KSkgKyBzdGF0X2Z1bmN0aW9uKGZ1bj10YXUsc2l6ZT0xKSArIHlsYWIoIkNBVEUiKSArIHhsYWIoIlgxIikgKyB0aGVtZV9taW5pbWFsKCkNCmczDQpgYGANCg0KPGJyPg0KPGJyPg0KDQojIyBDYXVzYWwgRm9yZXN0IG1lZXRzICoqY2F1c2FsIGNocmlzdG1hcyB0cmVlKioNCg0KTm93IHdlIGludmVzdGlnYXRlIGhvdyB3ZWxsIGNhdXNhbCBmb3Jlc3RzIGFwcHJveGltYXRlIHRoZSByZXN1bHRpbmcgQ0FURSBmdW5jdGlvbiBmb3IgZGlmZmVyZW50IHNhbXBsZSBzaXplcy4NCg0KVG8gdGhpcyBlbmQgd2UgZGVmaW5lIGEgbGl0dGxlIGZ1bmN0aW9uIHRoYXQgdGFrZXMgdGhlIHR3byBwb3RlbnRpYWwgb3V0Y29tZSBmdW5jdGlvbnMgYXMgaW5wdXRzLCBkcmF3cyBhIHJhbmRvbSB0cmVhdG1lbnQgJFcgXHNpbSBCZXJub3VsbGkoMS8yKSQgYW5kIGFkZHMgc3RhbmRhcmQgbm9ybWFsIG5vaXNlIHRvIHRoZSBvdXRjb21lLiBUaGVuIGl0IHJ1bnMgZm9yIHNhbXBsZSBzaXplcyAxMDAvMTAwMC8xMDAwMC8xMDAwMDAgYW5kIHdlIG9ic2VydmUgaG93IHZpc3VhbCBmaXQgYW5kIFJNU0UgaW1wcm92ZSAocnVucyBhYm91dCAxNSBtaW51dGVzIG9uIG15IGxhcHRvcCwgYnV0IGZlZWwgZnJlZSB0byBjaGFuZ2UgaXQgdG8gcmVkdWNlIGNvbXB1dGF0aW9uIHRpbWUpLg0KDQpgYGB7cixtZXNzYWdlPUYsd2FybmluZz1GfQ0KY2ZfZXN0aW1hdGlvbiA9IGZ1bmN0aW9uKG0xLG0wLG4scD0yLC4uLikgew0KICAjIEdldCBDQVRFIGZ1bmN0aW9uDQogIHRhdSA9IGZ1bmN0aW9uKHgpe20xKHgpIC0gbTAoeCl9DQogIA0KICAjIERyYXcgc2FtcGxlDQogIFggPSBtYXRyaXgocnVuaWYobipwLC0xLDEpLG5jb2w9cCkNCiAgVyA9IHJiaW5vbShuLDEsMS8yKQ0KICBZID0gVyptMShYWywxXSkgKyAoMS1XKSptMChYWywxXSkgKyBybm9ybShuLDAsMSkNCiAgDQogICMgUnVuIENGDQogIGNmID0gY2F1c2FsX2ZvcmVzdChYLCBZLCBXLCAuLi4pDQogIGNhdGVzID0gcHJlZGljdChjZikkcHJlZGljdGlvbnMNCg0KICAjIFBsb3QNCiAgZyA9IGRhdGEuZnJhbWUoeD1YWywxXSx5PWNhdGVzKSAlPiUgZ2dwbG90KCkgKyBnZW9tX3BvaW50KGFlcyh4PXgseT15KSxzaGFwZT0ic3F1YXJlIixjb2xvcj0iYmx1ZSIpICsNCiAgICBzdGF0X2Z1bmN0aW9uKGZ1bj10YXUsc2l6ZT0xKSArIHlsYWIoIkNBVEUiKSArIGdndGl0bGUocGFzdGUwKCJuPSIsdG9TdHJpbmcobikpKQ0KICANCiAgIyBSTVNFDQogIHJtc2UgPSBzcXJ0KG1lYW4oKGNhdGVzIC0gdGF1KFhbLDFdKSleMikpDQoNCiAgIyBSZXR1cm4gcmVzdWx0cw0KICBsaXN0KCJnIiA9IGcsIlJNU0UiID0gcm1zZSkNCn0NCg0Kc2V0LnNlZWQoMTIzNCkNCm4xMDAgPSBjZl9lc3RpbWF0aW9uKG0xLG0wLDEwMCx0dW5lLnBhcmFtZXRlcnMgPSAiYWxsIikNCm4xMDAwID0gY2ZfZXN0aW1hdGlvbihtMSxtMCwxMDAwLHR1bmUucGFyYW1ldGVycyA9ICJhbGwiKQ0KbjEwMDAwID0gY2ZfZXN0aW1hdGlvbihtMSxtMCwxMDAwMCx0dW5lLnBhcmFtZXRlcnMgPSAiYWxsIikNCm4xMDAwMDAgPSBjZl9lc3RpbWF0aW9uKG0xLG0wLDEwMDAwMCx0dW5lLnBhcmFtZXRlcnMgPSAiYWxsIikNCmBgYA0KDQpgYGB7cixtZXNzYWdlPUYsd2FybmluZz1GfQ0KKG4xMDAkZyB8IG4xMDAwJGcpIC8gKG4xMDAwMCRnIHwgbjEwMDAwMCRnKQ0KDQpkYXRhLmZyYW1lKFJNU0UgPSBjKG4xMDAkUk1TRSxuMTAwMCRSTVNFLG4xMDAwMCRSTVNFLG4xMDAwMDAkUk1TRSksDQogICBuID0gZmFjdG9yKGMoIm49MTAwIiwibj0xMDAwIiwibj0xMDAwMCIsIm49MTAwMDAwIikpKSAlPiUNCiAgZ2dwbG90KGFlcyh4PW4seT1STVNFKSkgKyBnZW9tX3BvaW50KCkgKyB0aGVtZV9idygpICsgZ2VvbV9obGluZSh5aW50ZXJjZXB0ID0gMCkNCmBgYA0KDQoNCjxicj4NCjxicj4NCg0KDQoNCg==