The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

Transfer Learning with Keras Applications

Introduction

Transfer learning is a powerful technique where a model developed for one task is reused as the starting point for a model on a second task. It is especially popular in computer vision, where pre-trained models like ResNet50, which were trained on the massive ImageNet dataset, can be used as powerful, ready-made feature extractors.

The kerasnip package makes it easy to incorporate these pre-trained Keras Applications directly into a tidymodels workflow. This vignette will demonstrate how to:

  1. Define a kerasnip model that uses a pre-trained ResNet50 as a frozen base layer.
  2. Add a new, trainable classification “head” on top of the frozen base.
  3. Tune the hyperparameters of the new classification head using a standard tidymodels workflow.

Setup

First, we load the necessary packages.

library(kerasnip)
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
#> ✔ broom        1.0.8     ✔ recipes      1.3.0
#> ✔ dials        1.4.0     ✔ rsample      1.3.0
#> ✔ dplyr        1.1.4     ✔ tibble       3.2.1
#> ✔ ggplot2      3.5.2     ✔ tidyr        1.3.1
#> ✔ infer        1.0.8     ✔ tune         1.3.0
#> ✔ modeldata    1.4.0     ✔ workflows    1.2.0
#> ✔ parsnip      1.3.1     ✔ workflowsets 1.1.0
#> ✔ purrr        1.0.4     ✔ yardstick    1.3.2
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()
library(keras3)
#> 
#> Attaching package: 'keras3'
#> The following object is masked from 'package:yardstick':
#> 
#>     get_weights

Data Preparation

We’ll use the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 classes. keras3 provides a convenient function to download it.

The ResNet50 model was pre-trained on ImageNet, which has a different set of classes. Our goal is to fine-tune it to classify the 10 classes in CIFAR-10.

# Load CIFAR-10 dataset
cifar10 <- dataset_cifar10()

# Separate training and test data
x_train <- cifar10$train$x
y_train <- cifar10$train$y
x_test <- cifar10$test$x
y_test <- cifar10$test$y

# Rescale pixel values from [0, 255] to [0, 1]
x_train <- x_train / 255
x_test <- x_test / 255

# Convert outcomes to factors for tidymodels
y_train_factor <- factor(y_train[, 1])
y_test_factor <- factor(y_test[, 1])

# For tidymodels, it's best to work with data frames.
# We'll use a list-column to hold the image arrays.
train_df <- tibble::tibble(
  x = lapply(seq_len(nrow(x_train)), function(i) x_train[i, , , , drop = TRUE]),
  y = y_train_factor
)

test_df <- tibble::tibble(
  x = lapply(seq_len(nrow(x_test)), function(i) x_test[i, , , , drop = TRUE]),
  y = y_test_factor
)

# Use a smaller subset for faster vignette execution
train_df_small <- train_df[1:500, ]
test_df_small <- test_df[1:100, ]

Functional API with a Pre-trained Base

The standard approach for transfer learning is to use the Keras Functional API. We will define a model where: 1. The base is a pre-trained ResNet50, with its final classification layer removed (include_top = FALSE). 2. The weights of the base are frozen (trainable = FALSE) so that only our new layers are trained. 3. A new classification “head” is added, consisting of a flatten layer and a dense output layer.

Define Layer Blocks

# Input block: shape is determined automatically from the data
input_block <- function(input_shape) {
  layer_input(shape = input_shape)
}

# ResNet50 base block
resnet_base_block <- function(tensor) {
  # The base model is not trainable; we use it for feature extraction.
  resnet_base <- application_resnet50(
    weights = "imagenet",
    include_top = FALSE
  )
  resnet_base$trainable <- FALSE
  resnet_base(tensor)
}

# New classification head
flatten_block <- function(tensor) {
  tensor |> layer_flatten()
}

output_block_functional <- function(tensor, num_classes) {
  tensor |> layer_dense(units = num_classes, activation = "softmax")
}

Create the kerasnip Specification

We connect these blocks using create_keras_functional_spec().

create_keras_functional_spec(
  model_name = "resnet_transfer",
  layer_blocks = list(
    input = input_block,
    resnet_base = inp_spec(resnet_base_block, "input"),
    flatten = inp_spec(flatten_block, "resnet_base"),
    output = inp_spec(output_block_functional, "flatten")
  ),
  mode = "classification"
)

Fit and Evaluate the Model

Now we can use our new resnet_transfer() specification within a tidymodels workflow.

spec_functional <- resnet_transfer(
  fit_epochs = 5,
  fit_validation_split = 0.2
) |>
  set_engine("keras")

rec_functional <- recipe(y ~ x, data = train_df_small)

wf_functional <- workflow() |>
  add_recipe(rec_functional) |>
  add_model(spec_functional)

fit_functional <- fit(wf_functional, data = train_df_small)

# Evaluate on the test set
predictions <- predict(fit_functional, new_data = test_df_small)
#> 4/4 - 4s - 962ms/step
bind_cols(predictions, test_df_small) |>
  accuracy(truth = y, estimate = .pred_class)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass      0.11

Even with a small dataset and few epochs, the pre-trained features from ResNet50 give us a reasonable starting point for accuracy.

Conclusion

This vignette demonstrated how kerasnip bridges the world of pre-trained Keras applications with the structured, reproducible workflows of tidymodels.

The Functional API is the most direct way to perform transfer learning by attaching a new head to a frozen base model.

This approach allows you to leverage the power of deep learning models that have been trained on massive datasets, significantly boosting performance on smaller, domain-specific tasks.

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.