Introduction to bkmr and bkmrhat

bkmr is a package to implement Bayesian kernel machine regression (BKMR) using Markov chain Monte Carlo (MCMC). Notably, bkmr is missing some key features in Bayesian inference and MCMC diagnostics: 1) no facility for running multiple chains in parallel 2) no inference across multiple chains 3) limited posterior summary of parameters 4) limited diagnostics. The bkmrhat package is a lightweight set of function that fills in each of those gaps by enabling post-processing of bkmr output in other packages and building a small framework for parallel processing.

How to use the bkmrhat package

  1. Fit a BKMR model for a single chain using the kmbaryes function from bkmr, or use multiple parallel chains kmbayes_parallel from bkmrhat
  2. Perform diagnostics on single or multiple chains using the kmbayes_diagnose function (uses functions from the rstan package) OR convert the BKMR fit(s) to mcmc (one chain) or mcmc.list (multiple chains) objects from the coda package using as.mcmc or as.mcmc.list from the bkmrhat package. The coda package has a whole host of inference and diagnostic procedures (but may lag behind some of the diagnostics functions from rstan).
  3. Perform posterior summaries using coda functions or combine chains from a kmbayes_parallel fit using kmbayes_combine. Final posterior inferences can be made on the combined object, which enables use of bkmr package functions for visual summaries of independent and joint effects of exposures in the bkmr model.

First, simulate some data from the bkmr function

library("bkmr")
library("bkmrhat")
library("coda")

set.seed(111)
dat <- bkmr::SimData(n = 50, M = 5, ind=1:3, Zgen="realistic")
y <- dat$y
Z <- dat$Z
X <- cbind(dat$X, rnorm(50))
head(cbind(y,Z,X))
##               y          z1          z2          z3          z4         z5
## [1,]  4.1379128 -0.06359282 -0.02996246 -0.14190647 -0.44089352 -0.1878732
## [2,] 12.0843607 -0.07308834  0.32021690  1.08838691  0.29448354 -1.4609837
## [3,]  7.8859254  0.59604857  0.20602329  0.46218114 -0.03387906 -0.7615902
## [4,]  1.1609768  1.46504863  2.48389356  1.39869461  1.49678590  0.2837234
## [5,]  0.4989372 -0.37549639  0.01159884  1.17891641 -0.05286516 -0.1680664
## [6,]  5.0731242 -0.36904566 -0.49744932 -0.03330522  0.30843805  0.6814844
##                           
## [1,]  1.0569172 -1.0503824
## [2,]  4.8158570  0.3251424
## [3,]  2.6683461 -2.1048716
## [4,] -0.7492096 -0.9551027
## [5,] -0.5428339 -0.5306399
## [6,]  1.6493251  0.8274405

Example 1: single vs multi-chains

There is some overhead in parallel processing when using the future package, so the payoff when using parallel processing may vary by the problem. Here it is about a 2-4x speedup, but you can see more benefit at higher iterations. Note that this may not yield as many usable iterations as a single large chain if a substantial burnin period is needed, but it will enable useful convergence diagnostics. Note that the future package can implement sequential processing, which effectively turns the kmbayes_parallel into a loop, but still has all other advantages of multiple chains.

# enable parallel processing (up to 4 simultaneous processes here)
future::plan(strategy = future::multisession)

# single run of 4000 observations from bkmr package
set.seed(111)
system.time(kmfit <- suppressMessages(kmbayes(y = y, Z = Z, X = X, iter = 4000, verbose = FALSE, varsel = FALSE)))
##    user  system elapsed 
##   4.259   0.264   4.531
# 4 runs of 1000 observations from bkmrhat package
set.seed(111)
system.time(kmfit5 <- suppressMessages(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = FALSE)))
## Chain 1 
## Chain 2 
## Chain 3 
## Chain 4
##    user  system elapsed 
##   0.042   0.003   2.108

Example 2: Diagnostics

The diagnostics from the rstan package come from the monitor function (see the help files for that function in the rstan pacakge)

# Using rstan functions (set burnin/warmup to zero for comparability with coda numbers given later
#  posterior summaries should be performed after excluding warmup/burnin)
singlediag = kmbayes_diagnose(kmfit, warmup=0, digits_summary=2)
## Single chain
## Inference for the input samples (1 chains: each with iter = 4000; warmup = 0):
## 
##            Q5  Q50  Q95 Mean  SD  Rhat Bulk_ESS Tail_ESS
## beta1     1.9  2.0  2.1  2.0 0.0  1.00     2820     3194
## beta2     0.0  0.1  0.3  0.1 0.1  1.00     3739     3535
## lambda    3.9 10.0 22.3 11.2 5.9  1.00      346      222
## r1        0.0  0.0  0.1  0.0 0.1  1.01      129      173
## r2        0.0  0.0  0.1  0.0 0.1  1.00      182      181
## r3        0.0  0.0  0.0  0.0 0.0  1.01      158      112
## r4        0.0  0.0  0.1  0.0 0.1  1.03      176      135
## r5        0.0  0.0  0.0  0.0 0.1  1.00      107      114
## sigsq.eps 0.2  0.3  0.5  0.4 0.1  1.00     1262     1563
## 
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of 
## effective sample size for bulk and tail quantities respectively (an ESS > 100 
## per chain is considered good), and Rhat is the potential scale reduction 
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# Using rstan functions (multiple chains enable R-hat)
multidiag = kmbayes_diagnose(kmfit5, warmup=0, digits_summary=2)
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 0):
## 
##            Q5 Q50  Q95 Mean  SD  Rhat Bulk_ESS Tail_ESS
## beta1     1.9 2.0  2.1  2.0 0.0  1.00     1951     1652
## beta2     0.0 0.1  0.3  0.1 0.1  1.00     3204     2578
## lambda    4.5 9.7 24.0 11.4 6.5  1.01      359      510
## r1        0.0 0.0  0.1  0.1 0.2  1.02      133       66
## r2        0.0 0.0  0.2  0.1 0.2  1.02      116       71
## r3        0.0 0.0  0.1  0.0 0.2  1.02       87       92
## r4        0.0 0.0  0.1  0.0 0.1  1.03      119       78
## r5        0.0 0.0  0.5  0.1 0.2  1.07       49       44
## sigsq.eps 0.2 0.3  0.5  0.3 0.1  1.01      655      431
## 
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of 
## effective sample size for bulk and tail quantities respectively (an ESS > 100 
## per chain is considered good), and Rhat is the potential scale reduction 
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# using coda functions, not using any burnin (for demonstration only)
kmfitcoda = as.mcmc(kmfit, iterstart = 1)
kmfit5coda = as.mcmc.list(kmfit5, iterstart = 1)

# single chain trace plot
traceplot(kmfitcoda)

plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1plot of chunk diagnostics_1 The trace plots look typical, and fine, but trace plots don't give a full picture of convergence. Note that there is apparent quick convergence for a couple of parameters demonstrated by movement away from the starting value and concentration of the rest of the samples within a narrow band.

Seeing visual evidence that different chains are sampling from the same marginal distributions is reassuring about the stability of the results.

# multiple chain trace plot
traceplot(kmfit5coda)

plot of chunk diagnostics_2plot of chunk diagnostics_2plot of chunk diagnostics_2plot of chunk diagnostics_2plot of chunk diagnostics_2plot of chunk diagnostics_2plot of chunk diagnostics_2