Using the loo package

The following is excerpted (with minor edits) from our paper Vehtari, A., Gelman, A., and Gabry, J. (2015). Efficient implementation of leave-one-out cross-validation and WAIC for evaluating fitted Bayesian models.


This example comes from a survey of residents from a small area in Bangladesh that was affected by arsenic in drinking water. Respondents with elevated arsenic levels in their wells were asked if they were interested in getting water from a neighbor’s well, and a series of logistic regressions were fit to predict this binary response given various information about the households (Gelman and Hill, 2007). Here we fit a model for the well-switching response given two predictors: the arsenic level of the water in the resident’s home, and the distance of the house from the nearest safe well.

The sample size in this example is \(N=3020\), which is not huge but is large enough that it is important to have a computational method for LOO that is fast for each data point. On the plus side, with such a large dataset, the influence of any given observation is small, and so the computations should be stable.

Stan and R code

Here is the Stan code for fitting the logistic regression model:

'logistic.stan'

data { 
  int N; // number of data points
  int K; // number of predictors (including intercept)
  int<lower=0,upper=1> y[N]; // outcome
  matrix[N,K] X; // predictors (including intercept)
}
parameters {
  vector[K] beta;
}
model {
  y ~ bernoulli_logit(X * beta);
}
generated quantities {
  vector[N] log_lik;
  for (n in 1:N)
    log_lik[n] <- bernoulli_logit_log(y[n], X[n] * beta);
}

We have defined the log likelihood as a vector named log_lik in the generated quantities block so that the individual terms will be saved by Stan. After running Stan, log_lik can be extracted (using the extract_log_lik function provided in the loo package) as an \(S \times N\) matrix, where \(S\) is the number of simulations (posterior draws) and \(N\) is the number of data points.


Next we fit the Stan model using the rstan package

library("rstan")

# Prepare data 
url <- "http://stat.columbia.edu/~gelman/arm/examples/arsenic/wells.dat"
wells <- read.table(url)
X <- model.matrix(~I(dist/100) + arsenic, wells) 
data <- list(N = nrow(wells), K = ncol(X), X = X, y = wells$switch)

# Fit model
fit <- stan("logistic.stan", data = data) # defaults to iter = 2000, chains = 4

and then use the loo package to compute LOO and WAIC:

library("loo")

# Extract log-likelihood and compute LOO, WAIC
log_lik <- extract_log_lik(fit)
loo <- loo_and_waic(log_lik)
print(loo)
# loo and waic are the same out to several decimal places for this model
print(loo, digits = 4) 


To compare this model to alternative model for the same data we can use the loo_and_waic_diff function.

# First run a second model using log(arsenic) instead of arsenic
data$X[,"arsenic"] <- log(data$X[,"arsenic"])
fit2 <- stan(fit = fit, data = data)
log_lik2 <- extract_log_lik(fit2)
loo2 <- loo_and_waic(log_lik2)

# Compare using loo_and_waic_diff
diff <- loo_and_waic_diff(loo, loo2)

References

Gelman, A., and Hill, J. (2007). Data Analysis Using Regression and Multilevel Hierarchical Models. Cambridge University Press.

Stan Development Team (2015a). Stan: A C++ library for probability and sampling, version 2.6. http://mc-stan.org

Stan Development Team (2015b). RStan, version 2.6. http://mc-stan.org/rstan.html.