In this example we use the package to infer the bias and coefficients in a logistic regression model using stochastic gradient Langevin Dynamics with control variates. We assume we have data \(\mathbf x_1, \dots, \mathbf x_N\) and response variables \(y_1, \dots, y_N\) with likelihood \[ p(\mathbf X, \mathbf y | \beta, \beta_0 ) = \prod_{i=1}^N \left[ \frac{1}{1+e^{-\beta_0 + \mathbf x_i \beta}} \right]^{y_i} \left[ 1 - \frac{1}{1+e^{-\beta_0 + \mathbf x_i \beta}} \right]^{1-y_i} \]
First let’s load in the data, we will use the cover type dataset commonly used to benchmark classification models. We use the dataset from LIBSVM, which transforms the problem from multiclass to binary.
library(sgmcmc)
# Load data from package
data("covertype")
First we’ll remove about 10000 observations from the original dataset to form a test set, this will be used to check the validity of the algorithm. Then we’ll separate out the response variable y
and the explanatory variables X
. The response variable is the first column in the dataset.
testObservations = sample(nrow(covertype), 10^4)
testSet = covertype[testObservations,]
X = covertype[-c(testObservations),2:ncol(covertype)]
y = covertype[-c(testObservations),1]
dataset = list( "X" = X, "y" = y )
In the last line we defined the dataset as it will be input to the relevant sgmcmc
function. A lot of the inputs to functions in sgmcmc
are defined as lists. This improves flexibility by enabling models to be specified with multiple parameters, datasets and allows separate tuning constants to be set for each parameter. We assume that observations are always accessed on the first dimension of each object, i.e. the point \(x_i\) is located at X[i,]
rather than X[,i]
. Similarly the observation \(i\) from a 3d object Y
would be located at Y[i,,]
.
Now we want to set the starting values and shapes for our parameters. We can see from the likelihood equation we have two parameters, the bias \(\beta_0\) and the coefficients \(\beta\). We’ll just set these to start from zero. Similar to the data, these are just a list with the relevant names.
# Get the dimension of X, needed to set shape of params$beta
d = ncol(dataset$X)
params = list( "bias" = 0, "beta" = matrix( rep( 0, d ), nrow = d ) )
Now we’ll define the functions logLik
and logPrior
. It should now become clear why the list names come in handy. The function logLik
should take two parameters as input: params
and dataset
. These parameters will be lists with the same names as those you defined for params
and dataset
earlier. There is one difference though, the objects in the lists will have automatically been converted to TensorFlow
objects for you. The params
list will contain TensorFlow
tensor variables; the dataset
list will contain TensorFlow
placeholders. The logLik
function should take these lists as input and return the value of the log likelihood as a tensor at point params
given data dataset
. The function should do this using TensorFlow
operations, as this allows the gradient to be automatically calculated; it also allows the wide range of distribution objects as well as matrix operations that TensorFlow
provides to be taken advantage of. A tutorial of TensorFlow
for R
is beyond the scope of this article, for more details we refer the reader to the website of TensorFlow for R. With this in place we can define the logLik
function as follows
logLik = function(params, dataset) {
yEstimated = 1 / (1 + tf$exp( - tf$squeeze(params$bias + tf$matmul(dataset$X, params$beta))))
logLik = tf$reduce_sum(dataset$y * tf$log(yEstimated) + (1 - dataset$y) * tf$log(1 - yEstimated))
return(logLik)
}
Next we want to define our log-prior density, we assume each \(\beta_i\) has an independent Laplace prior distribution, with location 0 and scale 1, so that \(\log p( \beta ) = - \sum_{i=0}^d | \beta_i|\). Similar to the log-likelihood function, the log-prior density is defined as a function with input params
. In our case the definition is
logPrior = function(params) {
logPrior = - (tf$reduce_sum(tf$abs(params$beta)) + tf$reduce_sum(tf$abs(params$bias)))
return(logPrior)
}
Finally, we’ll set the stepsize parameters for the algorithm, along with the minibatch size. sgldcv
relies on two stepsize parameters, one for the optimization step and one for the MCMC step. To allow stepsizes to be set for different parameters, the form of the stepsizes for the MCMC will be lists with names corresponding to each of the names in params
. The optimization step will just be one value as the stepsize is automatically tuned
stepsizesMCMC = list("beta" = 2e-5, "bias" = 2e-5)
stepsizesOptimization = 1e-1
Alternatively, we can simply use the shortcut stepsizesMCMC = 2e-5
which would set the stepsizes for each parameter to 2e-5
. The optimization step is performed using the TensorFlow
AdamOptimizer
.
Now we can run our SGLD-CV algorithm using the function sgldcv
from the sgmcmc
package, which returns a list of Markov chains for each parameter as output. Use the argument verbose = FALSE
to hide the output of the function. As the dataset size is quite large, we’ll change the minibatchSize
from its default 0.01 * N
to 500
. To allow a small 1000 iteration burn-in we’ll set the number of iterations to be 11000
output = sgldcv(logLik, dataset, params, stepsizesMCMC, stepsizesOptimization, logPrior = logPrior, minibatchSize = 500, nIters = 11000 )
##
## Finding initial MAP estimates...
## Iteration: 100 Log posterior estimate: -318934.15625
## Iteration: 200 Log posterior estimate: -311100.21875
## Iteration: 300 Log posterior estimate: -304136.59375
## Iteration: 400 Log posterior estimate: -288272.03125
## Iteration: 500 Log posterior estimate: -278034.9375
## Iteration: 600 Log posterior estimate: -323016.59375
## Iteration: 700 Log posterior estimate: -283915.875
## Iteration: 800 Log posterior estimate: -315919.34375
## Iteration: 900 Log posterior estimate: -315031.375
## Iteration: 1000 Log posterior estimate: -288968.625
## Iteration: 1100 Log posterior estimate: -291185.53125
## Iteration: 1200 Log posterior estimate: -287866.09375
## Iteration: 1300 Log posterior estimate: -293647.34375
## Iteration: 1400 Log posterior estimate: -310071.84375
## Iteration: 1500 Log posterior estimate: -291897.6875
## Iteration: 1600 Log posterior estimate: -300818.34375
## Iteration: 1700 Log posterior estimate: -305018.8125
## Iteration: 1800 Log posterior estimate: -287657.65625
## Iteration: 1900 Log posterior estimate: -278235.75
## Iteration: 2000 Log posterior estimate: -303246.3125
## Iteration: 2100 Log posterior estimate: -281668.8125
## Iteration: 2200 Log posterior estimate: -308920
## Iteration: 2300 Log posterior estimate: -290777.21875
## Iteration: 2400 Log posterior estimate: -289024
## Iteration: 2500 Log posterior estimate: -290047.40625
## Iteration: 2600 Log posterior estimate: -291038.5625
## Iteration: 2700 Log posterior estimate: -292185.90625
## Iteration: 2800 Log posterior estimate: -295259.53125
## Iteration: 2900 Log posterior estimate: -294429.59375
## Iteration: 3000 Log posterior estimate: -278330.78125
## Iteration: 3100 Log posterior estimate: -282764.84375
## Iteration: 3200 Log posterior estimate: -310849.28125
## Iteration: 3300 Log posterior estimate: -310186.90625
## Iteration: 3400 Log posterior estimate: -298926.96875
## Iteration: 3500 Log posterior estimate: -303079.125
## Iteration: 3600 Log posterior estimate: -294008.40625
## Iteration: 3700 Log posterior estimate: -312734.9375
## Iteration: 3800 Log posterior estimate: -293633.59375
## Iteration: 3900 Log posterior estimate: -276008.75
## Iteration: 4000 Log posterior estimate: -283248.15625
## Iteration: 4100 Log posterior estimate: -289258.1875
## Iteration: 4200 Log posterior estimate: -289803.8125
## Iteration: 4300 Log posterior estimate: -305144.0625
## Iteration: 4400 Log posterior estimate: -292219.1875
## Iteration: 4500 Log posterior estimate: -296130.59375
## Iteration: 4600 Log posterior estimate: -294772.84375
## Iteration: 4700 Log posterior estimate: -298234.53125
## Iteration: 4800 Log posterior estimate: -287088.6875
## Iteration: 4900 Log posterior estimate: -310976.40625
## Iteration: 5000 Log posterior estimate: -300378.5
## Iteration: 5100 Log posterior estimate: -307007.65625
## Iteration: 5200 Log posterior estimate: -312647.78125
## Iteration: 5300 Log posterior estimate: -304682.5
## Iteration: 5400 Log posterior estimate: -280271.6875
## Iteration: 5500 Log posterior estimate: -307953.4375
## Iteration: 5600 Log posterior estimate: -319640.40625
## Iteration: 5700 Log posterior estimate: -289323.59375
## Iteration: 5800 Log posterior estimate: -312788.75
## Iteration: 5900 Log posterior estimate: -284317.3125
## Iteration: 6000 Log posterior estimate: -302825.9375
## Iteration: 6100 Log posterior estimate: -287940.21875
## Iteration: 6200 Log posterior estimate: -276601.59375
## Iteration: 6300 Log posterior estimate: -307636.375
## Iteration: 6400 Log posterior estimate: -291986.96875
## Iteration: 6500 Log posterior estimate: -305340.8125
## Iteration: 6600 Log posterior estimate: -302332.21875
## Iteration: 6700 Log posterior estimate: -299466
## Iteration: 6800 Log posterior estimate: -298791.125
## Iteration: 6900 Log posterior estimate: -297779.21875
## Iteration: 7000 Log posterior estimate: -289950.53125
## Iteration: 7100 Log posterior estimate: -330373.25
## Iteration: 7200 Log posterior estimate: -282641.34375
## Iteration: 7300 Log posterior estimate: -283827.9375
## Iteration: 7400 Log posterior estimate: -289006.65625
## Iteration: 7500 Log posterior estimate: -282916.34375
## Iteration: 7600 Log posterior estimate: -295792.21875
## Iteration: 7700 Log posterior estimate: -283200.34375
## Iteration: 7800 Log posterior estimate: -283077.9375
## Iteration: 7900 Log posterior estimate: -281539.96875
## Iteration: 8000 Log posterior estimate: -297851.03125
## Iteration: 8100 Log posterior estimate: -295451
## Iteration: 8200 Log posterior estimate: -283193.9375
## Iteration: 8300 Log posterior estimate: -285842.03125
## Iteration: 8400 Log posterior estimate: -293102.65625
## Iteration: 8500 Log posterior estimate: -288639.90625
## Iteration: 8600 Log posterior estimate: -284451.9375
## Iteration: 8700 Log posterior estimate: -291608.84375
## Iteration: 8800 Log posterior estimate: -291328.71875
## Iteration: 8900 Log posterior estimate: -311649.3125
## Iteration: 9000 Log posterior estimate: -302892.8125
## Iteration: 9100 Log posterior estimate: -290860.4375
## Iteration: 9200 Log posterior estimate: -285636.34375
## Iteration: 9300 Log posterior estimate: -295014.34375
## Iteration: 9400 Log posterior estimate: -299942.1875
## Iteration: 9500 Log posterior estimate: -295133.34375
## Iteration: 9600 Log posterior estimate: -308232.9375
## Iteration: 9700 Log posterior estimate: -278838.59375
## Iteration: 9800 Log posterior estimate: -282269.53125
## Iteration: 9900 Log posterior estimate: -304174.8125
## Iteration: 10000 Log posterior estimate: -309968.5
##
## Sampling using SGMCMC...
## Iteration: 100 Log posterior estimate: -290171.03125
## Iteration: 200 Log posterior estimate: -279223.8125
## Iteration: 300 Log posterior estimate: -278308.4375
## Iteration: 400 Log posterior estimate: -279897.9375
## Iteration: 500 Log posterior estimate: -283376.3125
## Iteration: 600 Log posterior estimate: -317401.28125
## Iteration: 700 Log posterior estimate: -295789.21875
## Iteration: 800 Log posterior estimate: -310984.4375
## Iteration: 900 Log posterior estimate: -279781.84375
## Iteration: 1000 Log posterior estimate: -299323.03125
## Iteration: 1100 Log posterior estimate: -298043.65625
## Iteration: 1200 Log posterior estimate: -289478.84375
## Iteration: 1300 Log posterior estimate: -305483.5625
## Iteration: 1400 Log posterior estimate: -287282.75
## Iteration: 1500 Log posterior estimate: -282191.40625
## Iteration: 1600 Log posterior estimate: -309234.09375
## Iteration: 1700 Log posterior estimate: -281940.0625
## Iteration: 1800 Log posterior estimate: -303067.40625
## Iteration: 1900 Log posterior estimate: -299208.5625
## Iteration: 2000 Log posterior estimate: -283990.34375
## Iteration: 2100 Log posterior estimate: -277591.9375
## Iteration: 2200 Log posterior estimate: -299869.28125
## Iteration: 2300 Log posterior estimate: -272720.03125
## Iteration: 2400 Log posterior estimate: -280149.03125
## Iteration: 2500 Log posterior estimate: -287280.03125
## Iteration: 2600 Log posterior estimate: -291755.59375
## Iteration: 2700 Log posterior estimate: -283403.4375
## Iteration: 2800 Log posterior estimate: -302196.125
## Iteration: 2900 Log posterior estimate: -292763.46875
## Iteration: 3000 Log posterior estimate: -281877.9375
## Iteration: 3100 Log posterior estimate: -333051.3125
## Iteration: 3200 Log posterior estimate: -298832.09375
## Iteration: 3300 Log posterior estimate: -288650.34375
## Iteration: 3400 Log posterior estimate: -289696.53125
## Iteration: 3500 Log posterior estimate: -286719.65625
## Iteration: 3600 Log posterior estimate: -300798.28125
## Iteration: 3700 Log posterior estimate: -297927.84375
## Iteration: 3800 Log posterior estimate: -291033.34375
## Iteration: 3900 Log posterior estimate: -275082.15625
## Iteration: 4000 Log posterior estimate: -286376.1875
## Iteration: 4100 Log posterior estimate: -324981.34375
## Iteration: 4200 Log posterior estimate: -279220.78125
## Iteration: 4300 Log posterior estimate: -282712.6875
## Iteration: 4400 Log posterior estimate: -268535.65625
## Iteration: 4500 Log posterior estimate: -290500.875
## Iteration: 4600 Log posterior estimate: -313894.1875
## Iteration: 4700 Log posterior estimate: -290789
## Iteration: 4800 Log posterior estimate: -266297.6875
## Iteration: 4900 Log posterior estimate: -311418.59375
## Iteration: 5000 Log posterior estimate: -305157.3125
## Iteration: 5100 Log posterior estimate: -310085.53125
## Iteration: 5200 Log posterior estimate: -294280.625
## Iteration: 5300 Log posterior estimate: -271685.03125
## Iteration: 5400 Log posterior estimate: -297813.25
## Iteration: 5500 Log posterior estimate: -278053.9375
## Iteration: 5600 Log posterior estimate: -289355.15625
## Iteration: 5700 Log posterior estimate: -309618.65625
## Iteration: 5800 Log posterior estimate: -298705.8125
## Iteration: 5900 Log posterior estimate: -286467.96875
## Iteration: 6000 Log posterior estimate: -332059.78125
## Iteration: 6100 Log posterior estimate: -294429.3125
## Iteration: 6200 Log posterior estimate: -294467.53125
## Iteration: 6300 Log posterior estimate: -283324.78125
## Iteration: 6400 Log posterior estimate: -300902.09375
## Iteration: 6500 Log posterior estimate: -284868.625
## Iteration: 6600 Log posterior estimate: -293198.96875
## Iteration: 6700 Log posterior estimate: -289145.625
## Iteration: 6800 Log posterior estimate: -295144.5625
## Iteration: 6900 Log posterior estimate: -267968.5
## Iteration: 7000 Log posterior estimate: -304658.15625
## Iteration: 7100 Log posterior estimate: -308917.3125
## Iteration: 7200 Log posterior estimate: -322541.1875
## Iteration: 7300 Log posterior estimate: -291500.15625
## Iteration: 7400 Log posterior estimate: -299093.03125
## Iteration: 7500 Log posterior estimate: -272092.375
## Iteration: 7600 Log posterior estimate: -278800.78125
## Iteration: 7700 Log posterior estimate: -286997.1875
## Iteration: 7800 Log posterior estimate: -284155.875
## Iteration: 7900 Log posterior estimate: -293022.46875
## Iteration: 8000 Log posterior estimate: -284036.3125
## Iteration: 8100 Log posterior estimate: -294387.4375
## Iteration: 8200 Log posterior estimate: -299493.6875
## Iteration: 8300 Log posterior estimate: -295624.5
## Iteration: 8400 Log posterior estimate: -283145.8125
## Iteration: 8500 Log posterior estimate: -311274.75
## Iteration: 8600 Log posterior estimate: -282711.8125
## Iteration: 8700 Log posterior estimate: -273210.59375
## Iteration: 8800 Log posterior estimate: -301488.28125
## Iteration: 8900 Log posterior estimate: -315557.84375
## Iteration: 9000 Log posterior estimate: -296773.5
## Iteration: 9100 Log posterior estimate: -295796.8125
## Iteration: 9200 Log posterior estimate: -321198.09375
## Iteration: 9300 Log posterior estimate: -315938.75
## Iteration: 9400 Log posterior estimate: -299486.5625
## Iteration: 9500 Log posterior estimate: -280255.4375
## Iteration: 9600 Log posterior estimate: -277771.90625
## Iteration: 9700 Log posterior estimate: -302288.46875
## Iteration: 9800 Log posterior estimate: -294178.75
## Iteration: 9900 Log posterior estimate: -285188.96875
## Iteration: 10000 Log posterior estimate: -299911.75
## Iteration: 10100 Log posterior estimate: -281869.3125
## Iteration: 10200 Log posterior estimate: -264595.375
## Iteration: 10300 Log posterior estimate: -270510.53125
## Iteration: 10400 Log posterior estimate: -295076.34375
## Iteration: 10500 Log posterior estimate: -300016.78125
## Iteration: 10600 Log posterior estimate: -285350.21875
## Iteration: 10700 Log posterior estimate: -279112.71875
## Iteration: 10800 Log posterior estimate: -291103.84375
## Iteration: 10900 Log posterior estimate: -288083.71875
## Iteration: 11000 Log posterior estimate: -336730.34375
To check the algorithm converged, we’ll plot the average log-predictive density of the data from our test set every 10 iterations. Let \[\hat \pi_i^{(j)} := \frac{1}{1 + \exp\left[-\beta_0^{(j)} - \mathbf x_i \beta^{(j)}\right]},\] here \(\hat \pi_i^{(j)}\) denotes the probability that the \(j^{\text{th}}\) iteration of our MCMC chain assigned to observation \(i\) is in our test set. Define our test set by \(T\), the average log-predictive density is given by \[A := \frac{1}{|T|} \sum_{y_i \in T} \left[ y_i \log \hat \pi_i^{(j)} + (1 - y_i) \log(1 - \hat \pi_i^{(j)}) \right]\]
To check convergence, we’ll plot the average log-predictive density every 10 iterations as follows
yTest = testSet[,1]
XTest = testSet[,2:ncol(testSet)]
# Remove burn-in
output$bias = output$bias[-c(1:1000)]
output$beta = output$beta[-c(1:1000),,]
iterations = seq(from = 1, to = 10^4, by = 10)
avLogPred = rep(0, length(iterations))
# Calculate log predictive every 10 iterations
for ( iter in 1:length(iterations) ) {
j = iterations[iter]
# Get parameters at iteration j
beta0_j = output$bias[j]
beta_j = output$beta[j,]
for ( i in 1:length(yTest) ) {
pihat_ij = 1 / (1 + exp(- beta0_j - sum(XTest[i,] * beta_j)))
y_i = yTest[i]
# Calculate log predictive at current test set point
LogPred_curr = - (y_i * log(pihat_ij) + (1 - y_i) * log(1 - pihat_ij))
avLogPred[iter] = avLogPred[iter] + 1 / length(yTest) * LogPred_curr
}
}
library(ggplot2)
plotFrame = data.frame("iteration" = iterations, "logPredictive" = avLogPred)
ggplot(plotFrame, aes(x = iteration, y = logPredictive)) +
geom_line() +
ylab("Average log predictive of test set")