Linear Regression with batch

Vignette Author

2018-01-13

Import libraries

library(rTorch)

torch      <- import("torch")
Variable   <- import("torch.autograd")$Variable
np         <- import("numpy")
optim      <- import("torch.optim") 
py         <- import_builtins()

Generate the data

# make it reproducible
torch$manual_seed(42L)
#> <torch._C.Generator>

X  <- torch$linspace(-1L, 1L, 101L)
y1 <- 2 * X                           # X$mul(2) 
y2 <- torch$randn(X$size()) * 0.33
Y  <- y1 + y2                         # y1$add(y2)

Model parameters

n_examples    <- torch_size(X)
n_features    <- 1L

learning_rate <- 0.01
momentum      <- 0.9
n_classes     <- 1L
batch_size    <- 10L
epochs        <- 100        # original value for epochs = 100
neurons       <- 512L

Build the model

build_model <- function(input_dim, output_dim) {
    model <- torch$nn$Sequential()
    model$add_module("linear", torch$nn$Linear(input_dim, output_dim, bias = FALSE))
    return(model)
}

train <- function(model, loss, optimizer, x, y) {
    
    x = Variable(x, requires_grad = FALSE)
    y = Variable(y, requires_grad = FALSE)
    
    # reset gradient
    optimizer$zero_grad()
    
    # forward
    fx  <- model$forward(x$view(py$len(x), 1L))
    output <- loss$forward(fx, y)
    
    # backward
    output$backward()
    
    # update parameters
    optimizer$step()
    
    return(output$data$index(0L))
}

model     <- build_model(n_features, n_classes)
loss      <- torch$nn$MSELoss(size_average = TRUE)
optimizer <- optim$SGD(model$parameters(), lr = learning_rate, momentum = momentum)

Run optimization

for (i in seq(1, epochs)) {
    ccost <-  0.0
    num_batches <- n_examples %/% batch_size
    
    for (k in seq(1, num_batches)) {
        k <- k - 1                             # index in Python start at [0]
        start <- as.integer(k * batch_size)
        end   <- as.integer((k + 1) * batch_size)
        
        cost  <- train(model, loss, optimizer,
                             X$narrow(-1L, start, end-start),
                             Y$narrow(-1L, start, end-start))
        
        ccost <-  ccost + cost$numpy()         # because we don't have `+` func
    }
    cat(sprintf("Epoch = %3d, cost = %s \n", i, ccost / num_batches))
}
#> Epoch =   1, cost = 1.20651289448142 
#> Epoch =   2, cost = 0.560786741226912 
#> Epoch =   3, cost = 0.174031929299235 
#> Epoch =   4, cost = 0.102411837130785 
#> Epoch =   5, cost = 0.110093450173736 
#> Epoch =   6, cost = 0.111259509623051 
#> Epoch =   7, cost = 0.10720083899796 
#> Epoch =   8, cost = 0.104798705130816 
#> Epoch =   9, cost = 0.104044001922011 
#> Epoch =  10, cost = 0.103783955425024 
#> Epoch =  11, cost = 0.103662004321814 
#> Epoch =  12, cost = 0.103639147803187 
#> Epoch =  13, cost = 0.10366975069046 
#> Epoch =  14, cost = 0.103705078735948 
#> Epoch =  15, cost = 0.103725069761276 
#> Epoch =  16, cost = 0.103730804100633 
#> Epoch =  17, cost = 0.103729251399636 
#> Epoch =  18, cost = 0.103725837171078 
#> Epoch =  19, cost = 0.103723178431392 
#> Epoch =  20, cost = 0.103721882030368 
#> Epoch =  21, cost = 0.103721595183015 
#> Epoch =  22, cost = 0.103721782565117 
#> Epoch =  23, cost = 0.103722047433257 
#> Epoch =  24, cost = 0.103722236305475 
#> Epoch =  25, cost = 0.103722312673926 
#> Epoch =  26, cost = 0.10372233428061 
#> Epoch =  27, cost = 0.103722311928868 
#> Epoch =  28, cost = 0.103722286969423 
#> Epoch =  29, cost = 0.103722278401256 
#> Epoch =  30, cost = 0.103722276166081 
#> Epoch =  31, cost = 0.103722276166081 
#> Epoch =  32, cost = 0.103722277656198 
#> Epoch =  33, cost = 0.103722277656198 
#> Epoch =  34, cost = 0.103722277656198 
#> Epoch =  35, cost = 0.103722283616662 
#> Epoch =  36, cost = 0.103722278773785 
#> Epoch =  37, cost = 0.103722278773785 
#> Epoch =  38, cost = 0.103722278773785 
#> Epoch =  39, cost = 0.103722278773785 
#> Epoch =  40, cost = 0.103722276911139 
#> Epoch =  41, cost = 0.103722277656198 
#> Epoch =  42, cost = 0.103722277656198 
#> Epoch =  43, cost = 0.103722283616662 
#> Epoch =  44, cost = 0.103722278773785 
#> Epoch =  45, cost = 0.103722278773785 
#> Epoch =  46, cost = 0.103722278773785 
#> Epoch =  47, cost = 0.103722278773785 
#> Epoch =  48, cost = 0.103722276911139 
#> Epoch =  49, cost = 0.103722277656198 
#> Epoch =  50, cost = 0.103722277656198 
#> Epoch =  51, cost = 0.103722283616662 
#> Epoch =  52, cost = 0.103722278773785 
#> Epoch =  53, cost = 0.103722278773785 
#> Epoch =  54, cost = 0.103722278773785 
#> Epoch =  55, cost = 0.103722278773785 
#> Epoch =  56, cost = 0.103722276911139 
#> Epoch =  57, cost = 0.103722277656198 
#> Epoch =  58, cost = 0.103722277656198 
#> Epoch =  59, cost = 0.103722283616662 
#> Epoch =  60, cost = 0.103722278773785 
#> Epoch =  61, cost = 0.103722278773785 
#> Epoch =  62, cost = 0.103722278773785 
#> Epoch =  63, cost = 0.103722278773785 
#> Epoch =  64, cost = 0.103722276911139 
#> Epoch =  65, cost = 0.103722277656198 
#> Epoch =  66, cost = 0.103722277656198 
#> Epoch =  67, cost = 0.103722283616662 
#> Epoch =  68, cost = 0.103722278773785 
#> Epoch =  69, cost = 0.103722278773785 
#> Epoch =  70, cost = 0.103722278773785 
#> Epoch =  71, cost = 0.103722278773785 
#> Epoch =  72, cost = 0.103722276911139 
#> Epoch =  73, cost = 0.103722277656198 
#> Epoch =  74, cost = 0.103722277656198 
#> Epoch =  75, cost = 0.103722283616662 
#> Epoch =  76, cost = 0.103722278773785 
#> Epoch =  77, cost = 0.103722278773785 
#> Epoch =  78, cost = 0.103722278773785 
#> Epoch =  79, cost = 0.103722278773785 
#> Epoch =  80, cost = 0.103722276911139 
#> Epoch =  81, cost = 0.103722277656198 
#> Epoch =  82, cost = 0.103722277656198 
#> Epoch =  83, cost = 0.103722283616662 
#> Epoch =  84, cost = 0.103722278773785 
#> Epoch =  85, cost = 0.103722278773785 
#> Epoch =  86, cost = 0.103722278773785 
#> Epoch =  87, cost = 0.103722278773785 
#> Epoch =  88, cost = 0.103722276911139 
#> Epoch =  89, cost = 0.103722277656198 
#> Epoch =  90, cost = 0.103722277656198 
#> Epoch =  91, cost = 0.103722283616662 
#> Epoch =  92, cost = 0.103722278773785 
#> Epoch =  93, cost = 0.103722278773785 
#> Epoch =  94, cost = 0.103722278773785 
#> Epoch =  95, cost = 0.103722278773785 
#> Epoch =  96, cost = 0.103722276911139 
#> Epoch =  97, cost = 0.103722277656198 
#> Epoch =  98, cost = 0.103722277656198 
#> Epoch =  99, cost = 0.103722283616662 
#> Epoch = 100, cost = 0.103722278773785
model_param <- model$parameters()
w <- iter_next(model$parameters())$data
cat(sprintf("w = %.3f", w$numpy()))
#> w = 1.968

# Epoch =   1, cost = 0.103987852856517 
# Epoch = 100, cost = 0.103722278773785 
# w = 1.968