The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

tfNeuralODE-Adjoint

In this example, we’re going to train a simple network to learn a new spiral trajectory using the adjoint method to train. We start by loading all of the libraries and setting our initial conditions, along with a few constants.

library(reticulate)
library(tensorflow)
library(tfNeuralODE)
library(keras)
library(deSolve)

# iterations, time span (layers)
niters = 25
t = seq(0, 25, by = 25/100)
# initial conditions
h0 = tf$cast(t(c(1., 0.)), dtype= tf$float32)
W = tf$cast(rbind(c(-0.1, 1.0), c(-0.2, -0.1)), dtype = tf$float32)
h0_var = tf$Variable(h0, name = "")
hN_target = tf$cast(t(c(0., 0.5)), dtype = tf$float32)

We solve for the initial trajectory.

trueODEfunc<- function(du, u, p, t){
  true_A = rbind(c(-0.1, 1.0), c(-0.2, -0.1))
  du <- (u) %*% true_A
  return(list(du))
}

# solved ode output
init_path <- lsode(func = trueODEfunc, y = c(1., 0.), times = t)

Now we instantiate a very simple ODE model following that initial trajectory and an optimizer to train the ODE model with.

# ODE Model

optimizer = tf$keras$optimizers$legacy$SGD(learning_rate=1e-2, momentum=0.95)

OdeModel(keras$Model) %py_class% {
  call <- function(inputs) {
    tf$matmul(inputs, W)
  }
}
model<- OdeModel()

Now we train the model, using 25 iterations. Each 5 iterations, the plot of the ODE will be produced.

for(i in 1:niters){
  print(paste("Iteration", i, "out of", niters, "iterations."))
  with(tf$GradientTape() %as% tape, {
    pred = forward(model, inputs = h0_var, tsteps = t)
    tape$watch(pred)
    loss = tf$reduce_sum((hN_target - pred) ^ 2)
  })
  #print(paste("loss:", as.numeric(loss)))
  dLoss = tape$gradient(loss, pred)
  dfdh0 = backward(model, t, pred, output_gradients = dLoss)[[2]]
  optimizer$apply_gradients(list(c(dfdh0, h0_var)))

  # graphing the Neural ODE
  if(i %% 5 == 0 || i == 1){
    pred_y = forward(model = model, inputs = tf$cast((as.matrix(h0_var)), dtype = tf$float32),
                     tsteps = t, return_states = TRUE)
    pred_y_c<- k_concatenate(pred_y[[2]], 1)
    p_m<- as.matrix(pred_y_c)
    plot(p_m,
         main = paste("Iteration", i), type = "l", col = "red",
         xlim = c(-1,2), ylim = c(-1,2))
    lines(init_path[,2], init_path[,3], col = "blue")
  }
}
#> [1] "Iteration 1 out of 25 iterations."
plot of iteration 1
plot of iteration 1
#> [1] "Iteration 2 out of 25 iterations."
#> [1] "Iteration 3 out of 25 iterations."
#> [1] "Iteration 4 out of 25 iterations."
#> [1] "Iteration 5 out of 25 iterations."
plot of iteration 5
plot of iteration 5
#> [1] "Iteration 6 out of 25 iterations."
#> [1] "Iteration 7 out of 25 iterations."
#> [1] "Iteration 8 out of 25 iterations."
#> [1] "Iteration 9 out of 25 iterations."
#> [1] "Iteration 10 out of 25 iterations."
plot of iteration 10
plot of iteration 10
#> [1] "Iteration 11 out of 25 iterations."
#> [1] "Iteration 12 out of 25 iterations."
#> [1] "Iteration 13 out of 25 iterations."
#> [1] "Iteration 14 out of 25 iterations."
#> [1] "Iteration 15 out of 25 iterations."
plot of iteration 15
plot of iteration 15
#> [1] "Iteration 16 out of 25 iterations."
#> [1] "Iteration 17 out of 25 iterations."
#> [1] "Iteration 18 out of 25 iterations."
#> [1] "Iteration 19 out of 25 iterations."
#> [1] "Iteration 20 out of 25 iterations."
plot of iteration 20
plot of iteration 20
#> [1] "Iteration 21 out of 25 iterations."
#> [1] "Iteration 22 out of 25 iterations."
#> [1] "Iteration 23 out of 25 iterations."
#> [1] "Iteration 24 out of 25 iterations."
#> [1] "Iteration 25 out of 25 iterations."
plot of iteration 25
plot of iteration 25

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.