Causal christmas tree challenge 2022

I have drawn a (stylized) christmas tree using two potential outcomes functions and investigate how well causal forests approximate the resulting CATE function for different sample sizes.

library(tidyverse)
library(grf)
library(patchwork)

set.seed(1234)

# Define the functions
cw = 0.05
ch = 0.6
e = function(x){1/2}
m1 = function(x){(x < -0.8) * (3*x + 3) + ((x >= -0.8) & (x < (-0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(-0.8+cw)) & (x < -0.5)) * (3*x + 3) + ((x >= -0.5) & (x < (-0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(-0.5+cw)) & (x < -0.2)) * (3*x + 3) + ((x >= -0.2) & (x < (-0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(-0.2+cw)) & (x < 0)) * (3*x + 3) + (x >= 0 & x < 0.2) * (-3*x + 3) + ((x >= 0.2) & (x < (0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(0.2+cw)) & (x < 0.5)) * (-3*x + 3) + ((x >= 0.5) & (x < (0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(0.5+cw)) & (x < 0.8)) * (-3*x + 3) + ((x >= 0.8) & (x < (0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(0.8+cw))) * (-3*x + 3)  }
m0 = function(x){0 * ((x<(-0.7)) + (x>(0.07))) 
                - 0.8 *  ((x>(-0.07) & x<(0.07))) }
tau = function(x){m1(x) - m0(x)}

# Plot the two potential outcome fcts
g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="forestgreen") + 
  stat_function(fun=m0,size=1,colour="forestgreen") + ylab("Y(w)") + xlab("X1") + theme_bw()
# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1") + theme_bw()
g2

g3



Causal Forest meets causal christmas tree

cf_meets_cct = function(n,e,p=2,...) {
  # Define the functions
  cw = 0.05
  ch = 0.6
  e = function(x){1/2}
  m1 = function(x){(x < -0.8) * (3*x + 3) + ((x >= -0.8) & (x < (-0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(-0.8+cw)) & (x < -0.5)) * (3*x + 3) + ((x >= -0.5) & (x < (-0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(-0.5+cw)) & (x < -0.2)) * (3*x + 3) + ((x >= -0.2) & (x < (-0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(-0.2+cw)) & (x < 0)) * (3*x + 3) + (x >= 0 & x < 0.2) * (-3*x + 3) + ((x >= 0.2) & (x < (0.2+cw))) * (-0.2*3 + 3 + ch) + ((x>(0.2+cw)) & (x < 0.5)) * (-3*x + 3) + ((x >= 0.5) & (x < (0.5+cw))) * (-0.5*3 + 3 + ch) + ((x>(0.5+cw)) & (x < 0.8)) * (-3*x + 3) + ((x >= 0.8) & (x < (0.8+cw))) * (-0.8*3 + 3 + ch) + ((x>(0.8+cw))) * (-3*x + 3)  }
  m0 = function(x){0 * ((x<(-0.7)) + (x>(0.07))) 
                - 0.8 *  ((x>(-0.07) & x<(0.07))) }
  tau = function(x){m1(x) - m0(x)}
  
  # Draw sample
  X = matrix(runif(n*p,-1,1),ncol=p)
  W = rbinom(n,1,e(X[,1]))
  Y = W*m1(X[,1]) + (1-W)*m0(X[,1]) + rnorm(n,0,1)
  
  # Run CF
  cf = causal_forest(X, Y, W, ...)
  cates = predict(cf)$predictions

  # Plot
  g = data.frame(x=X[,1],y=cates) %>% ggplot() + geom_point(aes(x=x,y=y),shape="square",color="blue") +
    stat_function(fun=tau,size=1) + ylab("CATE") + ggtitle(paste0("n=",toString(n)))
  
  # RMSE
  rmse = sqrt(mean((cates - tau(X[,1]))^2))

  # Return results
  list("g" = g,"RMSE" = rmse)
}

n100 = cf_meets_cct(100,tune.parameters = "all")
n1000 = cf_meets_cct(1000,tune.parameters = "all")
n10000 = cf_meets_cct(10000,tune.parameters = "all")
n100000 = cf_meets_cct(100000,tune.parameters = "all")
(n100$g | n1000$g) / (n10000$g | n100000$g)


data.frame(RMSE = c(n100$RMSE,n1000$RMSE,n10000$RMSE,n100000$RMSE),
   n = factor(c("n=100","n=1000","n=10000","n=100000"))) %>%
  ggplot(aes(x=n,y=RMSE)) + geom_point() + theme_bw()



1) Your turn (5P)

Draw also something using two potential outcome functions (either something different or beautify my tree) and check how well causal forests can approximate the CATE function resulting from your drawing.



Student solutions in alphabetical order

Maren Baumgärtner

As Christmas is the season to spread love, I decided to plot a heart.


e = function(x){1/2}


m1 <- function(x){1.8*sqrt(1-(abs(x)-1)^2)} # upper part receiving treatment
m0 <- function(x){acos(1-abs(x))-pi} # lower part not receiving treatment

tau = function(x){m1(x) - m0(x)}

# plot heart
data.frame(x = c(-2, 2)) %>% ggplot(aes(x))  +
  stat_function(fun = m1,size=1, colour = "#B22222") +
  stat_function(fun = m0,size=1, colour = "#B22222") +
  ylab("Y(w)") + xlab("X1") +
  theme_bw() + xlim(-2.5,2.5)


# plot CATE
data.frame(x = c(-2, 2)) %>% ggplot(aes(x)) + stat_function(fun = tau,size=1) + theme_bw() + xlim(-2.2,2.2) +  ylab("CATE")


Stefan Glaisner

# circle function
cir <- function(x){sqrt(1 - x^2)}
cut0 <- 0.75
cut1 <- 0.5
cut2 <- 0.25
slope0 <- cir(cut0) / cut2
slope1 <- cir(cut1) / cut2
slope2 <- cir(cut2) / cut2


e <- function(x){1/2}
m1 <- function(x){
  ((x >= -1 & x < -cut2) * cir(x) +
     (x >= -cut2 & x < 0) * (cir(cut2) - slope2*(x+cut2)) +
     (x >= 0 & x <= cut2) * (slope2*x) +
     (x > cut2 & x <= 1) * cir(x)
   )}
m0 <- function(x){
  ((x >= -1 & x < -cut0) * -cir(x) +
     (x >= -cut0 & x < -cut1) * (-cir(cut0) + slope0*(x+cut0)) +
     (x >= -cut1 & x < -cut2) * (-slope1*(x+cut1)) +
     (x >= -cut2 & x < cut2) * -cir(x) +
     (x >= cut2 & x < cut1) * (-cir(cut2) + slope1*(x-cut2)) +
     (x >= cut1 & x < cut0) * (-slope0*(x-cut1)) +
     (x >= cut0 & x <= 1) * -cir(x))}

tau <- function(x){m1(x) - m0(x)}

data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="forestgreen") + 
  stat_function(fun=m0,size=1,colour="forestgreen") + ylab("Y(w)") + xlab("X1")


data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1")


Stefan Grochowski

# Define e
e = function(x){1/2}

# Define the first outcome function 
m1 = function(x){abs(x) + sqrt(1-x^2)}
# Define the second potential outcome function 
m0 = function(x){abs(x) - sqrt(1-x^2)}
# Compute tau as the difference of the two potential outcome functions
tau = function(x){m1(x) - m0(x)}

# Plot the two potential outcome fcts (most important)
g4 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="red") + 
  stat_function(fun=m0,size=1,colour="darkred") + ylab("Y(w)") + xlab("X1")
# Plot CATE function
g5 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1")

# Show first plot
g4


# Show the resulting CATE
g5


Jacqueline Gut

A star, and the CATE: Batman

e = function(x){1/2}
m1 = function(x){(x< -1)*(1) + ((x >= -1) & (x < -0.5)) * (1) + ((x>=(-0.5)) & (x < 0)) *(2*x +2)+((x>=(0)) & (x < 0.5)) * (-2*x +2)+((x>=(0.5))) * (1) }
m0 = function(x){(x < (-1))*(1) +((-x))*((x>=(-1))& (x<(-0.625))) + (1.4*x+0.375) *((x>(-0.625) & x<0)) + (-1.4*x+0.375) *((x>=(0)) & (x<(0.625)))+ (x) *((x>=(0.625)) & (x<(1))) +(x>= 1)*(1) }

tau = function(x){m1(x) - m0(x)}

# Plot the two potential outcome fcts (most important)
g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="yellow") + 
  stat_function(fun=m0,size=1,colour="yellow") + ylab("Y(w)") + xlab("X1") + theme_bw()

# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1") + theme_bw()

g2

g3


Sophia Herrmann

Christmas Ball

# Define the functions
cw1 <- 0.4
cw2 <- 0.15
cw3 <- 0.05

ch1 <- 0.3
ch2 <- 0.2
ch3 <- 1

e = function(x){1/2}
m_0 = function(x){-sqrt(1-(x^2))}

m_1 = function(x){(x < -cw1) * (sqrt(1-(x^2))) +
    ((x >= -cw1) & (x <= -cw2)) * (sqrt(1-(0.15^2)) + ch1) +
    ((x > -cw2) & (x <= -cw3)) * ( (sqrt(1-(0.15^2)) + ch1) + ch2) +
    ((x > -cw3) & (x <= cw3)) *  (((sqrt(1-(0.15^2)) + ch1) + ch2) + ch3) +
    ((x > cw3) & (x <= cw2)) *  ( ( (sqrt(1-(0.15^2)) + ch1) + ch2)) +
    ((x > cw2) & (x <= cw1)) *  ( ( (sqrt(1-(0.15^2)) + ch1))) +
    (x > cw1) * sqrt(1-(x^2))} 

tau_christmas_ball = function(x){m_1(x) - m_0(x)}

# Plot the two PO fcts (most important)
#g_1 <- data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m_1,size=1,colour="forestgreen") 
#g_2 <- data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m_0,size=1,colour="forestgreen")

g_christmas_ball <- data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m_1,size=1,colour="forestgreen") +
  stat_function(fun=m_0,size=1,colour="forestgreen") + ylab("Y(w)") + xlab("X1") #+
    ylim(1.7, 2)
<ScaleContinuousPosition>
 Range:  
 Limits:  1.7 --    2
# Plot CATE function
g_tau_christmas_ball <- data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau_christmas_ball,size=1) + ylab("CATE") + xlab("X1")

# g1 
#g_2
g_christmas_ball

g_tau_christmas_ball


Kevin Kopp

# Define the functions
cw = 0.001  # horizontal
ch = 3  # vertical
e = function(x){1/2}
m1 = function(x){ ((x >= (-1.2)) & (x < (-0.999))) * (15*x + 15) + ((x >= (-0.999)) & (x < (-0.8))) * (-15*x - 9) + ((x >= (-0.8)) & (x < (-0.6))) * (15*x + 15) + ((x >= (-0.6)) & (x < -0.4)) * (-15*x - 3) + ((x >= (-0.4)) & (x < (-0.2))) * (15*x + 9) + ((x>=(-0.2)) & (x < 0)) * (-15*x + 3) + (x >= 0 & x < 0.2) * (15*x + 3) + (x >= 0.2 & x < 0.4) * (-15*x + 9) + (x >= 0.4 & x < 0.6) * (15*x - 3) + (x >= 0.6 & x < 0.8) * (-15*x +15) + (x >= 0.8 & x < 1) * (15*x - 9)}
m0 = function(x){0 * ((x<(-0.7)) + (x>(0.07)))}
tau = function(x){m1(x) - m0(x)}

# Plot propensity score
# g1 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=e,size=1) + ylab("e") + xlab("X1")
# Plot the two potential outcome fcts (most important)
g2 = data.frame(x = c(-1.0, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="gold") + 
  stat_function(fun=m0,size=1,colour="gold") + ylab("Y(w)") + xlab("X1") + theme_bw()
# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1")

g2_picture = data.frame(x = c(-1.0, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="black", geom = "area", fill = "gold") + stat_function(fun=m0,size=1,colour="black") + ylab("Y(w)") + xlab("X1") + theme_bw()

g2_picture


g3 + theme_bw()


Alexandros Parginos Dös

Causal heart

e = function(x){1/2}

# Get a (very) edgy heart
m1 = function(x){ (x<(-0.5))*(1+x) + ((x==(-0.5)))*(0.5) + (x>(-0.5) & x<(0))*(-x)  + 
                  (x==(0))*0 +
                  (x>(0) & x<(0.5))*(x) + (x==(0.5))*0.5 + (x>(0.5) )*(-x+1) }
m0 = function(x){- (1*x + 1) *  (x<(0)) - 1* ((x==(0)))  + (-1+1*x) *  (x>(0)) }
tau = function(x){m1(x) - m0(x)}

# Plot propensity score
# g1 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=e,size=1) + ylab("e") + xlab("X1")
# Plot the two potential outcome fcts (most important)
g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="red") + 
  stat_function(fun=m0,size=1,colour="red") + ylab("Y(w)") + xlab("X1")
# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1")

# g1 
g2


Henri Pfleiderer

# Define the functions
jump = 0.1
small_jump = 0.01
very_small_jump = 0.001
height_bottle_neck = 10
width_bottle_neck = 1
ch = 0.6
e = function(x){1/2}
m1 = function(x){(x<6-small_jump)*(0) + ((x>=6)&(x<9 - width_bottle_neck/2))*((-8/9)*x^2 + 16*x - 18) + ((x>= 9 - width_bottle_neck/2)&(x<9 + width_bottle_neck/2))*(64 + height_bottle_neck) + ((x>=9 + width_bottle_neck/2)&(x<12))*((-8/9)*x^2 + 16*x - 18) + ((x>12)&(x<15))*0 + ((x>=15)&(x<19))*(sin(5*x) + 30) + (x>=19)*30}
m0 = function(x){(x<15)*0 + ((x>=15)&(x<17- jump))*(-7.5*x + 142.5) + ((x>=17- jump)&(x<17+ jump))*0 +  ((x<19)&(x>=17+ jump))*(7.5*x + -112.5) + (x>=19)*30}
tau = function(x){m1(x) - m0(x)}

# Plot propensity score
# g1 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=e,size=1) + ylab("e") + xlab("X1")
# Plot the two potential outcome fcts (most important)

g2 = data.frame(x = c(0, 25)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="hotpink") + 
  stat_function(fun=m0,size=1,colour="purple") + ylab("Y(w)") + xlab("X1")+ ggtitle("Sektflasche und Glas mit Schaum, ein Stillleben von Henri Pfleiderer, 2022")
# Plot CATE function
g3 = data.frame(x = c(0, 25)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1")

# g1 
g2

g3


Stella Rotter

# Define the functions
e = function(x){1/2}
m1 = function(x){(x < -1) * (x + 1) + 
                 ((x >= -0.7) & (x < -0.5)) * (-0.7*3 + 3) + (x > -0.7) +
                 ((x >= -0.62) & (x < -0.57)) * (0.4*1.2 + 0.1) +
                 ((x >= 0.5) & (x < 0.5)) * (0.4*2 + 3) + (x > 0.5) +
                 ((x >= 0.6) & (x < 0.7)) * (0.5*1.2 + 0.5) + 
                 ((x >= 0.8) & (x <= 1)) * (0.5*3 -3.5)}
m0 = function(x){0}
tau = function(x){m1(x) - m0(x)}

g2 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=m1,size=1,colour="navajowhite4") + 
  stat_function(fun=m0,size=1,colour="navajowhite4") + ylab("Y(w)") + xlab("X1") + ggtitle("The Ulmer Münster")
# Plot CATE function
g3 = data.frame(x = c(-1, 1)) %>% ggplot(aes(x)) + stat_function(fun=tau,size=1) + ylab("CATE") + xlab("X1")

# g1 
g2

g3


ChatGPT

Not very successful, but let’s see what it does next year…

# Set the plot dimensions
plot.new()
plot.window(xlim = c(-2, 2), ylim = c(-2, 2))

# Create a sequence of values from 0 to 2*pi in increments of 0.01
t <- seq(0, 2*pi, 0.01)

# Calculate the x and y coordinates for the tree
x <- sin(t)
y <- cos(t) - 1

# Use the polygon function to fill in the tree
polygon(x, y, col = "darkgreen")

# Use the lines function to draw the trunk
lines(c(0, 0), c(-1.5, -0.5), col = "brown", lwd = 2)

