The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
Keras 3 is a deep learning framework works with TensorFlow, JAX, and PyTorch interchangeably. This notebook will walk you through key Keras 3 workflows.
Let’s start by installing Keras 3:
We’re going to be using the tensorflow backend here – but you can
edit the string below to "jax"
or "torch"
and
hit “Restart runtime”, and the whole notebook will run just the same!
This entire guide is backend-agnostic.
Let’s start with the Hello World of ML: training a convnet to classify MNIST digits.
Here’s the data:
# Load the data and split it between train and test sets
c(c(x_train, y_train), c(x_test, y_test)) %<-% keras3::dataset_mnist()
# Scale images to the [0, 1] range
x_train <- x_train / 255
x_test <- x_test / 255
# Make sure images have shape (28, 28, 1)
x_train <- array_reshape(x_train, c(-1, 28, 28, 1))
x_test <- array_reshape(x_test, c(-1, 28, 28, 1))
dim(x_train)
## [1] 60000 28 28 1
## [1] 10000 28 28 1
Here’s our model.
Different model-building options that Keras offers include:
# Model parameters
num_classes <- 10
input_shape <- c(28, 28, 1)
model <- keras_model_sequential(input_shape = input_shape)
model |>
layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") |>
layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") |>
layer_max_pooling_2d(pool_size = c(2, 2)) |>
layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") |>
layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") |>
layer_global_average_pooling_2d() |>
layer_dropout(rate = 0.5) |>
layer_dense(units = num_classes, activation = "softmax")
Here’s our model summary:
## [1mModel: "sequential"[0m
## ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
## ┃[1m [0m[1mLayer (type) [0m[1m [0m┃[1m [0m[1mOutput Shape [0m[1m [0m┃[1m [0m[1m Param #[0m[1m [0m┃
## ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
## │ conv2d ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m26[0m, [38;5;34m26[0m, [38;5;34m64[0m) │ [38;5;34m640[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d_1 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m24[0m, [38;5;34m24[0m, [38;5;34m64[0m) │ [38;5;34m36,928[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ max_pooling2d ([38;5;33mMaxPooling2D[0m) │ ([38;5;45mNone[0m, [38;5;34m12[0m, [38;5;34m12[0m, [38;5;34m64[0m) │ [38;5;34m0[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d_2 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, [38;5;34m128[0m) │ [38;5;34m73,856[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d_3 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m8[0m, [38;5;34m8[0m, [38;5;34m128[0m) │ [38;5;34m147,584[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ global_average_pooling2d │ ([38;5;45mNone[0m, [38;5;34m128[0m) │ [38;5;34m0[0m │
## │ ([38;5;33mGlobalAveragePooling2D[0m) │ │ │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dropout ([38;5;33mDropout[0m) │ ([38;5;45mNone[0m, [38;5;34m128[0m) │ [38;5;34m0[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense ([38;5;33mDense[0m) │ ([38;5;45mNone[0m, [38;5;34m10[0m) │ [38;5;34m1,290[0m │
## └─────────────────────────────────┴────────────────────────┴───────────────┘
## [1m Total params: [0m[38;5;34m260,298[0m (1016.79 KB)
## [1m Trainable params: [0m[38;5;34m260,298[0m (1016.79 KB)
## [1m Non-trainable params: [0m[38;5;34m0[0m (0.00 B)
We use the compile()
method to specify the optimizer,
loss function, and the metrics to monitor. Note that with the JAX and
TensorFlow backends, XLA compilation is turned on by default.
model |> compile(
optimizer = "adam",
loss = "sparse_categorical_crossentropy",
metrics = list(
metric_sparse_categorical_accuracy(name = "acc")
)
)
Let’s train and evaluate the model. We’ll set aside a validation split of 15% of the data during training to monitor generalization on unseen data.
batch_size <- 128
epochs <- 10
callbacks <- list(
callback_model_checkpoint(filepath="model_at_epoch_{epoch}.keras"),
callback_early_stopping(monitor="val_loss", patience=2)
)
model |> fit(
x_train, y_train,
batch_size = batch_size,
epochs = epochs,
validation_split = 0.15,
callbacks = callbacks
)
## Epoch 1/10
## 399/399 - 8s - 20ms/step - acc: 0.7476 - loss: 0.7467 - val_acc: 0.9663 - val_loss: 0.1179
## Epoch 2/10
## 399/399 - 2s - 5ms/step - acc: 0.9384 - loss: 0.2066 - val_acc: 0.9770 - val_loss: 0.0765
## Epoch 3/10
## 399/399 - 2s - 5ms/step - acc: 0.9569 - loss: 0.1467 - val_acc: 0.9817 - val_loss: 0.0622
## Epoch 4/10
## 399/399 - 2s - 5ms/step - acc: 0.9652 - loss: 0.1170 - val_acc: 0.9860 - val_loss: 0.0499
## Epoch 5/10
## 399/399 - 2s - 5ms/step - acc: 0.9709 - loss: 0.0999 - val_acc: 0.9873 - val_loss: 0.0447
## Epoch 6/10
## 399/399 - 2s - 5ms/step - acc: 0.9752 - loss: 0.0863 - val_acc: 0.9877 - val_loss: 0.0400
## Epoch 7/10
## 399/399 - 2s - 5ms/step - acc: 0.9764 - loss: 0.0787 - val_acc: 0.9890 - val_loss: 0.0395
## Epoch 8/10
## 399/399 - 2s - 5ms/step - acc: 0.9794 - loss: 0.0678 - val_acc: 0.9874 - val_loss: 0.0432
## Epoch 9/10
## 399/399 - 2s - 5ms/step - acc: 0.9802 - loss: 0.0658 - val_acc: 0.9894 - val_loss: 0.0395
## Epoch 10/10
## 399/399 - 2s - 5ms/step - acc: 0.9825 - loss: 0.0584 - val_acc: 0.9914 - val_loss: 0.0342
During training, we were saving a model at the end of each epoch. You can also save the model in its latest state like this:
And reload it like this:
Next, you can query predictions of class probabilities with
predict()
:
## 313/313 - 0s - 2ms/step
## [1] 10000 10
That’s it for the basics!
Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers that work across TensorFlow, JAX, and PyTorch with the same codebase. Let’s take a look at custom layers first.
The op_
namespace contains:
op_stack
or
op_matmul
.op_conv
or
op_binary_crossentropy
.Let’s make a custom Dense
layer that works with all
backends:
layer_my_dense <- Layer(
classname = "MyDense",
initialize = function(units, activation = NULL, name = NULL, ...) {
super$initialize(name = name, ...)
self$units <- units
self$activation <- activation
},
build = function(input_shape) {
input_dim <- tail(input_shape, 1)
self$w <- self$add_weight(
shape = shape(input_dim, self$units),
initializer = initializer_glorot_normal(),
name = "kernel",
trainable = TRUE
)
self$b <- self$add_weight(
shape = shape(self$units),
initializer = initializer_zeros(),
name = "bias",
trainable = TRUE
)
},
call = function(inputs) {
# Use Keras ops to create backend-agnostic layers/metrics/etc.
x <- op_matmul(inputs, self$w) + self$b
if (!is.null(self$activation))
x <- self$activation(x)
x
}
)
Next, let’s make a custom Dropout
layer that relies on
the random_*
namespace:
layer_my_dropout <- Layer(
"MyDropout",
initialize = function(rate, name = NULL, seed = NULL, ...) {
super$initialize(name = name)
self$rate <- rate
# Use seed_generator for managing RNG state.
# It is a state element and its seed variable is
# tracked as part of `layer$variables`.
self$seed_generator <- random_seed_generator(seed)
},
call = function(inputs) {
# Use `keras3::random_*` for random ops.
random_dropout(inputs, self$rate, seed = self$seed_generator)
}
)
Next, let’s write a custom subclassed model that uses our two custom layers:
MyModel <- Model(
"MyModel",
initialize = function(num_classes, ...) {
super$initialize(...)
self$conv_base <-
keras_model_sequential() |>
layer_conv_2d(64, kernel_size = c(3, 3), activation = "relu") |>
layer_conv_2d(64, kernel_size = c(3, 3), activation = "relu") |>
layer_max_pooling_2d(pool_size = c(2, 2)) |>
layer_conv_2d(128, kernel_size = c(3, 3), activation = "relu") |>
layer_conv_2d(128, kernel_size = c(3, 3), activation = "relu") |>
layer_global_average_pooling_2d()
self$dp <- layer_my_dropout(rate = 0.5)
self$dense <- layer_my_dense(units = num_classes,
activation = activation_softmax)
},
call = function(inputs) {
inputs |>
self$conv_base() |>
self$dp() |>
self$dense()
}
)
Let’s compile it and fit it:
model <- MyModel(num_classes = 10)
model |> compile(
loss = loss_sparse_categorical_crossentropy(),
optimizer = optimizer_adam(learning_rate = 1e-3),
metrics = list(
metric_sparse_categorical_accuracy(name = "acc")
)
)
model |> fit(
x_train, y_train,
batch_size = batch_size,
epochs = 1, # For speed
validation_split = 0.15
)
## 399/399 - 6s - 15ms/step - acc: 0.7343 - loss: 0.7741 - val_acc: 0.9269 - val_loss: 0.2399
All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you’re using. This includes:
tf_dataset
objectsDataLoader
objectsPyDataset
objectsThey all work whether you’re using TensorFlow, JAX, or PyTorch as your Keras backend.
Let’s try this out with tf_dataset
:
library(tfdatasets, exclude = "shape")
train_dataset <- list(x_train, y_train) |>
tensor_slices_dataset() |>
dataset_batch(batch_size) |>
dataset_prefetch(buffer_size = tf$data$AUTOTUNE)
test_dataset <- list(x_test, y_test) |>
tensor_slices_dataset() |>
dataset_batch(batch_size) |>
dataset_prefetch(buffer_size = tf$data$AUTOTUNE)
model <- MyModel(num_classes = 10)
model |> compile(
loss = loss_sparse_categorical_crossentropy(),
optimizer = optimizer_adam(learning_rate = 1e-3),
metrics = list(
metric_sparse_categorical_accuracy(name = "acc")
)
)
model |> fit(train_dataset, epochs = 1, validation_data = test_dataset)
## 469/469 - 7s - 14ms/step - acc: 0.7499 - loss: 0.7454 - val_acc: 0.9051 - val_loss: 0.3089
This concludes our short overview of the new multi-backend capabilities of Keras 3. Next, you can learn about:
fit()
Want to implement a non-standard training algorithm yourself but
still want to benefit from the power and usability of
fit()
? It’s easy to customize fit()
to support
arbitrary use cases:
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.