The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

Extending Autograd

library(torch)

Adding operations to autograd requires implementing a new autograd_function for each operation. Recall that autograd_functionss are what autograd uses to compute the results and gradients, and encode the operation history. Every new function requires you to implement 2 methods:

Note

It’s the user’s responsibility to use the special functions in the forward’s ctx properly in order to ensure that the new autograd_function works properly with the autograd engine.

Examples

Below you can find code for a linear function:

linear <- autograd_function(
  forward = function(ctx, input, weight, bias = NULL) {
    ctx$save_for_backward(input = input, weight = weight, bias = bias)
    output <- input$mm(weight$t())
    if (!is.null(bias))
      output <- output + bias$unsqueeze(0)$expand_as(output)
    
    output
  },
  backward = function(ctx, grad_output) {
    
    s <- ctx$saved_variables
    
    grads <- list(
      input = NULL,
      weight = NULL,
      bias = NULL
    )
    
    if (ctx$needs_input_grad$input)
      grads$input <- grad_output$mm(s$weight)
    
    if (ctx$needs_input_grad$weight)
      grads$weight <- grad_output$t()$mm(s$input)
    
    if (!is.null(s$bias) && ctx$needs_input_grad$bias)
      grads$bias <- grad_output$sum(dim = 0)
    
    grads
  }
)

Here, we give an additional example of a function that is parametrized by non-Tensor arguments:

mul_constant <- autograd_function(
  forward = function(ctx, tensor, constant) {
    ctx$save_for_backward(constant = constant)
    tensor * constant
  },
  backward = function(ctx, grad_output) {
    v <- ctx$saved_variables
    list(
      tensor = grad_output * v$constant
    )
  }
)
x <- torch_tensor(1, requires_grad = TRUE)
o <- mul_constant(x, 2)
o$backward()
x$grad

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.