🦸🦸🦸 | 2025-03-10 | Targeted Learning
Data \(O = (W, A, Y) \sim P_{0}\); observe \(n\) i.i.d. observations \(O_{1}, \ldots, O_{n}\)
Target parameter is the causal risk difference under causal assumptions: \[\begin{align*} \Psi(P_{0}) & = E_{W, 0}\left[ E_{0}(Y \mid A = 1, W) - E_{0}(Y \mid A = 0, W) \right] \\ & = \int_{\mathcal{W}} \left\{ \bar{Q}_{0}(1, w) - \bar{Q}_{0}(0, w) \right\} Q_{W, 0}(dw) \\ & = \Psi(Q_{0}) \end{align*}\] where \[\begin{align} \bar{Q}_{0}(A, W) := E_{0}(Y \mid A, W), \end{align}\] \(Q_{W, 0}\) is the marginal distribution of \(W\) under \(Q_{0}\) and \(Q_{0} = (\bar{Q}_{0}, Q_{W, 0})\).
Target parameter \(\Psi(P_{0})\) only depends on \(P_{0}\) through conditional mean \(\bar{Q}_{0}(A, W)\) and marginal distribution \(Q_{W, 0}\).
Empirical measure \(Q_{W, n}\) estimates marginal distribution \(Q_{W, 0}\) (“nonparametric maximum likelihood estimator”); focus on estimator \(\bar{Q}_{n}\) of \(\bar{Q}_{0}\)
Plug \(Q_{n} = (\bar{Q}_{n}, Q_{W, n})\) into \(\Psi(\cdot)\) (“non-targeted substitution estimator”) \[\begin{align} \psi_{n} = \Psi(Q_{n}) = \frac{1}{n} \sum_{i=1}^{n} \{\bar{Q}_{n}(1, W_{i}) - \bar{Q}_{n}(0, W_{i})\} \end{align}\]
Want an automated algorithm to nonparametrically (or semiparametrically) estimate \(\bar{Q}_{0}(A, W)\)
An algorithm is an estimator of \(\bar{Q}_{0}\) that maps a data set \(\{(W_{i}, A_{i}, Y_{i})\}_{i=1}^{n}\) into a prediction function that can be used to map input \((A, W)\) into a predicted value \(Y\)
Want to find best algorithm to estimate the function \[\begin{align*} \bar{Q}_{0} : (A, W) \to \bar{Q}_{0}(A, W) \end{align*}\]
Can define “best” in terms of a loss function, which assigns a measure of performance to a candidate function \(\bar{Q}\) when applied to an observation \(O\), i.e. \[\begin{align*} L : (O, \bar{Q}) \to L(O, \bar{Q}) \in \mathbb{R}. \end{align*}\] It is a function of the random variable \(O\) and parameter value \(\bar{Q}\).
E.g. squared error or negative log loss \[\begin{align*} L(O, \bar{Q}) & = [Y - \bar{Q}(A, W)]^{2} \\ L(O, \bar{Q}) & = - \log [ \bar{Q}(A, W)^{Y} \left\{ 1 - \bar{Q}(A, W) \right\}^{1 - Y} ] \end{align*}\]
For binary \(Y\), both the squared error and negative log loss target the same function, namely \(\bar{Q}_{0}(A, W) = E_{0}(Y \mid A, W) = P_{0}(Y = 1 \mid A, W)\)
A loss function defines a function \(\bar{Q}_{0}\) that has the optimal performance with respect to that loss function among all candidate functions \(\bar{Q}\)
We can define our parameter of interest, \(\bar{Q}_{0}(A, W) = E_{0}(Y \mid A, W)\), as the minimizer of the expected squared error loss: \[\begin{align} \bar{Q}_{0} = \underset{\bar{Q}}{\arg\min} \ E_{0} L(O, \bar{Q}) = \underset{\bar{Q}}{\arg\min} \ E_{0} [Y - \bar{Q}(A, W)]^{2} \end{align}\]
The expected loss \(R(\bar{Q}) := E_{0} L(O, \bar{Q})\) is denoted the risk
“Best” algorithm: want the estimator of the regression function \(\bar{Q}_{0}\) whose realized value minimizes the expectation of the squared error loss function.
For e.g. \(V = 10\) we have training and validation samples \((\mathcal{T}_{1}, \mathcal{V}_{1}), \ldots, (\mathcal{T}_{10}, \mathcal{V}_{10})\) where \(\mathcal{T}_{i}\) and \(\mathcal{V}_{i}\) of size \(n_{t}\) and \(n_{v}\), respectively
Fit learner on training sample \(\mathcal{T}_{j}\), denote resulting learner by \(\bar{Q}_{n, j}\), and compute its risk on corresponding validation sample \(\mathcal{V}_{j}\); denote the risk estimate by \[\begin{equation} R_{n, j}(\bar{Q}_{n, j}) = \frac{1}{n_{v}} \sum_{i \in \mathcal{V}_{j}} L(O_{i}, \bar{Q}_{n, j}) = \frac{1}{n_{v}} \sum_{i \in \mathcal{V}_{j}} \{Y_{i} - \bar{Q}_{n, j}(A_{i}, W_{i})\}^{2} \end{equation}\]
The cross-validated risk (CV risk) is the average risk over the folds \[\begin{align*} R_{\mathrm{CV}, n}(\bar{Q}_{n}) := \frac{1}{V} \sum_{v = 1}^{V} R_{n, v}(\bar{Q}_{n, v}) \end{align*}\]
CV risk provides a close approximation of the true risk of an estimator
Figure from Eric Polley’s slides.
be the index of the estimator with the smallest CV risk.
which depends on the true risk \(R(\cdot)\) and is therefore unknown in practice
It serves as a benchmark; can compare performance of \(\bar{Q}^{k_{n}^{\mathrm{cv}}}_{n, v}\) to \(\bar{Q}^{k_{n}^{\mathrm{or}}}_{n, v}\)
van der Vaart, Dudoit & van der Laan (2006) shows an oracle inequality saying that the discrete super learner is essentially as close to the truth as the oracle, and the number of algorithms \(K\) is allowed to be very large
The estimator corresponding to the convex combination \(\omega^{\mathrm{cv}}_{n}\) of estimators is \(\bar{Q}_{n}^{\omega^{\mathrm{cv}}_{n}} = \sum_{k = 1}^{K} \omega^{\mathrm{cv}}_{n, k} \bar{Q}_{n}^{k}\).
The resulting estimator is the super learner
We used cross-validation to find the best convex combination of candidate estimators; how to estimate the risk of the super learner?
Can also estimate the risk of the super learner using cross-validation.
CV.SuperLearner
splits the data into \(V\) folds and then calls SuperLearner
on each foldFigure from Raschka 2018.
\[\begin{align*} Y &= -2 \cdot \mathbf{1}(X < -3) + 2.55 \cdot \mathbf{1}(X > -2) - 2 \cdot \mathbf{1}(X > 0) \\ & + 4 \cdot \mathbf{1}(X > 2) - 1 \cdot \mathbf{1}(X > 3) +u \\ Y &= 6 + 0.4X - 0.36X^2 + 0.005X^3 + u \\ Y &= 2.83 \sin\left(\frac{\pi}{2} X\right) +u \\ Y &= 4 \sin(3\pi X) \cdot \mathbf{1}(X > 0) + u \end{align*}\] where \(u \sim N(0, 1)\) and \(X \sim \mathrm{Unif}(-4, 4)\)
gam_learners <- create.Learner(
"SL.gam", tune = list(deg.gam = c(2, 3, 4))
)
bag_learners <- create.Learner(
"SL.ipredbagg", tune = list(cp = c(0.0, 0.1, 0.01))
)
nn_learners <- create.Learner(
"SL.nnet", tune = list(size = c(2, 3, 4, 5))
)
loess_learners <- create.Learner(
"SL.loess", tune = list(span = c(0.75, 0.5, 0.25, 0.1))
)
all_learners <- c(
gam_learners$names,
nn_learners$names,
bag_learners$names,
loess_learners$names,
"SL.randomForest",
"SL.xgboost",
"SL.glm",
"SL.polymars"
)
Super Learner weights from plot | ||||
---|---|---|---|---|
Algorithm | Sim 1 | Sim 2 | Sim 3 | Sim 4 |
SL.glm | 0.0 | 0.0 | 0.0 | 0.0 |
SL.randomForest | 0.0 | 0.0 | 0.0 | 0.864 |
SL.xgboost | 0.322 | 0.0 | 0.0 | 0.136 |
SL.ipredbagg(0.00) | 0.0 | 0.015 | 0.212 | 0.0 |
SL.ipredbagg(0.1) | 0.0 | 0.0 | 0.158 | 0.0 |
SL.ipredbagg(0.01) | 0.0 | 0.0 | 0.0 | 0.0 |
SL.gam(2) | 0.0 | 0.0 | 0.0 | 0.0 |
SL.gam(3) | 0.0 | 0.0 | 0.0 | 0.0 |
SL.gam(4) | 0.0 | 0.0 | 0.0 | 0.0 |
SL.nnet(2) | 0.0 | 0.102 | 0.0 | 0.0 |
SL.nnet(3) | 0.0 | 0.231 | 0.061 | 0.0 |
SL.nnet(4) | 0.381 | 0.0 | 0.07 | 0.0 |
SL.nnet(5) | 0.206 | 0.0 | 0.064 | 0.0 |
SL.polymars | 0.091 | 0.0 | 0.0 | 0.0 |
SL.loess(0.75) | 0.0 | 0.247 | 0.0 | 0.0 |
SL.loess(0.5) | 0.0 | 0.406 | 0.255 | 0.0 |
SL.loess(0.25) | 0.0 | 0.0 | 0.179 | 0.0 |
SL.loess(0.1) | 0.0 | 0.0 | 0.0 | 0.0 |
Simulation results for N = 100 simulations | ||||||||
---|---|---|---|---|---|---|---|---|
Algorithm | Sim 1 | Sim 2 | Sim 3 | Sim 4 | ||||
R² | SE(R²) | R² | SE(R²) | R² | SE(R²) | R² | SE(R²) | |
Super Learner | 0.746 | 0.004 | 0.782 | 0.004 | 0.734 | 0.005 | 0.477 | 0.005 |
Discrete SL | 0.755 | 0.004 | 0.78 | 0.004 | 0.718 | 0.005 | 0.489 | 0.006 |
SL.glm | 0.403 | 0.008 | 0.185 | 0.011 | 0.11 | 0.006 | -0.013 | 0.002 |
SL.randomForest | 0.728 | 0.004 | 0.741 | 0.004 | 0.739 | 0.005 | 0.489 | 0.006 |
SL.xgboost | 0.701 | 0.005 | 0.704 | 0.005 | 0.678 | 0.006 | 0.4 | 0.008 |
SL.ipredbagg(0.00) | 0.754 | 0.004 | 0.754 | 0.004 | 0.715 | 0.005 | 0.107 | 0.007 |
SL.ipredbagg(0.1) | 0.752 | 0.004 | 0.755 | 0.004 | 0.714 | 0.005 | 0.097 | 0.008 |
SL.ipredbagg(0.01) | 0.755 | 0.004 | 0.754 | 0.004 | 0.715 | 0.005 | 0.11 | 0.007 |
SL.gam(2) | 0.465 | 0.007 | 0.701 | 0.004 | 0.237 | 0.005 | -0.004 | 0.003 |
SL.gam(3) | 0.548 | 0.006 | 0.777 | 0.003 | 0.409 | 0.005 | -0.007 | 0.005 |
SL.gam(4) | 0.609 | 0.005 | 0.785 | 0.003 | 0.544 | 0.006 | -0.013 | 0.006 |
SL.nnet(2) | 0.653 | 0.005 | 0.618 | 0.007 | 0.334 | 0.007 | -0.036 | 0.012 |
SL.nnet(3) | 0.716 | 0.005 | 0.769 | 0.004 | 0.721 | 0.005 | -0.012 | 0.013 |
SL.nnet(4) | 0.716 | 0.005 | 0.768 | 0.004 | 0.718 | 0.005 | 0.024 | 0.009 |
SL.nnet(5) | 0.722 | 0.004 | 0.759 | 0.004 | 0.763 | 0.004 | -3.2 | 0.274 |
SL.polymars | 0.72 | 0.004 | 0.768 | 0.004 | 0.758 | 0.004 | -0.041 | 0.008 |
SL.loess(0.75) | 0.586 | 0.005 | 0.786 | 0.003 | 0.449 | 0.009 | -0.02 | 0.006 |
SL.loess(0.5) | 0.654 | 0.005 | 0.78 | 0.004 | 0.646 | 0.007 | -0.013 | 0.006 |
SL.loess(0.25) | 0.697 | 0.005 | 0.769 | 0.004 | 0.767 | 0.004 | -0.097 | 0.007 |
SL.loess(0.1) | 0.705 | 0.005 | 0.734 | 0.004 | 0.73 | 0.005 | 0.538 | 0.009 |
# Define outcome (Y), treatment (A), and covariates (W)
Y <- df$Y
A <- df$treat
W <- df[, c(
"age", "educ", "black", "hispan", "married", "nodegree", "re74", "re75"
)]
tmle_fit <- tmle(
Y = Y, A = A, W = W,
Q.SL.library = all_learners,
g.SL.library = all_learners,
family = "binomial",
)
# ... and for continuous outcome
Y_cont <- df$re78
tmle_fit_cont <- tmle(
Y = Y_cont, A = A, W = W,
Q.SL.library = all_learners,
g.SL.library = all_learners,
family = "gaussian",
)
g.SL.library
argument, to target the initial estimate of \(\Psi(Q_{0})\) towards the TMLE (the influence function contains the propensity score)TMLE results | ||
---|---|---|
Binary | Continuous | |
Ψ(Q₀) | 0.036 | 640.886 |
CI lower | -0.017 | -104.296 |
CI upper | 0.09 | 1386.068 |
p-value | 0.183 | 0.092 |
Super Learner weights | ||
---|---|---|
Binary | Continuous | |
SL.gam(2) | 0.0 | 0.036 |
SL.gam(3) | 0.0 | 0.492 |
SL.gam(4) | 0.0 | 0.0 |
SL.randomForest | 0.272 | 0.472 |
SL.xgboost | 0.0 | 0.0 |
SL.glm | 0.728 | 0.0 |
\(\text{HtmlTable} \mapsto \texttt{Prompt}(\text{f"Please interpret the results:\n{HtmlTable}"})\):