The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

Getting Started with ensembleML

Overview

ensembleML provides a single, consistent API for ensemble machine learning in R. Regardless of which algorithm you choose, the core workflow is always:

em_fit()  ->  em_predict()  ->  em_evaluate()

Advanced usage adds:

em_cv()        # k-fold cross-validation (stability estimates)
em_tune()      # grid-search hyperparameter optimisation
em_compare()   # side-by-side algorithm comparison
em_importance() # feature importance
em_partial()   # partial dependence plots
em_confusion() # confusion matrix heatmap
em_calibration() # calibration / reliability diagram
em_residuals() # regression diagnostics

1. Train a model

data(iris)
set.seed(42)
idx   <- sample(nrow(iris), 120)
train <- iris[idx,  ]
test  <- iris[-idx, ]

rf <- em_fit(Species ~ ., data = train, method = "random_forest",
             verbose = TRUE)
#> [ensembleML] task auto-detected as 'classification'
#> 
#> ╭────────────────────────────────────────────────────╮
#> │  Algorithm:           random_forest               │
#> │  Task:                classification              │
#> │  Response:            Species                     │
#> │  Classes:             setosa, versicolor, virginica│
#> │  Predictors:          4  (Sepal.Length, Sepal.Width, Petal.Length, …)│
#> │  Training n:          120                         │
#> │  Fit time:            0.020 sec                   │
#> │  Train metrics:       accuracy=1.0000  kappa=1.0000  precision=1.0000  recall=1.0000  f1=1.0000  auc=NA│
#> │  ⚠  Use em_evaluate() on held-out data         │
#> ╰────────────────────────────────────────────────────╯

Switching algorithms requires changing a single argument:

xgb <- em_fit(Species ~ ., data = train, method = "xgboost")
ada <- em_fit(Species ~ ., data = train, method = "adaboost")
bag <- em_fit(Species ~ ., data = train, method = "bagging")

2. Predict

preds <- em_predict(rf, newdata = test)
head(preds)
#>      7     11     12     19     23     28 
#> setosa setosa setosa setosa setosa setosa 
#> Levels: setosa versicolor virginica

Class probabilities:

probs <- em_predict(rf, newdata = test, type = "prob")
head(probs, 3)
#>    setosa versicolor virginica
#> 7   1.000      0.000         0
#> 11  0.998      0.002         0
#> 12  1.000      0.000         0

3. Evaluate

em_evaluate(rf, newdata = test)
#>  accuracy     kappa precision    recall        f1       auc 
#>    0.9333    0.8997    0.9364    0.9364    0.9364        NA

Select specific metrics:

em_evaluate(rf, newdata = test, metrics = c("accuracy", "f1", "kappa"))
#> accuracy       f1    kappa 
#>   0.9333   0.9364   0.8997

4. Cross-validation

Use em_cv() to get mean +/- SD across folds before committing to a method:

cv_res <- em_cv(Species ~ ., data = iris, method = "random_forest",
                cv_folds = 5, repeats = 3)
cv_res$summary
em_plot_cv(cv_res, metric = "accuracy")

5. Tune hyperparameters

grid <- list(ntree = c(100, 300, 500), mtry = c(1, 2, 3))

tuned <- em_tune(
  Species ~ ., data = train, method = "random_forest",
  param_grid = grid, cv_folds = 5
)

tuned$best_params
tuned$best_score
tuned$results

6. Compare algorithms

cmp <- em_compare(Species ~ ., train = train, test = test)
cmp$table

7. Feature importance

em_importance(rf, top_n = 4)


8. Partial dependence

em_partial(rf, data = train, feature = "Petal.Length")

9. Confusion matrix

em_confusion(rf, newdata = test)
em_confusion(rf, newdata = test, normalise = TRUE)

10. Regression example

Everything works identically for numeric responses:

set.seed(7)
reg_data  <- data.frame(
  x1 = rnorm(200), x2 = rnorm(200),
  y  = 3 + 2 * rnorm(200) + rnorm(200))
reg_train <- reg_data[1:160, ]
reg_test  <- reg_data[161:200, ]

reg_model <- em_fit(y ~ ., data = reg_train, method = "random_forest")
#> [ensembleML] task auto-detected as 'regression'
em_evaluate(reg_model, reg_test)
#>    rmse     mae    mape     rsq adj_rsq 
#>  2.4320  1.8556 88.1007 -0.2193 -0.2852
em_residuals(reg_model, reg_test)
#> `geom_smooth()` using formula = 'y ~ x'


Citation

If you use ensembleML in published work, please cite it:

citation("ensembleML")

The individual algorithms should also be cited — see citation("ensembleML") for the full list of references.


Session info

sessionInfo()
#> R version 4.2.1 (2022-06-23 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 26200)
#> 
#> Matrix products: default
#> 
#> locale:
#> [1] LC_COLLATE=C                          
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ensembleML_0.2.5
#> 
#> loaded via a namespace (and not attached):
#>  [1] bslib_0.10.0         compiler_4.2.1       pillar_1.11.1       
#>  [4] RColorBrewer_1.1-3   jquerylib_0.1.4      tools_4.2.1         
#>  [7] digest_0.6.39        lattice_0.20-45      nlme_3.1-168        
#> [10] jsonlite_2.0.0       evaluate_1.0.5       lifecycle_1.0.5     
#> [13] tibble_3.3.1         gtable_0.3.6         mgcv_1.8-40         
#> [16] pkgconfig_2.0.3      rlang_1.1.7          Matrix_1.6-5        
#> [19] cli_3.6.5            rstudioapi_0.18.0    yaml_2.3.12         
#> [22] xfun_0.57            fastmap_1.2.0        gridExtra_2.3       
#> [25] withr_3.0.2          dplyr_1.2.0          knitr_1.51          
#> [28] generics_0.1.4       sass_0.4.10          vctrs_0.7.2         
#> [31] grid_4.2.1           tidyselect_1.2.1     glue_1.7.0          
#> [34] R6_2.6.1             otel_0.2.0           rmarkdown_2.31      
#> [37] ggplot2_4.0.2        farver_2.1.2         magrittr_2.0.3      
#> [40] splines_4.2.1        scales_1.4.0         htmltools_0.5.9     
#> [43] randomForest_4.7-1.2 labeling_0.4.3       S7_0.2.1            
#> [46] cachem_1.1.0

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.