The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

Customizing Wrapper Functions

Nickalus Redell

2020-05-05

Purpose

The purpose of this vignette is to provide a closer look at how the user-supplied model training and predict wrapper functions can be modified to give greater control over the model-building process. The goal is to present examples of how the wrapper functions could be flexibly written to keep a linear workflow in forecastML while modeling across multiple forecast horizons and validation datasets. The alternative would be to train models across a single forecast horizon and/or validation window and customize the wrapper functions for this specific setup.

Example 1 - Multiple forecast horizons and 1 model training function

library(DT)
library(dplyr)
library(ggplot2)
library(forecastML)
library(randomForest)

data("data_seatbelts", package = "forecastML")
data <- data_seatbelts

data <- data[, c("DriversKilled", "kms", "PetrolPrice", "law")]

dates <- seq(as.Date("1969-01-01"), as.Date("1984-12-01"), by = "1 month")
data_train <- forecastML::create_lagged_df(data,
                                           type = "train",
                                           outcome_col = 1, 
                                           lookback = 1:12,
                                           horizons = c(3, 12),
                                           dates = dates,
                                           frequency = "1 month")

# View the horizon 3 lagged dataset.
DT::datatable(head((data_train$horizon_3)), options = list("scrollX" = TRUE))


windows <- forecastML::create_windows(data_train, window_length = 0, 
                                      window_start = as.Date("1984-01-01"),
                                      window_stop = as.Date("1984-12-01"))

plot(windows, data_train)

User-defined model-training function

attributes(data_train$horizon_3)$horizon
## [1] 3
attributes(data_train$horizon_12)$horizon
## [1] 12
model_function <- function(data, my_outcome_col = 1, n_tree = c(200, 100)) {

  outcome_names <- names(data)[my_outcome_col]
  model_formula <- formula(paste0(outcome_names,  "~ ."))
  
  if (attributes(data)$horizon == 3) {  # Model 1
    
          model <- randomForest::randomForest(formula = model_formula, 
                                              data = data, 
                                              ntree = n_tree[1])
          
          return(list("my_trained_model" = model, "n_tree" = n_tree[1], 
                      "meta_data" = attributes(data)$horizon))
      
  } else if (attributes(data)$horizon == 12) {  # Model 2
    
          model <- randomForest::randomForest(formula = model_formula, 
                                              data = data, 
                                              ntree = n_tree[2])
          
          return(list("my_trained_model" = model, "n_tree" = n_tree[2],
                      "meta_data" = attributes(data)$horizon))
  }
}
model_results <- forecastML::train_model(data_train, windows, model_name = "RF", model_function)
model_results$horizon_3$window_1$model
## $my_trained_model
## 
## Call:
##  randomForest(formula = model_formula, data = data, ntree = n_tree[1]) 
##                Type of random forest: regression
##                      Number of trees: 200
## No. of variables tried at each split: 13
## 
##           Mean of squared residuals: 247.162
##                     % Var explained: 59.72
## 
## $n_tree
## [1] 200
## 
## $meta_data
## [1] 3
model_results$horizon_12$window_1$model
## $my_trained_model
## 
## Call:
##  randomForest(formula = model_formula, data = data, ntree = n_tree[2]) 
##                Type of random forest: regression
##                      Number of trees: 100
## No. of variables tried at each split: 1
## 
##           Mean of squared residuals: 420.6497
##                     % Var explained: 31.45
## 
## $n_tree
## [1] 100
## 
## $meta_data
## [1] 12

User-defined prediction function

prediction_function <- function(model, data_features) {
  
    if (model$meta_data == 3) {  # Perform a transformation specific to model 1.
      
        data_pred <- data.frame("y_pred" = predict(model$my_trained_model, data_features))
    }
  
    if (model$meta_data == 12) {  # Perform a transformation specific to model 2.
      
        data_pred <- data.frame("y_pred" = predict(model$my_trained_model, data_features))
    }

  return(data_pred)
}
data_results <- predict(model_results,
                        prediction_function = list(prediction_function),
                        data = data_train)
plot(data_results)


These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.