params <-
list(eval = TRUE)

## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(LBBNN)
has_torch <- requireNamespace("torch", quietly = TRUE) &&
            torch::torch_is_installed()

## ----eval=FALSE---------------------------------------------------------------
# if(!requireNamespace("torchvision"))
#   install.packages("torchvision")
# torch::torch_manual_seed(42)
# dir <- "./dataset/kmnist"
# kmnist_transform <- function(x) {
#   d <- dim(x)
#   if (length(d) == 3 && d[3] > 1 && d[1] == d[2]) {#if shape [28,28,batch] as on windows and linux(?)
#     x <- torchvision::transform_to_tensor(x) #now shape should be [batch, 28,28]
#     x <- x$unsqueeze(2) #add the channel dimension - > [batch,1,28,28]
#   }
#   else{ #on mac, everything is fine
#     x <- torchvision::transform_to_tensor(x)
#   }
#   return(x)
# }
# #get datasets from torchvision and define training and test loaders
# train_ds <- torchvision::kmnist_dataset(
#   dir,
#   download = TRUE,
#   transform = kmnist_transform)
# 
# test_ds <- torchvision::kmnist_dataset(
#   dir,
#   train = FALSE,
#   transform = kmnist_transform)
# 
# train_loader_kmnist <- torch::dataloader(train_ds, batch_size = 100, shuffle = TRUE)
# test_loader_kmnist <- torch::dataloader(test_ds, batch_size = 100)

## ----eval = has_torch---------------------------------------------------------
torch::torch_manual_seed(42)
x <- torch::torch_randn(200, 1, 28, 28)
y <- torch::torch_randint(1, 11, size = 200)
dataset <- torch::tensor_dataset(x, y)
train_loader <- torch::dataloader(dataset, batch_size = 100)

## ----eval = has_torch---------------------------------------------------------
device <- "cpu"
conv_layer_1 <- lbbnn_conv2d(in_channels = 1, out_channels = 32, kernel_size = 5,
                             prior_inclusion = 0.5, standard_prior = 1,
                             density_init = c(-10, 10), num_transforms = 2,
                             flow = FALSE, hidden_dims = c(200, 200),
                             device = device)
conv_layer_2 <- lbbnn_conv2d(in_channels = 32, out_channels = 64, kernel_size = 5,
                             prior_inclusion = 0.5, standard_prior = 1,
                             density_init = c(-10, 15), num_transforms = 2,
                             flow = FALSE, hidden_dims = c(200, 200),
                             device = device)

linear_layer_1 <- lbbnn_linear(in_features = 1024, out_features = 300,
                               prior_inclusion = 0.5, standard_prior = 1,
                               density_init = c(-10, 10), num_transforms = 2,
                               flow = FALSE, hidden_dims = c(200, 200), device = device,
                               bias_inclusion_prob = FALSE, conv_net = TRUE)

linear_layer_2 <- lbbnn_linear(in_features = 300, out_features = 10,
                               prior_inclusion = 0.5, standard_prior = 1,
                               density_init = c(-5, 15),num_transforms = 2,
                               flow = FALSE, hidden_dims = c(200, 200), device = device,
                               bias_inclusion_prob = FALSE, conv_net = TRUE)

## ----eval = has_torch---------------------------------------------------------
LBBNN_ConvNet <- torch::nn_module(
  "LBBNN_ConvNet",
  
  initialize = function(conv1, conv2, fc1 ,fc2 ,device = device) {
    self$problem_type <- "multiclass classification"
    self$input_skip <- FALSE
    self$conv1 <- conv1
    self$conv2 <- conv2
    self$fc1 <- fc1
    self$fc2 <- fc2
    self$pool <- torch::nn_max_pool2d(2)
    self$act <- torch::nn_leaky_relu()
    self$out <- torch::nn_log_softmax(dim = 2)
    self$pout <- torch::nn_softmax(dim = 2)
    self$loss_fn <- torch::nn_nll_loss(reduction = "sum")
  },
  
  forward = function(x, MPM = FALSE, predict = FALSE) {
    x = self$act(self$conv1(x, MPM))
    x = self$pool(x)
    x = self$act(self$conv2(x, MPM))
    x = self$pool(x)
    x = torch::torch_flatten(x,start_dim = 2)
    x = self$act(self$fc1(x, MPM))
    if(!predict)
      x = self$out(self$fc2(x ,MPM))
    else
      x = self$pout(self$fc2(x ,MPM))
  },
  kl_div = function(){
    kl <- self$conv1$kl_div() + self$conv2$kl_div() +
      self$fc1$kl_div() + self$fc2$kl_div()
    return(kl)
  },
  density = function(){
    alphas <- NULL
    alphas <- c(as.numeric(self$conv1$alpha), as.numeric(self$conv2$alpha)
                ,as.numeric(self$fc1$alpha), as.numeric(self$fc2$alpha))
    return(mean(alphas > 0.5))
    
    
  },
  compute_paths = function(){
    NULL
  },
  density_active_path = function(){
    NA
  }
)

model_conv <- LBBNN_ConvNet(conv_layer_1, conv_layer_2, linear_layer_1,
                       linear_layer_2, device)
model_conv$to(device = device)

## ----eval = has_torch---------------------------------------------------------
train_lbbnn(epochs = 2, LBBNN = model_conv, lr = 0.01, train_dl = train_loader,
            device = device)
validate_lbbnn(model_conv, num_samples = 2, test_dl = train_loader, 
               device = device)

