Other types of models

In the following, we explain the counterfactuals workflow for both a classification and a regression task using concrete use cases.

library("counterfactuals")
library("iml")
library("rpart")

Other types of models

The Predictor class of the iml package provides the necessary flexibility to cover classification and regression models fitted with diverse R packages. In the introduction vignette, we saw models fitted with the mlr3 and randomForest packages. In the following, we show extensions to - an classification tree fitted with the caret package, the mlr (a predecesor of mlr3) and tidymodels. For each model we generate counterfactuals for the first row of the BostonHousing dataset using the WhatIf method

data(BostonHousing, package = "mlbench")
x_interest = BostonHousing[1L,]

rpart - caret package

library("caret")

treecaret = caret::train(medv ~ ., data = BostonHousing[-1,], method = "rpart", 
  tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = BostonHousing[-1L,], y = "medv")
predcaret$predict(x_interest)
#>   .prediction
#> 1    27.49074
nicecaret = NICERegr$new(predcaret, optimization = "plausibility", 
  margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [30, 40] 
#>  
#> Head: 
#>       crim zn indus chas   nox   rm  age    dis rad tax ptratio     b lstat
#> 1: 0.07503 33  2.18    0 0.472 7.42 71.9 3.0992   7 222    18.4 396.9  6.47

rpart - tidymodels package

library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>% 
  fit(medv ~ ., data = BostonHousing[-1L,])
predtm = Predictor$new(model = treetm, data = BostonHousing[-1L,], y = "medv")
predtm$predict(x_interest)
#>      .pred
#> 1 27.49074
nicetm = NICERegr$new(predtm, optimization = "plausibility", 
  margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [30, 40] 
#>  
#> Head: 
#>       crim zn indus chas   nox   rm  age    dis rad tax ptratio     b lstat
#> 1: 0.07503 33  2.18    0 0.472 7.42 71.9 3.0992   7 222    18.4 396.9  6.47

rpart - mlr package

library("mlr")
task = mlr::makeRegrTask(data = BostonHousing[-1L,], target = "medv")
mod = makeLearner("regr.rpart")

treemlr = train(mod, task)
predmlr = Predictor$new(model = treemlr, data = BostonHousing[-1L,], y = "medv")
predmlr$predict(x_interest)
#>   .prediction
#> 1    27.49074
nicemlr = NICERegr$new(predmlr, optimization = "plausibility", 
  margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [30, 40] 
#>  
#> Head: 
#>       crim zn indus chas   nox   rm  age    dis rad tax ptratio     b lstat
#> 1: 0.07503 33  2.18    0 0.472 7.42 71.9 3.0992   7 222    18.4 396.9  6.47

Decision tree - rpart package

treerpart = rpart(medv ~ ., data = BostonHousing[-1L,])
predrpart = Predictor$new(model = treerpart, data = BostonHousing[-1L,], y = "medv")
predrpart$predict(x_interest)
#>       pred
#> 1 27.49074
nicerpart = NICERegr$new(predrpart, optimization = "plausibility", 
  margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [30, 40] 
#>  
#> Head: 
#>       crim zn indus chas   nox   rm  age    dis rad tax ptratio     b lstat
#> 1: 0.07503 33  2.18    0 0.472 7.42 71.9 3.0992   7 222    18.4 396.9  6.47