¿Y si ... ? Parte II

Volvamos a nuestro ejemplo tonto, dónde habíamos visto que el T-learner cuando el modelo base es un modelo lineal equivale a tener un modelo saturado (con interacciones).

En estos de los “metalearners” tenemos entre otros, los T-learners vistos en el post anterior , los S-learner y los X-learners.

Los S-learners no es más que usar un solo modelo “Single” para estimar el Conditional Average Treatment Effect , CATE.

Usando el mismo ejemplo sencillo, se tiene que.

set.seed(155)

X <- rnorm(100, 10,1)
W <- rbinom(100, 1, 0.6)

# Me construyo la Y de forma que haya efectos principales e interacción
Y <- 4 + 2 * X + 2 * W + 2 * W * X + rnorm(100, 0, sd = 2)

df <- as.data.frame(cbind(Y,W,X))

df
##            Y W         X
## 1   48.78438 1 10.800067
## 2   25.28644 0 10.707605
## 3   28.39538 0  9.925625
## 4   47.60225 1 10.652555
## 5   46.72225 1  9.992698
## 6   55.15008 1 11.514759
## 7   40.46547 1  9.093717
## 8   22.17879 0  9.157972
## 9   49.44883 1  9.866499
## 10  51.21602 1 11.100414
## 11  46.90193 1 10.287350
## 12  22.88517 0  9.295653
## 13  39.44776 1  9.156142
## 14  40.78560 1  8.513496
## 15  48.04199 1 10.067613
## 16  47.80314 1  9.898276
## 17  25.33331 0 10.578513
## 18  24.15651 0  9.253759
## 19  25.13304 0 10.365123
## 20  25.23243 0 11.040849
## 21  30.45260 0 12.869587
## 22  44.82112 1  9.319895
## 23  25.11998 0  9.830254
## 24  19.99574 0  9.635928
## 25  43.48504 1  9.215349
## 26  41.14271 1  8.356523
## 27  22.81061 0  8.883480
## 28  25.58288 0  9.784855
## 29  44.41997 1  9.404844
## 30  27.84046 0 10.414529
## 31  39.59324 1  9.041776
## 32  51.28215 1 10.442391
## 33  38.53548 1  8.142158
## 34  21.95668 0  9.042216
## 35  46.84521 1  9.724798
## 36  43.87810 1  9.013322
## 37  42.12536 1  9.633154
## 38  45.74959 1 10.873450
## 39  18.78703 0  9.748465
## 40  21.79664 0 10.607739
## 41  37.35355 1  8.361663
## 42  22.53808 0 10.303852
## 43  42.48434 1  9.004360
## 44  49.39156 1 10.580300
## 45  47.92040 1 10.672659
## 46  48.76256 1 11.773249
## 47  23.67107 0  9.875302
## 48  48.76949 1  9.921954
## 49  41.39283 1  8.920843
## 50  42.49853 1  8.688555
## 51  48.09462 1 10.564605
## 52  44.45942 1  9.194570
## 53  45.84477 1  9.438857
## 54  41.94149 1  9.888696
## 55  47.26368 1  9.887931
## 56  51.42203 1 11.055223
## 57  39.17177 1  8.327467
## 58  51.15275 1 10.320770
## 59  50.40525 1 10.585048
## 60  42.49727 1  9.336601
## 61  28.05959 0 10.952144
## 62  49.10409 1 10.562264
## 63  27.15474 0 12.045244
## 64  19.24901 0  8.091111
## 65  47.67471 1 10.241636
## 66  24.39380 0 10.824896
## 67  26.49221 0 10.812256
## 68  38.77565 1  8.358974
## 69  45.05843 1  9.515578
## 70  52.28683 1 11.800317
## 71  23.36347 0  9.797133
## 72  26.84582 0 10.470713
## 73  42.10340 1  9.598281
## 74  39.43318 1  8.326351
## 75  44.69754 1  9.965926
## 76  48.71043 1 10.870054
## 77  24.30603 0  9.038770
## 78  24.54690 0 11.097281
## 79  22.08450 0 10.558284
## 80  51.71144 1 11.264590
## 81  53.69442 1 11.434979
## 82  26.79476 0 12.390173
## 83  40.80879 1  9.520336
## 84  43.63049 1 10.081028
## 85  20.06392 0  8.716013
## 86  41.11569 1  8.556393
## 87  24.45452 0  9.109263
## 88  24.05505 0 10.779678
## 89  41.82733 1  9.990715
## 90  53.17613 1 11.501511
## 91  49.50179 1 11.061493
## 92  20.42382 0  7.543992
## 93  41.57695 1  8.856854
## 94  50.83502 1 11.004920
## 95  41.66118 1  9.274137
## 96  47.30987 1 10.771928
## 97  20.74180 0  9.829798
## 98  24.39354 0 10.412418
## 99  53.71654 1 11.506078
## 100 51.22245 1 10.711256

S-learner

El S-learner sería estimar un sólo modelo y ver la diferencia (en esperanzas) en lo que estima el modelo para cuando W=1 versus lo que estima cuando W=0.

\(E[Y=y | W=1, X=x] - E[Y=y | W=0, X=x]\)

Si hacemos un modelo lineal en este ejemplo, cabe plantearse dos, uno con la interacción

mod_saturado <-  lm(Y ~ W *X , data = df)
summary(mod_saturado)
## 
## Call:
## lm(formula = Y ~ W * X, data = df)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -4.6786 -1.2138  0.1903  1.5419  4.6289 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)   6.9118     3.1354   2.204   0.0299 *  
## W            -1.8395     4.0511  -0.454   0.6508    
## X             1.6981     0.3085   5.504 3.08e-07 ***
## W:X           2.4096     0.4016   6.000 3.49e-08 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 2.011 on 96 degrees of freedom
## Multiple R-squared:  0.9689, Adjusted R-squared:  0.9679 
## F-statistic: 995.9 on 3 and 96 DF,  p-value: < 2.2e-16

Dónde la estimación CATE para un caso con X=14.

df_w0 <- data.frame(X = 14, W=0)
df_w1 <- data.frame(X = 14, W=1)

predict(mod_saturado, df_w1) - predict(mod_saturado, df_w0)
##        1 
## 31.89504

Que es lo mismo que haber considerado solo los coeficientes cuando W = 1

(cate1 <- coef(mod_saturado)[2] + coef(mod_saturado)[4] * 14 )
##        W 
## 31.89504

Esto ya lo habíamos visto en el post anterior. El tema es que hemos elegido como modelo base el modelo saturado pero podríamos haber elegido otro.

mod_efectos_ppal <- lm(Y ~ W  + X , data = df)
predict(mod_efectos_ppal, df_w1) - predict(mod_efectos_ppal, df_w0)
##        1 
## 22.33547

Y el CATE en este caso está subestimado ya que no hemos tenido en cuenta la interacción (que existe por construcción del efecto).

Podríamos haber elegido otro modelo, y obtener otra estimación del CATE. Usando un árbol por ejemplo, o en caso de tener más variables, cualquier modelo que se os ocurra.

library(rpart)

mod_arbol <-  rpart(Y ~ W  + X , data = df)
predict(mod_arbol, df_w1) - predict(mod_arbol, df_w0)
##        1 
## 27.88051

Total, que el S-learner es eso, usar un sólo modelo y obtener la diferencia entre lo que estima para cuando W = 1 y cuando W = 0.

X-learner

Los X-learner es una forma un poco más inteligente de usar los T-learners. Básicamente se trata de.

  • Estimamos dos modelos, uno con los datos cuando W=0 y otro cuando W=1. Los notamos por

\[\hat{\mu}_{0} = M_{1}(Y^0 \sim X^0)\] y por \[\hat{\mu}_{1} = M_{2}(Y^1 \sim X^1)\]

  • Ahora usamos esos modelos de la siguiente forma, para las observaciones que tengan W=0 utilizamos el modelo \(\hat{\mu}_{1}\), y para las observaciones con W=1 usamos el modelo que se ha estimado usando la otra parte de los datos \(\hat{\mu}_{0}\).

  • Calculamos para cada observación con W=0 la diferencia entre lo observado y lo estimado por el modelo \(\hat{\mu}_{1}\) y lo mismo para las observaciones con W=1. Así tenemos.

\[D_{i}^{0} = \hat{\mu}_{1}(X_{i}^{0})- Y_{i}^{0} \] y \[D_{i}^{1} = Y_{i}^{1} - \hat{\mu}_{0}(X_{i}^{0})\]

  • Volvemos a usar lo del T-learner pero esta vez sobre las diferencias obtenidas en el paso anterior

\[\hat{\tau}_1 = M_{3}(D^1 \sim X^1) \] \[\hat{\tau}_0 = M_{4}(D^0 \sim X^0) \]

  • Hacemos una combinación convexa para obtener

\[\hat{\tau}(x) = ps(x)\hat{\tau}_0(x) + (1- ps(x))\hat{\tau}_1(x) \] Dónde \(ps(x) \in [0,1]\) es una función de pesos con ciertas propiedades, normalmente se suele usar el propensity score, que básicamente es la estimación de la probabilidad de que cada observación pertenezca al tratamiento vs al control.

Y en nuestro ejemplo como sería.

Modelos 1 y 2 usando como modelos base un árbol por ejemplo.

m1 <- rpart(Y ~  X , data = df, subset = (W==0))
m2 <- rpart(Y ~ X , data = df, subset = (W==1))

Diferencias

Usamos modelo 1 para estimar cuando W=1 y el modelo 2 para estimar cuando W = 0

# Con el viejo R-base sería 
df$Difer[df$W==1] <- df$Y[df$W==1] - predict(m1, df[df$W==1, ])
head(df)
##          Y W         X    Difer
## 1 48.78438 1 10.800067 22.14350
## 2 25.28644 0 10.707605       NA
## 3 28.39538 0  9.925625       NA
## 4 47.60225 1 10.652555 22.79507
## 5 46.72225 1  9.992698 21.91507
## 6 55.15008 1 11.514759 28.50920

Y ahora para W=0

df$Difer[df$W==0] <-  predict(m2, df[df$W==0, ]) - df$Y[df$W==0] 
head(df)
##          Y W         X    Difer
## 1 48.78438 1 10.800067 22.14350
## 2 25.28644 0 10.707605 23.69278
## 3 28.39538 0  9.925625 17.87909
## 4 47.60225 1 10.652555 22.79507
## 5 46.72225 1  9.992698 21.91507
## 6 55.15008 1 11.514759 28.50920

Modelamos las diferencias

m3 <- rpart(Difer ~  X , data = df, subset = (W==1))
m4 <- rpart(Difer ~ X , data = df, subset = (W==0))

Combinamos

Modelo para propensity score

glm1 <- glm(W ~ X, data = df, family=binomial)
df$pesos <- predict(glm1, df, type = "response")
df$combinado <- df$pesos * predict(m4, df) + (1-df$pesos) * predict(m3, df) 

head(df[, c("Y", "W", "pesos", "combinado")])
##          Y W     pesos combinado
## 1 48.78438 1 0.6087519  24.38836
## 2 25.28644 0 0.6124915  24.39265
## 3 28.39538 0 0.6435533  22.05802
## 4 47.60225 1 0.6147118  24.39520
## 5 46.72225 1 0.6409317  22.05217
## 6 55.15008 1 0.5794443  25.32695

La estimación del CATE para nuestra nueva x sería

df_nueva_x <- data.frame(X = 14)

predict(glm1, df_nueva_x, type="response") * predict(m4, df_nueva_x) + (1-predict(glm1, df_nueva_x, type="response"))* predict(m3, df_nueva_x) 
##        1 
## 25.44921

Este ejemplo es muy sencillo, y supongo que habría que verlo con muchas más variables y utilizando modelos base más complejos.

No obstante, todo esto de los metalearners no tiene mucho sentido si el grado de solape entre la distribución de las X en el tratamiento y el control no es suficiente, cosa que se intenta arreglar un poco utilizando los propensity scores en el X-learner.

Extra, uso de causalml

En la librería causalml de Uber vienen implmentandos los metalearner entre otras cosas. Usando el mismo ejemplo veamos como se calcularía el CATE.

Nota: He sido incapaz de ver como predecir para mi nueva x, no hay o no he encontrado que funcione un método predict para aplicar el X learner a unos nuevos datos.

from causalml.inference.meta import BaseXRegressor
## The sklearn.utils.testing module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.
from sklearn.linear_model import LinearRegression
# llamamos al df que está en R
df_python = r.df[['Y','W','X','pesos']]
df_python
##             Y    W          X     pesos
## 0   48.784384  1.0  10.800067  0.608752
## 1   25.286438  0.0  10.707605  0.612492
## 2   28.395375  0.0   9.925625  0.643553
## 3   47.602247  1.0  10.652555  0.614712
## 4   46.722247  1.0   9.992698  0.640932
## ..        ...  ...        ...       ...
## 95  47.309873  1.0  10.771928  0.609891
## 96  20.741797  0.0   9.829798  0.647284
## 97  24.393540  0.0  10.412418  0.624340
## 98  53.716540  1.0  11.506078  0.579804
## 99  51.222453  1.0  10.711256  0.612344
## 
## [100 rows x 4 columns]
learner_x = BaseXRegressor(learner=LinearRegression())

X = df_python.X.values.reshape(-1,1)
y = df_python.Y.values
treatment = df_python.W.values
e = df_python.pesos.values
nueva_X = r.df_nueva_x['X'].values.reshape(-1,1)

# estimamos
cate_x = learner_x.fit_predict(X=X, treatment=treatment, y=y, p=e)

print(cate_x)
## [[24.18445071]
##  [23.96165333]
##  [22.07738827]
##  [23.8290041 ]
##  [22.23900667]
##  [25.90657902]
##  [20.07281545]
##  [20.22764413]
##  [21.93491718]
##  [24.90817005]
##  [22.94900375]
##  [20.55940033]
##  [20.22323467]
##  [18.67470923]
##  [22.41952385]
##  [22.01148778]
##  [23.65059347]
##  [20.4584526 ]
##  [23.13640524]
##  [24.76464061]
##  [29.17118624]
##  [20.61781408]
##  [21.84758166]
##  [21.37933037]
##  [20.36590009]
##  [18.29646419]
##  [19.56622485]
##  [21.73818594]
##  [20.82250793]
##  [23.25545554]
##  [19.9476568 ]
##  [23.32259103]
##  [17.77992951]
##  [19.94871811]
##  [21.59347327]
##  [19.87909366]
##  [21.37264782]
##  [24.36127525]
##  [21.65050016]
##  [23.72101678]
##  [18.30885076]
##  [22.98876753]
##  [19.85749887]
##  [23.65489924]
##  [23.8774463 ]
##  [26.52943858]
##  [21.9561297 ]
##  [22.06854152]
##  [19.65625568]
##  [19.09653297]
##  [23.61708013]
##  [20.31583121]
##  [20.90446626]
##  [21.98840222]
##  [21.98655993]
##  [24.7992766 ]
##  [18.2264522 ]
##  [23.02953222]
##  [23.66633862]
##  [20.65807062]
##  [24.55089614]
##  [23.61143985]
##  [27.18483957]
##  [17.65692503]
##  [22.83885017]
##  [24.24427817]
##  [24.21382276]
##  [18.30237035]
##  [21.08933438]
##  [26.59466223]
##  [21.76777266]
##  [23.39083619]
##  [21.28861592]
##  [18.22376141]
##  [22.17449599]
##  [24.35309297]
##  [19.94041482]
##  [24.90061955]
##  [23.60184813]
##  [25.3037691 ]
##  [25.71434126]
##  [28.01598546]
##  [21.10080014]
##  [22.45184837]
##  [19.16269587]
##  [18.77807334]
##  [20.1102733 ]
##  [24.1353215 ]
##  [22.23422848]
##  [25.87465564]
##  [24.81438579]
##  [16.33858204]
##  [19.50206724]
##  [24.67806615]
##  [20.5075564 ]
##  [24.11664769]
##  [21.84648138]
##  [23.25036905]
##  [25.88566055]
##  [23.97045205]]
 
comments powered by Disqus