The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
Transfer learning is a powerful technique where a model developed for
one task is reused as the starting point for a model on a second task.
It is especially popular in computer vision, where pre-trained models
like ResNet50
, which were trained on the massive ImageNet
dataset, can be used as powerful, ready-made feature extractors.
The kerasnip
package makes it easy to incorporate these
pre-trained Keras Applications directly into a tidymodels
workflow. This vignette will demonstrate how to:
kerasnip
model that uses a pre-trained
ResNet50
as a frozen base layer.tidymodels
workflow.First, we load the necessary packages.
library(kerasnip)
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
#> ✔ broom 1.0.8 ✔ recipes 1.3.0
#> ✔ dials 1.4.0 ✔ rsample 1.3.0
#> ✔ dplyr 1.1.4 ✔ tibble 3.2.1
#> ✔ ggplot2 3.5.2 ✔ tidyr 1.3.1
#> ✔ infer 1.0.8 ✔ tune 1.3.0
#> ✔ modeldata 1.4.0 ✔ workflows 1.2.0
#> ✔ parsnip 1.3.1 ✔ workflowsets 1.1.0
#> ✔ purrr 1.0.4 ✔ yardstick 1.3.2
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ recipes::step() masks stats::step()
library(keras3)
#>
#> Attaching package: 'keras3'
#> The following object is masked from 'package:yardstick':
#>
#> get_weights
We’ll use the CIFAR-10 dataset, which consists of 60,000 32x32 color
images in 10 classes. keras3
provides a convenient function
to download it.
The ResNet50
model was pre-trained on ImageNet, which
has a different set of classes. Our goal is to fine-tune it to classify
the 10 classes in CIFAR-10.
# Load CIFAR-10 dataset
cifar10 <- dataset_cifar10()
# Separate training and test data
x_train <- cifar10$train$x
y_train <- cifar10$train$y
x_test <- cifar10$test$x
y_test <- cifar10$test$y
# Rescale pixel values from [0, 255] to [0, 1]
x_train <- x_train / 255
x_test <- x_test / 255
# Convert outcomes to factors for tidymodels
y_train_factor <- factor(y_train[, 1])
y_test_factor <- factor(y_test[, 1])
# For tidymodels, it's best to work with data frames.
# We'll use a list-column to hold the image arrays.
train_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_train)), function(i) x_train[i, , , , drop = TRUE]),
y = y_train_factor
)
test_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_test)), function(i) x_test[i, , , , drop = TRUE]),
y = y_test_factor
)
# Use a smaller subset for faster vignette execution
train_df_small <- train_df[1:500, ]
test_df_small <- test_df[1:100, ]
The standard approach for transfer learning is to use the Keras
Functional API. We will define a model where: 1. The base is a
pre-trained ResNet50
, with its final classification layer
removed (include_top = FALSE
). 2. The weights of the base
are frozen (trainable = FALSE
) so that only our new layers
are trained. 3. A new classification “head” is added, consisting of a
flatten layer and a dense output layer.
# Input block: shape is determined automatically from the data
input_block <- function(input_shape) {
layer_input(shape = input_shape)
}
# ResNet50 base block
resnet_base_block <- function(tensor) {
# The base model is not trainable; we use it for feature extraction.
resnet_base <- application_resnet50(
weights = "imagenet",
include_top = FALSE
)
resnet_base$trainable <- FALSE
resnet_base(tensor)
}
# New classification head
flatten_block <- function(tensor) {
tensor |> layer_flatten()
}
output_block_functional <- function(tensor, num_classes) {
tensor |> layer_dense(units = num_classes, activation = "softmax")
}
kerasnip
SpecificationWe connect these blocks using
create_keras_functional_spec()
.
Now we can use our new resnet_transfer()
specification
within a tidymodels
workflow.
spec_functional <- resnet_transfer(
fit_epochs = 5,
fit_validation_split = 0.2
) |>
set_engine("keras")
rec_functional <- recipe(y ~ x, data = train_df_small)
wf_functional <- workflow() |>
add_recipe(rec_functional) |>
add_model(spec_functional)
fit_functional <- fit(wf_functional, data = train_df_small)
# Evaluate on the test set
predictions <- predict(fit_functional, new_data = test_df_small)
#> 4/4 - 4s - 962ms/step
bind_cols(predictions, test_df_small) |>
accuracy(truth = y, estimate = .pred_class)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.11
Even with a small dataset and few epochs, the pre-trained features from ResNet50 give us a reasonable starting point for accuracy.
This vignette demonstrated how kerasnip
bridges the
world of pre-trained Keras applications with the structured,
reproducible workflows of tidymodels
.
The Functional API is the most direct way to perform transfer learning by attaching a new head to a frozen base model.
This approach allows you to leverage the power of deep learning models that have been trained on massive datasets, significantly boosting performance on smaller, domain-specific tasks.
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.