The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
causalOT
was developed to reproduce the methods in Optimal transport methods
for causal inference. The functions in the package are built to
construct weights to make distributions more same and estimate causal
effects. We recommend using the Causal Optimal Transport methods since
they are semi- to non-parametric. This document will describe some
simple usages of the functions in the package and should be enough to
get users started.
The main functions of the package, calc_weight
and
estimate_effect
, take arguments x
, a numeric
matrix
of covariates; z
, a treatment indicator
in c(0,1)
; and y
, a numeric vector with the
outcome data.
If easy to do, we can supply the necessary data directly.
# packages
library(causalOT)
library(torch)
# reproducible seeds
set.seed(1111)
torch_manual_seed(3249)
# generated some data
hainmueller <- Hainmueller$new(n = 512)
hainmueller$gen_data()
x <- hainmueller$get_x()
z <- hainmueller$get_z()
y <- hainmueller$get_y()
# NOT RUN
# weights <- calc_weight(x = x, z = z, y = y)
Note that the calc_weight function will not use the outcome data
y
when calculating the weights and will
not pass it to the internal data constructor. It must
be passed to the estimate_effect
function later.
However, sometimes we get data as a data.frame
and users
may not know how to turn these into the required objects. In this case,
we’ve supplied the df2dataHolder
function to create the
required data object like so:
df <- data.frame(y = y, z = z, x)
df2dH <- df2dataHolder(treatment.formula = "z ~ .",
outcome.formula = "y ~ .",
data = df)
# NOT RUN
# weights <- calc_weight(x = df2dH, z = NULL, y = NULL)
In this case, we can pass the dataHolder
object directly
to the calc_weight
function inside the x
argument and ignore the others; the dataHolder
object
already contains the x
, y
, and z
data internally. Note that this function will need both outcome and
treatment formulae since it needs to know which columns are actually
confounders for the purposes of calculating the weights!
Finally, if you so desire, you can create a dataHolder object directly.
dH <- dataHolder(x = x, z = z, y = y)
# NOT RUN
# weights <- calc_weight(x = dH, z = NULL, y = NULL)
This may be useful if you plan on reusing the data object.
The weights can be estimated by using the calc_weight
function in the package. We select optimal hyperparameters through our
bootstrap-based algorithm and target the average treatment effect.
weights <- calc_weight(x = x, z = z,
method = "COT",
estimand = "ATE",
options = list(lambda.bootstrap = Inf,
nboot = 1000L)
)
These weights will balance distributions, making estimates of treatment effects unbiased.
We can then estimate effects with
The estimator generated here is a simple weighted difference in
observed outcomes between treatment groups. Moreover, note we must
supply the outcome information in argumnet y
since the
calc_weight
function does not store it when we pass data
matrices.
The output of the estimate_effect
function creates an
object of class causalEffect
which can be fed into the
native R
function confint
to calculate
asymptotic confidence intervals,
or into vcov
to calculate the variance of your estimate
using the semiparametrically efficient variance formula:
This then gives the following treatment effect estimate, variance, and C.I.
print(coef(tau_hat))
#> estimate
#> 0.007949831
print(var_tau)
#> estimate
#> estimate 0.1887772
print(ci_tau)
#> 2.5 % 97.5 %
#> estimate -0.8436251 0.8595248
The function estimate_effect
can also use models to
estimate the treatment effects. There are also several additional
arguments that will be demonstrated below. These are:
model.function
: either a character or function with the
model you want to runestimate.separately
: TRUE or FALSE. Should the model be
estimated separately on each treatment group (TRUE) or jointly on the
full data (FALSE)augment.estimate
: Should an augmented estimator be used
to calculate the final treatment effect? (TRUE or FALSE)normalize.weights
: Should the weights be normalized to
sum to one in each treatment group before being used. For methods except
“Logistic”, “Probit”, or “CBPS”, the weights are by definition
normalized to sum to one so this option will not have an effect for most
methods in the package.The model functions we can use need to have a few components
formula
argumentdata
argument that accepts a
data.frame
weights
argumentnewdata
argument.One such function we could use is lm
.
lm
tau_hat_lm <- estimate_effect(causalWeights = weights,
y = y,
model.function = lm,
estimate.separately = TRUE,
augment.estimate = FALSE,
normalize.weights = TRUE)
In this case, separate models will be fit to treated and controls and
the predictions from the model will be used to estimate treatment
effects. We can also calculate the augmented (aka doubly robust)
estimate with argument augment.estimate
.
tau_hat_dr <- estimate_effect(causalWeights = weights,
y = y,
model.function = lm,
estimate.separately = TRUE,
augment.estimate = TRUE,
normalize.weights = TRUE)
We can also fit a weighted OLS by specifying
estimate.separately = FALSE
:
tau_hat_wols <- estimate_effect(causalWeights = weights,
y = y,
model.function = lm,
estimate.separately = FALSE,
augment.estimate = FALSE,
normalize.weights = TRUE)
This fits a single weighted OLS model on the entire data.
An outcome model that is particular to this package is the function
barycentric_projection
that estimates, as the name implies,
barycentric projections of the outcome data. To use this function, there
are a couple of steps. Unlike the case for linear models with
lm
, we need to think carefully about what sample the data
arise from.
To use this function outside of the main causalOT functions, we would do
df <- data.frame(z = z, y = y, x)
bp <- barycentric_projection(formula = "y ~ x + z",
data = df,
weights = weights,
separate.samples.on = "z",
penalty = 0.01,
cost_function = NULL,
p = 2,
debias = FALSE,
cost.online = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-7)
This will run the optimal transport problem between the samples denoted by “z” and get the dual potentials for the Sinkhorn Divergence problem.
Then, we can run a predict
function to see what the
outcomes would be if the samples had arisen from a different
distribution. Let’s say that everyone had actually been treated
newdf <- df
newdf$z <- 1L
preds <- predict(object = bp,
newdata = newdf,
source.sample = df$z)
head(preds)
#> [1] -2.80749441 0.52369862 5.79605734 -1.55036673 0.01254699
#> [6] 1.55753697
The argument source.sample
should be a vector that
denotes the original treatment group of the samples. This allows the
function to use the appropriate dual potentials to calculate the
expected outcome.
In the context of the estimate_effect function, we need to supply
some extra arguments in the ...
argument.
tau_hat_bp <- estimate_effect(causalWeights = weights,
y = y,
model.function = barycentric_projection,
estimate.separately = FALSE,
augment.estimate = TRUE,
normalize.weights = TRUE,
# special args for barycentric_projection
separate.samples.on = "z",
penalty = 0.01,
cost_function = NULL,
p = 3,
debias = FALSE,
cost.online = "tensorized",
diameter = NULL,
niter = 1000L,
tol = 1e-7,
line_search_fn = "strong_wolfe"
)
print(tau_hat_bp@estimate)
#> [1] -0.02549428
This method currently doesn’t have a variance estimator accounting for the weight uncertainty but we can use the asymptotic variance estimator of Hahn (1998):
In neither of these cases did we feed data or a formula to the model
function. By default, the estimate_effect
function will
regress the outcome in argument y
on all of the covariates
from the calc_weight
function and adjust for the treatment
indicator as appropriate given the selected options. If you want to
change the covariates for the outcome model from the weighting
estimating function, you can provide new covariates in an argument
x
:
Note this data must have the same observation order as the previous
data and must also be an object of class matrix
.
Diagnostics are also an important part of deciding whether the weights perform well. There are several areas that we will explore:
Typically, estimated samples sizes with weights are calculated as
\(\sum_i 1/w_i^2\) and gives us a
measure of how much information is in the sample. The lower the
effective sample size (ESS), the higher the variance, and the lower the
sample size, the more weight a few individuals have. Of course, we can
calculate this in causalOT
!
Of course, this measure has problems because it can fail to diagnose
problems with variable weights. In response, Vehtari et al. use Pareto
smoothed importance sampling. We offer some shell code to adapt the
class causalWeights
to the loo
package:
This will also return the Pareto smoothed weights and log weights.
If we want to easily examine the PSIS diagnostics, we can pull those out too
PSIS_diag(raw_psis)
#> $w0
#> $w0$pareto_k
#> [1] -0.2320877
#>
#> $w0$n_eff
#> [1] 115.7568
#>
#>
#> $w1
#> $w1$pareto_k
#> [1] 0.2478191
#>
#> $w1$n_eff
#> [1] 82.99825
We can see that all of the \(k\)
values are below the recommended 0.5, indicating finite variance and
that the central limit theorem holds. Note the estimated sample sizes
are a bit lower than the ESS
method above.
Many authors consider the standardized absolute mean balance as a
marker for important balance: see Stuart
(2010). That is \[ \frac{|\overline{X}_c
- \overline{X}_t| }{\sigma_{\text{pool}}},\] where \(\overline{X}_c\) is the mean in the
controls, \(\overline{X}_t\) is the
mean in the treated, and \(\sigma_{\text{pool}}\) is the pooled
standard deviation. We offer such checks in causalOT
as
well.
First, we consider pre-weighting mean balance between treatment groups
mean_balance(x = hainmueller)
#> X1 X2 X3 X4 X5 X6
#> 1.0889178 0.9320099 0.9322327 0.3741322 0.2798986 0.2644929
and after weighting mean balance between treatment groups
mean_balance(x = hainmueller, weights = weights)
#> X1 X2 X3 X4 X5 X6
#> 0.036900048 0.006577199 0.003137536 0.003228830 0.002445118 0.060551316
Pretty good! However, mean balance doesn’t ensure distributional balance.
Ultimately, distributional balance is what we care about in causal inference. Fortunately, we can also measure that too. We consider the 2-Sinkhorn divergence of Genevay et al. since it metrizes the convergence in distribution.
Before weighting, distributional balance looks poor:
# controls
ot_distance(x1 = hainmueller$get_x0(), x2 = hainmueller$get_x(),
a = NULL, b = rep(1/512,512),
p = 2, penalty = 1e3, debias = TRUE)
#> [1] 0.5311378
#treated
ot_distance(x1 = hainmueller$get_x1(), x2 = hainmueller$get_x(),
a = NULL, b = rep(1/512,512),
p = 2, penalty = 1e3, debias = TRUE)
#> [1] 0.4612781
But after weighting, it looks much better!
# controls
ot_distance(x1 = hainmueller$get_x0(), x2 = hainmueller$get_x(),
a = weights@w0, b = rep(1/512,512),
p = 2, penalty = 1e3, debias = TRUE)
#> [1] 0.003670398
# treated
ot_distance(x1 = hainmueller$get_x1(), x2 = hainmueller$get_x(),
a = weights@w1, b = rep(1/512,512),
p = 2, penalty = 1e3, debias = TRUE)
#> [1] 0.002123788
After Causal Optimal Transport, the distributions are much
more similar. We can also simply feed the output of calc_weight directly
into the ot_distance
function:
ot_distance(x1 = weights, p = 2, penalty = 1e3, debias = TRUE)
#> $pre
#> control treated
#> 0.5311378 0.4612781
#>
#> $post
#> control treated
#> 0.003670398 0.002123788
and the S4 deployment takes care of the rest.
Finally, we can construct a summary of the optimal transport distances, Pareto k statistics, effective sample size, and mean balance using the summary method:
We can then print the object to the screen:
summarized_cw
#> Diagnostics for causalWeights for estimand ATE
#> Control group
#> pre post
#> OT distance 0.5311378 0.003670398
#> Pareto k NA -0.232087710
#> N eff 247.0000000 115.756836449
#> Avg. std. mean balance 0.3067516 0.017646355
#>
#> Treated group
#> pre post
#> OT distance 0.4612781 2.123788e-03
#> Pareto k NA 2.478191e-01
#> N eff 265.0000000 8.299825e+01
#> Avg. std. mean balance 0.2859156 9.871006e-04
or we can make some diagnostic plots too!
The calc weight function can also handle other methods. We have implemented methods for logistic or probit regression, the covariate balancing propensity score (CBPS), stable balancing weights (SBW), entropy balancing weights (EntropyBW), and the synthetic control method (SCM).
calc_weight(x = hainmueller, method = "Logistic",
estimand = "ATE")
calc_weight(x = hainmueller, method = "Probit",
estimand = "ATE")
calc_weight(x = hainmueller, method = "CBPS",
estimand = "ATE")
calc_weight(x = hainmueller, method = "SBW")
calc_weight(x = hainmueller, method = "EntropyBW")
calc_weight(x = hainmueller, method = "SCM")
The function also accepts methods “EnergyBW”, for Energy Balancing Weights of Hainmueller and Mak (2020), and “NNM”, for nearest neighbor matching with replacement, but these are special cases of COT with the penalty parameter \(\lambda\) forced to be \(\infty\) and \(0\), respectively.
The argument options
is a little vague. So we also have
a function cotOptions
which is avaible to help. The
documentation provides more details. The other optimization methods
“SBW”, “EntropyBW”, and “SCM” provide their own options function. The
options for “Logistic” and “Probit” pass arguments to glm
and “CBPS” will pass arguments to the CBPS
function in the
package of the same name.
The package also provides more flexible optimal transport weights and
modeling via some object-oriented programming via the R6
package. These functions don’t have as many safeguards and everything is
done by reference so that you have to be more careful about what you do.
However, it also gives you more flexibility on the types of problems you
can solve. To learn more, see the vignette on object-oriented
solvers.
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.