library(bartcs)
The bartcs package finds confounders and treatment effect with Bayesian Additive Regression Trees (BART).
This tutorial will use The Infant Health and Development Program (IHDP) dataset. The dataset includes 6 continuous and 19 binary covariates with simulated outcome which is a cognitive test score. This dataset was first used by Hill (2011). My version of dataset is the first realization generated by Louizos et al. (2017) and you can find other versions in his github.
data(ihdp, package = "bartcs")
<- mbart(
fit Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 10,
num_chain = 4,
num_post_sample = 100,
num_burn_in = 100,
verbose = FALSE
)
fit#> `bartcs` fit by `mbart()`
#>
#> mean 2.5% 97.5%
#> ATE 4.011577 3.784315 4.234449
#> Y1 6.421923 6.223755 6.607631
#> Y0 2.410346 2.308209 2.502927
You can get mean and 95% credible interval of average treatment effect (ATE) and possible outcome Y1 and Y0.
Both sbart()
and mbart()
fits multiple MCMC
chains. summary()
provides result and Gelman-Rubin
statistic to check convergence.
summary(fit)
#> `bartcs` fit by `mbart()`
#>
#> Treatment Value
#> Treated group : 1
#> Control group : 0
#>
#> Tree Parameters
#> Number of Tree : 10 Value of alpha : 0.95
#> Prob. of Grow : 0.28 Value of beta : 2
#> Prob. of Prune : 0.28 Value of nu : 3
#> Prob. of Change : 0.44 Value of q : 0.95
#>
#> Chain Parameters
#> Number of Chains : 4 Number of burn-in : 100
#> Number of Iter : 200 Number of thinning : 0
#> Number of Sample : 100
#>
#> Outcome Diagnostics
#> Gelman-Rubin : 0.9993255
#>
#> Outcome
#> estimand chain 2.5% 1Q mean median 3Q 97.5%
#> ATE 1 3.827927 3.934340 4.000410 4.006176 4.068711 4.148203
#> ATE 2 3.728265 3.920698 3.968997 3.969489 4.031632 4.196675
#> ATE 3 3.782646 3.937966 4.004081 3.998647 4.067501 4.209064
#> ATE 4 3.873411 3.998460 4.072823 4.068607 4.150179 4.255475
#> ATE agg 3.784315 3.944236 4.011577 4.010476 4.080467 4.234449
#> Y1 1 6.219079 6.334693 6.389728 6.393089 6.442329 6.539095
#> Y1 2 6.207831 6.358046 6.414615 6.421666 6.469041 6.650941
#> Y1 3 6.229205 6.366742 6.429787 6.433926 6.492659 6.581297
#> Y1 4 6.262879 6.387433 6.453562 6.461160 6.526814 6.609977
#> Y1 agg 6.223755 6.359395 6.421923 6.423216 6.483039 6.607631
#> Y0 1 2.301260 2.366610 2.389318 2.397295 2.413969 2.465218
#> Y0 2 2.365974 2.412603 2.445618 2.445942 2.474246 2.525244
#> Y0 3 2.346262 2.396049 2.425706 2.428951 2.457778 2.507713
#> Y0 4 2.296413 2.346101 2.380739 2.377355 2.416494 2.454110
#> Y0 agg 2.308209 2.377912 2.410346 2.410937 2.444841 2.502927
You can get posterior inclusion probability for each variables.
plot(fit, method = "pip")
Since inclusion_plot()
is a wrapper function of
ggcharts::bar_chart()
, you can use its arguments for better
plot.
plot(fit, method = "pip", top_n = 10)
plot(fit, method = "pip", threshold = 0.5)
With trace_plot()
, you can visually check trace of
effects or other parameters.
plot(fit, method = "trace")
plot(fit, method = "trace", "dir_alpha")
count_omp_thread()
#> [1] 6
Check whether OpenMP is supported. You need more than 1 thread for multi-threading. Due to overhead of multi-threading, using parallelization will be not effective with small and moderate datasets. I recommend parallelization for data with size of at least 10,000.
For comparison purpose, I will create dataset with 20,000 rows by bootstrapping from IHDP dataset. Then, for fast computation, I will set most parameters to 1.
<- sample(nrow(ihdp), 2e4, TRUE)
idx <- ihdp[idx, ]
ihdp
::microbenchmark(
microbenchmarksimple_mbart = mbart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = FALSE
),parallel_mbart = mbart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = TRUE
),times = 50
)#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> simple_mbart 57.8327 68.8058 75.67812 74.86685 77.6620 174.1183 50
#> parallel_mbart 51.2153 59.6585 64.93935 62.62715 66.8636 162.2986 50
Result show that parallelization gives better result.