The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
library(luz)
The Accelerator API is a simplified port of the Hugging Face Accelerate library. It allows users to avoid the boilerplate code necessary to write training loops that work correctly on both devices. Currently it only handles CPU and single-GPU usage.
This API is meant to be the most flexible way you can use the luz
package. With the Accelerator API, you write the raw torch training loop
and, with a few code changes, you automatically handle device placement
of the model, optimizers and dataloaders, so you don’t need to add many
$to(device="cuda")
calls in your code or think about the
order in which to create the model and optimizers.
The Accelerator API is best explained by showing an example diff in a raw torch training loop.
library(torch)+ library(luz)
+ acc <- accelerator()
- device <- "cpu"
data <- tensor_dataset(
x = torch_randn(100, 10),
y = torch_rand(100, 1)
)
dl <- dataloader(data, batch_size = 10)
model <- nn_linear(10, 1)- model$to(device = device)
opt <- optim_adam(model$parameters)
+ c(model, opt, dl) %<-% acc$prepare(model, opt, dl)
model$train()
coro::loop(for (batch in dl) {
opt$zero_grad()
- preds <- model(batch$x$to(device = device))
+ preds <- model(batch$x)
- loss <- nnf_mse_loss(preds, batch$y$to(device = device))
+ loss <- nnf_mse_loss(preds, batch$y)
loss$backward()
opt$step() })
With the code changes shown, you no longer need to manually move data and parameters between devices, which makes your code easier to read and less error prone.
You can find additional documentation using
help(accelerator)
.
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.