bkmr
and bkmrhat
bkmr
is a package to implement Bayesian kernel machine regression (BKMR) using Markov chain Monte Carlo (MCMC). Notably, bkmr
is missing some key features in Bayesian inference and MCMC diagnostics: 1) no facility for running multiple chains in parallel 2) no inference across multiple chains 3) limited posterior summary of parameters 4) limited diagnostics. The bkmrhat
package is a lightweight set of function that fills in each of those gaps by enabling post-processing of bkmr
output in other packages and building a small framework for parallel processing.
bkmrhat
packagekmbaryes
function from bkmr
, or use multiple parallel chains kmbayes_parallel
from bkmrhat
kmbayes_diagnose
function (uses functions from the rstan
package) OR convert the BKMR fit(s) to mcmc
(one chain) or mcmc.list
(multiple chains) objects from the coda
package using as.mcmc
or as.mcmc.list
from the bkmrhat
package. The coda
package has a whole host of inference and diagnostic procedures (but may lag behind some of the diagnostics functions from rstan
).coda
functions or combine chains from a kmbayes_parallel
fit using kmbayes_combine
. Final posterior inferences can be made on the combined object, which enables use of bkmr
package functions for visual summaries of independent and joint effects of exposures in the bkmr
model.First, simulate some data from the bkmr
function
library("bkmr")
library("bkmrhat")
library("coda")
set.seed(111)
dat <- bkmr::SimData(n = 50, M = 5, ind=1:3, Zgen="realistic")
y <- dat$y
Z <- dat$Z
X <- cbind(dat$X, rnorm(50))
head(cbind(y,Z,X))
## y z1 z2 z3 z4 z5
## [1,] 4.1379128 -0.06359282 -0.02996246 -0.14190647 -0.44089352 -0.1878732
## [2,] 12.0843607 -0.07308834 0.32021690 1.08838691 0.29448354 -1.4609837
## [3,] 7.8859254 0.59604857 0.20602329 0.46218114 -0.03387906 -0.7615902
## [4,] 1.1609768 1.46504863 2.48389356 1.39869461 1.49678590 0.2837234
## [5,] 0.4989372 -0.37549639 0.01159884 1.17891641 -0.05286516 -0.1680664
## [6,] 5.0731242 -0.36904566 -0.49744932 -0.03330522 0.30843805 0.6814844
##
## [1,] 1.0569172 -1.0503824
## [2,] 4.8158570 0.3251424
## [3,] 2.6683461 -2.1048716
## [4,] -0.7492096 -0.9551027
## [5,] -0.5428339 -0.5306399
## [6,] 1.6493251 0.8274405
There is some overhead in parallel processing when using the future
package, so the payoff when using parallel processing may vary by the problem. Here it is about a 2-4x speedup, but you can see more benefit at higher iterations. Note that this may not yield as many usable iterations as a single large chain if a substantial burnin period is needed, but it will enable useful convergence diagnostics. Note that the future package can implement sequential processing, which effectively turns the kmbayes_parallel into a loop, but still has all other advantages of multiple chains.
# enable parallel processing (up to 4 simultaneous processes here)
future::plan(strategy = future::multisession)
# single run of 4000 observations from bkmr package
set.seed(111)
system.time(kmfit <- suppressMessages(kmbayes(y = y, Z = Z, X = X, iter = 4000, verbose = FALSE, varsel = FALSE)))
## user system elapsed
## 12.560 0.096 12.727
# 4 runs of 1000 observations from bkmrhat package
set.seed(111)
system.time(kmfit5 <- suppressMessages(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = FALSE)))
## Chain 1
## Chain 2
## Chain 3
## Chain 4
## user system elapsed
## 0.088 0.004 6.551
The diagnostics from the rstan package come from the monitor
function (see the help files for that function in the rstan pacakge)
# Using rstan functions (set burnin/warmup to zero for comparability with coda numbers given later
# posterior summaries should be performed after excluding warmup/burnin)
singlediag = kmbayes_diagnose(kmfit, warmup=0, digits_summary=2)
## Single chain
## Inference for the input samples (1 chains: each with iter = 4000; warmup = 0):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.1 2.0 0.0 1.00 2820 3194
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 3739 3535
## lambda 3.9 10.0 22.3 11.2 5.9 1.00 346 222
## r1 0.0 0.0 0.1 0.0 0.1 1.01 129 173
## r2 0.0 0.0 0.1 0.0 0.1 1.00 182 181
## r3 0.0 0.0 0.0 0.0 0.0 1.01 158 112
## r4 0.0 0.0 0.1 0.0 0.1 1.03 176 135
## r5 0.0 0.0 0.0 0.0 0.1 1.00 107 114
## sigsq.eps 0.2 0.3 0.5 0.4 0.1 1.00 1262 1563
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# Using rstan functions (multiple chains enable R-hat)
multidiag = kmbayes_diagnose(kmfit5, warmup=0, digits_summary=2)
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 0):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.1 2.0 0.0 1.00 1951 1652
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 3204 2578
## lambda 4.5 9.7 24.0 11.4 6.5 1.01 359 510
## r1 0.0 0.0 0.1 0.1 0.2 1.02 133 66
## r2 0.0 0.0 0.2 0.1 0.2 1.02 116 71
## r3 0.0 0.0 0.1 0.0 0.2 1.02 87 92
## r4 0.0 0.0 0.1 0.0 0.1 1.03 119 78
## r5 0.0 0.0 0.5 0.1 0.2 1.07 49 44
## sigsq.eps 0.2 0.3 0.5 0.3 0.1 1.01 655 431
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# using coda functions, not using any burnin (for demonstration only)
kmfitcoda = as.mcmc(kmfit, iterstart = 1)
kmfit5coda = as.mcmc.list(kmfit5, iterstart = 1)
# single chain trace plot
traceplot(kmfitcoda)
The trace plots look typical, and fine, but trace plots don't give a full picture of convergence. Note that there is apparent quick convergence for a couple of parameters demonstrated by movement away from the starting value and concentration of the rest of the samples within a narrow band.
Seeing visual evidence that different chains are sampling from the same marginal distributions is reassuring about the stability of the results.
# multiple chain trace plot
traceplot(kmfit5coda)
Now examine “cross correlation”, which can help identify highly correlated parameters in the posterior, which can be problematic for MCMC sampling. Here there is a block {r3,r4,r5} which appear to be highly correlated. All other things equal, having highly correlated parameters in the posterior means that more samples are needed than would be needed with uncorrelated parameters.
# multiple cross-correlation plot (combines all samples)
crosscorr(kmfit5coda)
## beta1 beta2 lambda r1 r2
## beta1 1.00000000 0.11196556 0.01818948 0.12780730 0.14054380
## beta2 0.11196556 1.00000000 -0.09886613 -0.07629748 -0.08187748
## lambda 0.01818948 -0.09886613 1.00000000 0.05928473 0.08790397
## r1 0.12780730 -0.07629748 0.05928473 1.00000000 0.80592328
## r2 0.14054380 -0.08187748 0.08790397 0.80592328 1.00000000
## r3 0.15885857 -0.04824946 0.08800776 0.55926182 0.72804955
## r4 0.24884624 -0.02522362 0.03561473 0.68618114 0.76089188
## r5 0.10629353 -0.05275125 0.08188335 0.65339395 0.83482001
## sigsq.eps -0.04299566 0.05842240 -0.32711124 -0.20915310 -0.22953148
## r3 r4 r5 sigsq.eps
## beta1 0.15885857 0.24884624 0.10629353 -0.04299566
## beta2 -0.04824946 -0.02522362 -0.05275125 0.05842240
## lambda 0.08800776 0.03561473 0.08188335 -0.32711124
## r1 0.55926182 0.68618114 0.65339395 -0.20915310
## r2 0.72804955 0.76089188 0.83482001 -0.22953148
## r3 1.00000000 0.74099838 0.76235821 -0.15141692
## r4 0.74099838 1.00000000 0.67017793 -0.14066748
## r5 0.76235821 0.67017793 1.00000000 -0.17857737
## sigsq.eps -0.15141692 -0.14066748 -0.17857737 1.00000000
crosscorr.plot(kmfit5coda)
Now examine “autocorrelation” to identify parameters that have high correlation between subsequent iterations of the MCMC sampler, which can lead to inefficient MCMC sampling. All other things equal, having highly autocorrelated parameters in the posterior means that more samples are needed than would be needed with low-autocorrelation parameters.
# multiple chain trace plot
#autocorr(kmfit5coda) # lots of output
autocorr.plot(kmfit5coda)
Graphical tools can be limited, and are sometimes difficult to use effectively with scale parameters (of which bkmr
has many). Additionally, no single diagnostic is perfect, leading many authors to advocate the use of multiple, complementary diagnostics. Thus, more formal diagnostics are helpful.
Gelman's r-hat diagnostic gives an interpretable diagnostic: the expected reduction in the standard error of the posterior means if you could run the chains to an infinite size. These give some idea about when is a fine idea to stop sampling. There are rules of thumb about using r-hat to stop sampling that are available from several authors (for example you can consult the help files for rstan
and coda
).
Effective sample size is also useful - it estimates the amount of information in your chain, expressed in terms of the number of independent posterior samples it would take to match that information (e.g. if we could just sample from the posterior directly).
# Gelman's r-hat using coda estimator (will differ from rstan implementation)
gelman.diag(kmfit5coda)
## Potential scale reduction factors:
##
## Point est. Upper C.I.
## beta1 1.00 1.00
## beta2 1.00 1.00
## lambda 1.01 1.02
## r1 1.04 1.11
## r2 1.06 1.11
## r3 1.05 1.09
## r4 1.03 1.05
## r5 1.06 1.16
## sigsq.eps 1.00 1.00
##
## Multivariate psrf
##
## 1.05
# effective sample size
effectiveSize(kmfitcoda)
## beta1 beta2 lambda r1 r2 r3 r4
## 2411.61878 2865.78299 431.49158 87.11091 260.29419 328.45388 181.61903
## r5 sigsq.eps
## 123.06100 1719.70679
effectiveSize(kmfit5coda)
## beta1 beta2 lambda r1 r2 r3 r4
## 1993.69268 3572.49076 399.59987 130.58132 118.60085 139.37144 190.52199
## r5 sigsq.eps
## 89.01537 1246.93322
Posterior kernel marginal densities, 1 chain
# posterior kernel marginal densities using `mcmc` and `mcmc` objects
densplot(kmfitcoda)
Posterior kernel marginal densities, multiple chains combined. Look for multiple modes that may indicate non-convergence of some chains
# posterior kernel marginal densities using `mcmc` and `mcmc` objects
densplot(kmfit5coda)
Other diagnostics from the coda
package are available here.
Finally, the chains from the original kmbayes_parallel
fit can be combined into a single chain (see the help files for how to deal with burn-in, the default in bkmr
is to use the first half of the chain, which is respected here). The kmbayes_combine
function smartly first combines the burn-in iterations and then combines the iterations after burnin, such that the burn-in rules of subsequent functions within the bkmr
package are respected. Note that unlike the as.mcmc.list
function, this function combines all iterations into a single chain, so trace plots will not be good diagnotistics in this combined object, and it should be used once one is assured that all chains have converged and the burn-in is acceptable.
With this combined set of samples, you can follow any of the post-processing functions from the bkmr
functions, which are described here: https://jenfb.github.io/bkmr/overview.html. For example, see below the estimation of the posterior mean difference along a series of quantiles of all exposures in Z.
# posterior summaries using `mcmc` and `mcmc` objects
summary(kmfitcoda)
##
## Iterations = 1:4000
## Thinning interval = 1
## Number of chains = 1
## Sample size per chain = 4000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## beta1 1.98349 0.04392 0.0006945 0.0008944
## beta2 0.11999 0.08532 0.0013490 0.0015937
## lambda 11.18052 5.90540 0.0933725 0.2842909
## r1 0.03010 0.06385 0.0010095 0.0068408
## r2 0.03656 0.05690 0.0008997 0.0035270
## r3 0.02131 0.04065 0.0006427 0.0022429
## r4 0.02855 0.06641 0.0010500 0.0049278
## r5 0.02904 0.09543 0.0015089 0.0086023
## sigsq.eps 0.35283 0.08228 0.0013009 0.0019840
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## beta1 1.90086 1.95474 1.98287 2.01177 2.07218
## beta2 -0.04513 0.06141 0.12001 0.17716 0.28713
## lambda 3.23941 7.01814 10.00979 14.12075 25.93332
## r1 0.01022 0.01232 0.01807 0.02767 0.09818
## r2 0.01018 0.01433 0.02172 0.04049 0.12353
## r3 0.01015 0.01180 0.01488 0.02198 0.05655
## r4 0.01040 0.01299 0.01670 0.02582 0.08533
## r5 0.01025 0.01219 0.01532 0.01951 0.07021
## sigsq.eps 0.22855 0.29302 0.34057 0.39833 0.54838
summary(kmfit5coda)
##
## Iterations = 1:1000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## beta1 1.98295 0.04719 0.0007461 0.001226
## beta2 0.11392 0.08811 0.0013931 0.001492
## lambda 11.38060 6.53246 0.1032873 0.348441
## r1 0.05443 0.16140 0.0025520 0.018642
## r2 0.07067 0.17568 0.0027778 0.018080
## r3 0.04722 0.15141 0.0023941 0.014086
## r4 0.04203 0.11633 0.0018393 0.008849
## r5 0.06256 0.18346 0.0029008 0.025725
## sigsq.eps 0.34866 0.08658 0.0013690 0.002554
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## beta1 1.89180 1.95338 1.98273 2.01251 2.0756
## beta2 -0.06110 0.05688 0.11537 0.17157 0.2821
## lambda 3.75062 6.94311 9.69281 14.01495 30.4208
## r1 0.01025 0.01225 0.01663 0.02898 0.6771
## r2 0.01084 0.01579 0.02609 0.04595 0.8337
## r3 0.01013 0.01144 0.01571 0.02300 0.5456
## r4 0.01023 0.01203 0.01711 0.02862 0.2529
## r5 0.01004 0.01201 0.01557 0.02197 0.8570
## sigsq.eps 0.21037 0.28755 0.33934 0.39873 0.5428
# highest posterior density intervals using `mcmc` and `mcmc` objects
HPDinterval(kmfitcoda)
## lower upper
## beta1 1.90086937 2.07231225
## beta2 -0.04051413 0.29012992
## lambda 2.68866745 22.62238656
## r1 0.01002094 0.06754473
## r2 0.01002597 0.09683241
## r3 0.01003504 0.04392585
## r4 0.01000511 0.06510500
## r5 0.01005733 0.04755863
## sigsq.eps 0.21411781 0.52229778
## attr(,"Probability")
## [1] 0.95
HPDinterval(kmfit5coda)
## [[1]]
## lower upper
## beta1 1.89276629 2.06413559
## beta2 -0.05156326 0.28395624
## lambda 3.10640350 22.52661684
## r1 0.01015061 0.45345607
## r2 0.01004337 0.23737153
## r3 0.01035515 0.08038502
## r4 0.01008299 0.08787233
## r5 0.01001371 0.35237108
## sigsq.eps 0.19500454 0.52782799
## attr(,"Probability")
## [1] 0.95
##
## [[2]]
## lower upper
## beta1 1.89936901 2.07525060
## beta2 -0.06340550 0.27865199
## lambda 2.38869269 20.88324139
## r1 0.01025450 0.06751305
## r2 0.01008338 0.09635152
## r3 0.01007603 0.07466994
## r4 0.01000690 0.10235090
## r5 0.01004010 0.04759293
## sigsq.eps 0.21207225 0.52893100
## attr(,"Probability")
## [1] 0.95
##
## [[3]]
## lower upper
## beta1 1.88810244 2.07425771
## beta2 -0.04652416 0.29719363
## lambda 3.10602441 24.89619262
## r1 0.01001759 0.07919633
## r2 0.01036142 0.24702810
## r3 0.01007678 0.55489001
## r4 0.01023406 0.12806408
## r5 0.01034510 0.90040470
## sigsq.eps 0.19890368 0.52657850
## attr(,"Probability")
## [1] 0.95
##
## [[4]]
## lower upper
## beta1 1.87836131 2.06908822
## beta2 -0.06136301 0.27665000
## lambda 3.11659093 29.43524129
## r1 0.01000132 0.51647301
## r2 0.01013282 0.62165645
## r3 0.01006649 0.07601939
## r4 0.01005967 0.11054523
## r5 0.01002425 0.66711223
## sigsq.eps 0.20201080 0.53035096
## attr(,"Probability")
## [1] 0.95
# combine multiple chains into a single chain
fitkmccomb = kmbayes_combine(kmfit5)
# For example:
summary(fitkmccomb)
## Fitted object of class 'bkmrfit'
## Iterations: 4000
## Outcome family: gaussian
## Model fit on: 2021-09-07 13:59:41
## Running time: 3.7978 secs
##
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4096024
## 2 r1 0.1970493
## 3 r2 0.2943236
## 4 r3 0.1562891
## 5 r4 0.1860465
## 6 r5 0.1467867
##
## Parameter estimates (based on iterations 2001-4000):
## param mean sd q_2.5 q_97.5
## 1 beta1 1.98212 0.04442 1.89349 2.06781
## 2 beta2 0.11546 0.08379 -0.04764 0.27603
## 3 sigsq.eps 0.35517 0.08322 0.22750 0.54753
## 4 r1 0.02336 0.01801 0.01024 0.07617
## 5 r2 0.03667 0.03489 0.01067 0.13997
## 6 r3 0.02045 0.01566 0.01041 0.06233
## 7 r4 0.02370 0.02456 0.01023 0.10115
## 8 r5 0.01669 0.00787 0.01002 0.04492
## 9 lambda 10.88426 6.03507 3.45492 26.75653
mean.difference <- suppressWarnings(OverallRiskSummaries(fit = fitkmccomb, y = y, Z = Z, X = X,
qs = seq(0.25, 0.75, by = 0.05),
q.fixed = 0.5, method = "exact"))
mean.difference
## quantile est sd
## 1 0.25 -0.7196817 0.11979520
## 2 0.30 -0.5794215 0.09698722
## 3 0.35 -0.3914102 0.08109645
## 4 0.40 -0.2728444 0.04748039
## 5 0.45 -0.1501128 0.02648235
## 6 0.50 0.0000000 0.00000000
## 7 0.55 0.2142025 0.04216533
## 8 0.60 0.3335123 0.05245787
## 9 0.65 0.5136222 0.08448011
## 10 0.70 0.8765201 0.14711644
## 11 0.75 0.9726334 0.15792312
with(mean.difference, {
plot(quantile, est, pch=19, ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)),
axes=FALSE, ylab= "Mean difference", xlab = "Joint quantile")
segments(x0=quantile, x1=quantile, y0 = est - 1.96*sd, y1 = est + 1.96*sd)
abline(h=0)
axis(1)
axis(2)
box(bty='l')
})
These results parallel previous session and are given here without comment, other than to note that no fixed effects (X variables) are included, and that it is useful to check the posterior inclusion probabilities to ensure they are similar across chains.
set.seed(111)
system.time(kmfitbma.list <- suppressWarnings(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = TRUE)))
## Chain 1
## Iteration: 100 (10% completed; 0.19012 secs elapsed)
## Iteration: 200 (20% completed; 0.36007 secs elapsed)
## Iteration: 300 (30% completed; 0.50004 secs elapsed)
## Iteration: 400 (40% completed; 0.65123 secs elapsed)
## Iteration: 500 (50% completed; 0.79919 secs elapsed)
## Iteration: 600 (60% completed; 1.05529 secs elapsed)
## Iteration: 700 (70% completed; 1.1879 secs elapsed)
## Iteration: 800 (80% completed; 1.325 secs elapsed)
## Iteration: 900 (90% completed; 1.45481 secs elapsed)
## Iteration: 1000 (100% completed; 1.57788 secs elapsed)
## Chain 2
## Iteration: 100 (10% completed; 0.19908 secs elapsed)
## Iteration: 200 (20% completed; 0.35866 secs elapsed)
## Iteration: 300 (30% completed; 0.50811 secs elapsed)
## Iteration: 400 (40% completed; 0.6734 secs elapsed)
## Iteration: 500 (50% completed; 0.80807 secs elapsed)
## Iteration: 600 (60% completed; 0.95385 secs elapsed)
## Iteration: 700 (70% completed; 1.08942 secs elapsed)
## Iteration: 800 (80% completed; 1.33994 secs elapsed)
## Iteration: 900 (90% completed; 1.47263 secs elapsed)
## Iteration: 1000 (100% completed; 1.5886 secs elapsed)
## Chain 3
## Iteration: 100 (10% completed; 0.2013 secs elapsed)
## Iteration: 200 (20% completed; 0.36381 secs elapsed)
## Iteration: 300 (30% completed; 0.5033 secs elapsed)
## Iteration: 400 (40% completed; 0.65869 secs elapsed)
## Iteration: 500 (50% completed; 0.79246 secs elapsed)
## Iteration: 600 (60% completed; 0.93974 secs elapsed)
## Iteration: 700 (70% completed; 1.0814 secs elapsed)
## Iteration: 800 (80% completed; 1.215 secs elapsed)
## Iteration: 900 (90% completed; 1.35461 secs elapsed)
## Iteration: 1000 (100% completed; 1.48052 secs elapsed)
## Chain 4
## Iteration: 100 (10% completed; 0.1957 secs elapsed)
## Iteration: 200 (20% completed; 0.35616 secs elapsed)
## Iteration: 300 (30% completed; 0.5051 secs elapsed)
## Iteration: 400 (40% completed; 0.65644 secs elapsed)
## Iteration: 500 (50% completed; 0.79253 secs elapsed)
## Iteration: 600 (60% completed; 0.93953 secs elapsed)
## Iteration: 700 (70% completed; 1.0725 secs elapsed)
## Iteration: 800 (80% completed; 1.20335 secs elapsed)
## Iteration: 900 (90% completed; 1.34046 secs elapsed)
## Iteration: 1000 (100% completed; 1.46657 secs elapsed)
## user system elapsed
## 0.065 0.004 1.650
bmadiag = kmbayes_diagnose(kmfitbma.list)
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 500):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.0 2.0 0.0 1.00 675 1304
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 1866 1877
## lambda 4.6 11.5 32.1 14.1 9.6 1.10 39 58
## r1 0.0 0.0 0.1 0.0 0.0 1.18 50 44
## r2 0.0 0.0 0.1 0.0 0.0 1.17 25 61
## r3 0.0 0.0 0.0 0.0 0.0 1.03 117 60
## r4 0.0 0.0 0.1 0.0 0.0 1.05 72 63
## r5 0.0 0.0 0.0 0.0 0.0 1.13 71 81
## sigsq.eps 0.3 0.4 0.5 0.4 0.1 1.01 432 1064
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# posterior exclusion probability of each chain
lapply(kmfitbma.list, function(x) t(ExtractPIPs(x)))
## [[1]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.866" "0.706" "0.350" "0.678" "0.400"
##
## [[2]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.886" "0.940" "0.614" "0.844" "0.724"
##
## [[3]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.764" "0.834" "0.454" "0.736" "0.636"
##
## [[4]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.908" "0.688" "0.476" "0.636" "0.500"
kmfitbma.comb = kmbayes_combine(kmfitbma.list)
summary(kmfitbma.comb)
## Fitted object of class 'bkmrfit'
## Iterations: 4000
## Outcome family: gaussian
## Model fit on: 2021-09-07 13:59:48
## Running time: 1.57833 secs
##
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4611153
## 2 r/delta (overall) 0.3135784
## 3 r/delta (move 1) 0.4030151
## 4 r/delta (move 2) 0.2253240
##
## Parameter estimates (based on iterations 2001-4000):
## param mean sd q_2.5 q_97.5
## 1 beta1 1.97624 0.04434 1.89178 2.06540
## 2 beta2 0.11939 0.08720 -0.05403 0.28919
## 3 sigsq.eps 0.37401 0.09073 0.23451 0.59425
## 4 r1 0.02125 0.01797 0.00000 0.07041
## 5 r2 0.02710 0.02517 0.00000 0.08436
## 6 r3 0.01219 0.02517 0.00000 0.10333
## 7 r4 0.01658 0.01725 0.00000 0.06966
## 8 r5 0.00917 0.01393 0.00000 0.02455
## 9 lambda 14.07624 9.60894 3.71335 37.33710
##
## Posterior inclusion probabilities:
## variable PIP
## 1 z1 0.8560
## 2 z2 0.7920
## 3 z3 0.4735
## 4 z4 0.7235
## 5 z5 0.5650
ExtractPIPs(kmfitbma.comb) # posterior inclusion probabilities
## variable PIP
## 1 z1 0.8560
## 2 z2 0.7920
## 3 z3 0.4735
## 4 z4 0.7235
## 5 z5 0.5650
mean.difference2 <- suppressWarnings(OverallRiskSummaries(fit = kmfitbma.comb, y = y, Z = Z, X = X, qs = seq(0.25, 0.75, by = 0.05),
q.fixed = 0.5, method = "exact"))
mean.difference2
## quantile est sd
## 1 0.25 -0.6366720 0.13485284
## 2 0.30 -0.5089135 0.10837813
## 3 0.35 -0.3335908 0.09907977
## 4 0.40 -0.2439669 0.05375930
## 5 0.45 -0.1425662 0.02649497
## 6 0.50 0.0000000 0.00000000
## 7 0.55 0.1978298 0.04864386
## 8 0.60 0.3023487 0.06120906
## 9 0.65 0.4711916 0.09699283
## 10 0.70 0.8175392 0.16889282
## 11 0.75 0.8970455 0.17107679
with(mean.difference2, {
plot(quantile, est, pch=19, ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)),
axes=FALSE, ylab= "Mean difference", xlab = "Joint quantile")
segments(x0=quantile, x1=quantile, y0 = est - 1.96*sd, y1 = est + 1.96*sd)
abline(h=0)
axis(1)
axis(2)
box(bty='l')
})
bkmrhat
also has ported versions of the native posterior summarization functions to compare how these summaries vary across parallel chains. Note that these should serve as diagnostics, and final posterior inference should be done on the combined chain. The easiest of these functions to demonstrate is the OverallRiskSummaries_parallel
function, which simply runs OverallRiskSummaries
(from the bkmr
package) on each chain and combines the results. Notably, this function fixes the y-axis at zero for the median, so it under-represents overall predictive variation across chains, but captures variation in effect estimates across the chains. Ideally, that variation is negligible - e.g. if you see differences between chains that would result in different interpretations, you should re-fit the model with more iterations. In this example, the results are reasonably consistent across chains, but one might want to run more iterations if, say, the differences seen across the upper error bounds are of such a magnitude as to be practically meaningful.
set.seed(111)
system.time(kmfitbma.list <- suppressWarnings(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = TRUE)))
## Chain 1
## Iteration: 100 (10% completed; 0.14417 secs elapsed)
## Iteration: 200 (20% completed; 0.28502 secs elapsed)
## Iteration: 300 (30% completed; 0.42128 secs elapsed)
## Iteration: 400 (40% completed; 0.56359 secs elapsed)
## Iteration: 500 (50% completed; 0.70225 secs elapsed)
## Iteration: 600 (60% completed; 0.82651 secs elapsed)
## Iteration: 700 (70% completed; 0.98656 secs elapsed)
## Iteration: 800 (80% completed; 1.13531 secs elapsed)
## Iteration: 900 (90% completed; 1.26097 secs elapsed)
## Iteration: 1000 (100% completed; 1.41781 secs elapsed)
## Chain 2
## Iteration: 100 (10% completed; 0.13642 secs elapsed)
## Iteration: 200 (20% completed; 0.27816 secs elapsed)
## Iteration: 300 (30% completed; 0.40544 secs elapsed)
## Iteration: 400 (40% completed; 0.53283 secs elapsed)
## Iteration: 500 (50% completed; 0.68806 secs elapsed)
## Iteration: 600 (60% completed; 0.81086 secs elapsed)
## Iteration: 700 (70% completed; 0.9592 secs elapsed)
## Iteration: 800 (80% completed; 1.11268 secs elapsed)
## Iteration: 900 (90% completed; 1.24525 secs elapsed)
## Iteration: 1000 (100% completed; 1.38483 secs elapsed)
## Chain 3
## Iteration: 100 (10% completed; 0.26307 secs elapsed)
## Iteration: 200 (20% completed; 0.39403 secs elapsed)
## Iteration: 300 (30% completed; 0.52395 secs elapsed)
## Iteration: 400 (40% completed; 0.66127 secs elapsed)
## Iteration: 500 (50% completed; 0.79387 secs elapsed)
## Iteration: 600 (60% completed; 0.94467 secs elapsed)
## Iteration: 700 (70% completed; 1.08065 secs elapsed)
## Iteration: 800 (80% completed; 1.2203 secs elapsed)
## Iteration: 900 (90% completed; 1.37123 secs elapsed)
## Iteration: 1000 (100% completed; 1.48441 secs elapsed)
## Chain 4
## Iteration: 100 (10% completed; 0.1542 secs elapsed)
## Iteration: 200 (20% completed; 0.28201 secs elapsed)
## Iteration: 300 (30% completed; 0.41588 secs elapsed)
## Iteration: 400 (40% completed; 0.54944 secs elapsed)
## Iteration: 500 (50% completed; 0.6911 secs elapsed)
## Iteration: 600 (60% completed; 0.81824 secs elapsed)
## Iteration: 700 (70% completed; 0.96803 secs elapsed)
## Iteration: 800 (80% completed; 1.1152 secs elapsed)
## Iteration: 900 (90% completed; 1.24305 secs elapsed)
## Iteration: 1000 (100% completed; 1.39212 secs elapsed)
## user system elapsed
## 0.061 0.004 1.644
meandifference_par = OverallRiskSummaries_parallel(kmfitbma.list, y = y, Z = Z, X = X ,qs = seq(0.25, 0.75, by = 0.05), q.fixed = 0.5, method = "exact")
## Chain 1
## Chain 2
## Chain 3
## Chain 4
head(meandifference_par)
## quantile est sd chain
## 1 0.25 -0.6077766 0.13214887 1
## 2 0.30 -0.4883572 0.11148530 1
## 3 0.35 -0.3070528 0.08312737 1
## 4 0.40 -0.2333888 0.04654883 1
## 5 0.45 -0.1452090 0.02761519 1
## 6 0.50 0.0000000 0.00000000 1
nchains = length(unique(meandifference_par$chain))
with(meandifference_par, {
plot.new()
plot.window(ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)),
xlim=c(min(quantile), max(quantile)),
ylab= "Mean difference", xlab = "Joint quantile")
for(cch in seq_len(nchains)){
width = diff(quantile)[1]
jit = runif(1, -width/5, width/5)
points(jit+quantile[chain==cch], est[chain==cch], pch=19, col=cch)
segments(x0=jit+quantile[chain==cch], x1=jit+quantile[chain==cch], y0 = est[chain==cch] - 1.96*sd[chain==cch], y1 = est[chain==cch] + 1.96*sd[chain==cch], col=cch)
}
abline(h=0)
axis(1)
axis(2)
box(bty='l')
legend("bottom", col=1:nchains, pch=19, lty=1, legend=paste("chain", 1:nchains), bty="n")
})
regfuns_par = PredictorResponseUnivar_parallel(kmfitbma.list, y = y, Z = Z, X = X ,qs = seq(0.25, 0.75, by = 0.05), q.fixed = 0.5, method = "exact")
## Chain 1
## Chain 2
## Chain 3
## Chain 4
head(regfuns_par)
## variable z est se chain
## 1 z1 -2.186199 -1.327148 0.8377480 1
## 2 z1 -2.082261 -1.275372 0.8048473 1
## 3 z1 -1.978323 -1.222335 0.7718192 1
## 4 z1 -1.874385 -1.168099 0.7386821 1
## 5 z1 -1.770446 -1.112734 0.7054585 1
## 6 z1 -1.666508 -1.056314 0.6721764 1
nchains = length(unique(meandifference_par$chain))
# single variable
with(regfuns_par[regfuns_par$variable=="z1",], {
plot.new()
plot.window(ylim=c(min(est - 1.96*se), max(est + 1.96*se)),
xlim=c(min(z), max(z)),
ylab= "Predicted Y", xlab = "Z")
pc = c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
pc2 = c("#0000001A", "#E69F001A", "#56B4E91A", "#009E731A", "#F0E4421A", "#0072B21A", "#D55E001A", "#CC79A71A", "#9999991A")
for(cch in seq_len(nchains)){
ribbonX = c(z[chain==cch], rev(z[chain==cch]))
ribbonY = c(est[chain==cch] + 1.96*se[chain==cch], rev(est[chain==cch] - 1.96*se[chain==cch]))
polygon(x=ribbonX, y = ribbonY, col=pc2[cch], border=NA)
lines(z[chain==cch], est[chain==cch], pch=19, col=pc[cch])
}
axis(1)
axis(2)
box(bty='l')
legend("bottom", col=1:nchains, pch=19, lty=1, legend=paste("chain", 1:nchains), bty="n")
})
Sometimes you just need to run more samples in an existing chain. For example, you run a bkmr fit for 3 days, only to find you don't have enough samples. A “continued” fit just means that you can start off at the last iteration you were at and just keep building on an existing set of results by lengthening the Markov chain. Unfortunately, due to how the kmbayes
function accepts starting values (for the official install version), you can't quite do this exactly in many cases (The function will relay a message and possible solutions, if any. bkmr
package authors are aware of this issue). The kmbayes_continue
function continues a bkmr
fit as well as the bkmr
package will allow. The r
parameters from the fit must all be initialized at the same value, so kmbayes_continue
starts a new MCMC fit at the final values of all parameters from the prior bkmr fit, but sets all of the r
parameters to the mean at the last iteration from the previous fit. Additionally, if h.hat
parameters are estimated, these are fixed to be above zero to meet similar constraints, either by fixing them at their posterior mean or setting to a small positive value. One should inspect trace plots to see whether this will cause issues (e.g. if the traceplots demonstrate different patterns in the samples before and after the continuation). Here's an example with a quick check of diagnostics of the first part of the chain, and the combined chain (which could be used for inference or extended again, if necessary). We caution users that this function creates 2 distinct, if very similar Markov chains, and to use appropriate caution if traceplots differ before and after each continuation. Nonetheless, in many cases one can act as though all samples are from the same Markov chain.
Note that if you install the developmental version of the bkmr
package you can continue fits from exactly where they left off, so you get a true, single Markov chain. You can install that via the commented code below
# install dev version of bkmr to allow true continued fits.
#install.packages("devtools")
#devtools::install_github("jenfb/bkmr")
set.seed(111)
# run 100 initial iterations for a model with only 2 exposures
Z2 = Z[,1:2]
kmfitbma.start <- suppressWarnings(kmbayes(y = y, Z = Z2, X = X, iter = 500, verbose = FALSE, varsel = FALSE))
## Iteration: 50 (10% completed; 0.06677 secs elapsed)
## Iteration: 100 (20% completed; 0.13669 secs elapsed)
## Iteration: 150 (30% completed; 0.19486 secs elapsed)
## Iteration: 200 (40% completed; 0.25969 secs elapsed)
## Iteration: 250 (50% completed; 0.31457 secs elapsed)
## Iteration: 300 (60% completed; 0.38066 secs elapsed)
## Iteration: 350 (70% completed; 0.44475 secs elapsed)
## Iteration: 400 (80% completed; 0.52298 secs elapsed)
## Iteration: 450 (90% completed; 0.58524 secs elapsed)
## Iteration: 500 (100% completed; 0.65066 secs elapsed)
kmbayes_diag(kmfitbma.start)
## Single chain
## Inference for the input samples (1 chains: each with iter = 500; warmup = 250):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.0 2.0 0.0 1.02 184 238
## beta2 0.0 0.1 0.2 0.1 0.1 1.00 259 197
## lambda 6.4 15.0 33.2 17.0 8.2 1.03 32 38
## r1 0.0 0.0 0.0 0.0 0.0 1.20 5 16
## r2 0.0 0.0 0.1 0.0 0.0 1.01 26 28
## sigsq.eps 0.3 0.4 0.5 0.4 0.1 1.00 195 157
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
## mean se_mean sd 2.5% 25%
## beta1 1.96002469 0.002996717 0.04071499 1.88732666 1.93049212
## beta2 0.09934846 0.005601868 0.09100445 -0.06425935 0.03866907
## lambda 17.03096186 1.480162990 8.16066439 5.99261263 10.62650587
## r1 0.02445179 0.006963033 0.01503557 0.01009566 0.01650405
## r2 0.02931652 0.003839555 0.02294888 0.01050284 0.01300914
## sigsq.eps 0.38177766 0.006062528 0.08483668 0.25337749 0.32965181
## 50% 75% 97.5% n_eff Rhat valid Q5
## beta1 1.96168526 1.98761915 2.04042002 192 1.0174270 1 1.89496869
## beta2 0.09494353 0.15892943 0.26004214 263 0.9979147 1 -0.04566104
## lambda 14.97716478 22.17297877 35.85230190 32 1.0322926 1 6.39370841
## r1 0.02078562 0.03098429 0.04827171 12 1.1969882 1 0.01009566
## r2 0.02369039 0.03766762 0.09155390 38 1.0072685 1 0.01213386
## sigsq.eps 0.36834066 0.42589940 0.61832629 187 0.9989252 1 0.26493273
## Q50 Q95 MCSE_Q2.5 MCSE_Q25 MCSE_Q50
## beta1 1.96168526 2.02239974 0.0061439984 0.003896226 0.004052288
## beta2 0.09494353 0.24851228 0.0125216874 0.009226277 0.006555085
## lambda 14.97716478 33.24311222 1.2308679568 0.873763048 1.767460797
## r1 0.02078562 0.04809447 0.0006957382 0.002532962 0.004430201
## r2 0.02369039 0.07473096 0.0011501569 0.001816037 0.004755380
## sigsq.eps 0.36834066 0.53371298 0.0112954790 0.004960890 0.004978161
## MCSE_Q75 MCSE_Q97.5 MCSE_SD Bulk_ESS Tail_ESS
## beta1 0.001302716 0.010539232 0.002122356 184 238
## beta2 0.008043386 0.008643991 0.003965505 259 197
## lambda 3.451601827 1.237140963 1.056843473 32 38
## r1 0.006131920 0.001462048 0.005269340 5 16
## r2 0.002075801 0.013178244 0.002737455 26 28
## sigsq.eps 0.006630407 0.039435542 0.004293256 195 157
# run 2000 additional iterations
moreiterations = kmbayes_continue(kmfitbma.start, iter=2000)
## Modifying r starting values to meet kmbayes initial value constraints (this isn't a perfect continuation)
## This issue can be fixed by updating bkmr to the development version via: install.packages('devtools'); devtools::install_github('jenfb/bkmr')
## Validating control.params...
## Validating starting.values...
## r should be a vector of length equal to the number of columns of Z. Input will be repeated or truncated as necessary.
## Iteration: 201 (10% completed; 0.29177 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.45
## 2 r1 0.23
## 3 r2 0.24
## Iteration: 401 (20% completed; 0.57083 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4475
## 2 r1 0.2075
## 3 r2 0.2675
## Iteration: 601 (30% completed; 0.82862 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4766667
## 2 r1 0.1850000
## 3 r2 0.2700000
## Iteration: 801 (40% completed; 1.08158 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.49625
## 2 r1 0.16500
## 3 r2 0.25875
## Iteration: 1001 (50% completed; 1.34625 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.499
## 2 r1 0.176
## 3 r2 0.251
## Iteration: 1201 (60% completed; 1.63372 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4966667
## 2 r1 0.1741667
## 3 r2 0.2533333
## Iteration: 1401 (70% completed; 1.94303 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4957143
## 2 r1 0.1800000
## 3 r2 0.2571429
## Iteration: 1601 (80% completed; 2.38216 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.490000
## 2 r1 0.181250
## 3 r2 0.269375
## Iteration: 1801 (90% completed; 2.6519 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4800000
## 2 r1 0.1811111
## 3 r2 0.2750000
## Iteration: 2001 (100% completed; 2.91589 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4905
## 2 r1 0.1780
## 3 r2 0.2770
kmbayes_diag(moreiterations)
## Single chain
## Inference for the input samples (1 chains: each with iter = 2500; warmup = 1250):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.0 2.0 0.0 1.00 1219 1219
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 1194 1111
## lambda 3.9 11.8 28.7 14.0 9.3 1.01 80 62
## r1 0.0 0.0 0.1 0.0 0.0 1.01 67 56
## r2 0.0 0.0 0.1 0.0 0.0 1.03 52 39
## sigsq.eps 0.3 0.4 0.6 0.4 0.1 1.00 592 1072
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
## mean se_mean sd 2.5% 25% 50%
## beta1 1.95732832 0.001361673 0.04758648 1.86055186 1.92649484 1.95626700
## beta2 0.10593017 0.002658415 0.09175898 -0.07535931 0.04494582 0.10497201
## lambda 13.95848096 0.974444125 9.26601179 3.67843043 8.53661507 11.80221444
## r1 0.02388718 0.001815641 0.01576315 0.01026043 0.01314430 0.01785491
## r2 0.03816322 0.004231446 0.03411368 0.01083829 0.01615244 0.02628364
## sigsq.eps 0.40307782 0.003681697 0.08799969 0.26548543 0.33873822 0.39310969
## 75% 97.5% n_eff Rhat valid Q5 Q50
## beta1 1.98861636 2.04891445 1218 1.000833 1 1.87703291 1.95626700
## beta2 0.16589538 0.29210556 1194 1.001112 1 -0.04661205 0.10497201
## lambda 17.17489007 36.48067829 93 1.009083 1 3.86142689 11.80221444
## r1 0.02913901 0.06679262 80 1.009978 1 0.01085156 0.01785491
## r2 0.04640817 0.13181877 83 1.026150 1 0.01085182 0.02628364
## sigsq.eps 0.45450760 0.60061908 571 1.000022 1 0.27983410 0.39310969
## Q95 MCSE_Q2.5 MCSE_Q25 MCSE_Q50 MCSE_Q75
## beta1 2.03734964 0.0021170461 0.0019635347 0.001741340 0.002094780
## beta2 0.25184715 0.0070705251 0.0029934904 0.003334672 0.003703081
## lambda 28.67942227 0.8743516725 0.7130727208 0.896066298 0.858318166
## r1 0.05476273 0.0005792735 0.0006850018 0.001656574 0.003657957
## r2 0.11010608 0.0001766024 0.0015788293 0.002436589 0.006589668
## sigsq.eps 0.56269045 0.0031279911 0.0030804472 0.004000827 0.007075587
## MCSE_Q97.5 MCSE_SD Bulk_ESS Tail_ESS
## beta1 0.002408374 0.000963078 1219 1219
## beta2 0.005012008 0.001902498 1194 1111
## lambda 5.677054364 0.727212084 80 62
## r1 0.006089716 0.001288853 67 56
## r2 0.011876064 0.003005615 52 39
## sigsq.eps 0.010455310 0.002604683 592 1072
TracePlot(moreiterations, par="beta")
TracePlot(moreiterations, par="r")
Thanks to Haotian “Howie” Wu for invaluable feedback on early versions of the package.