In the following, we explain the counterfactuals
workflow for both a classification and a regression task using concrete use cases.
The Predictor
class of the iml
package provides the necessary flexibility to cover classification and regression models fitted with diverse R packages. In the introduction vignette, we saw models fitted with the mlr3
and randomForest
packages. In the following, we show extensions to - an classification tree fitted with the caret
package, the mlr
(a predecesor of mlr3
) and tidymodels
. For each model we generate counterfactuals for the first row of the BostonHousing dataset using the WhatIf
method
library("caret")
treecaret = caret::train(medv ~ ., data = BostonHousing[-1,], method = "rpart",
tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = BostonHousing[-1L,], y = "medv")
predcaret$predict(x_interest)
#> .prediction
#> 1 27.49074
nicecaret = NICERegr$new(predcaret, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47
library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>%
fit(medv ~ ., data = BostonHousing[-1L,])
predtm = Predictor$new(model = treetm, data = BostonHousing[-1L,], y = "medv")
predtm$predict(x_interest)
#> .pred
#> 1 27.49074
nicetm = NICERegr$new(predtm, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47
library("mlr")
task = mlr::makeRegrTask(data = BostonHousing[-1L,], target = "medv")
mod = makeLearner("regr.rpart")
treemlr = train(mod, task)
predmlr = Predictor$new(model = treemlr, data = BostonHousing[-1L,], y = "medv")
predmlr$predict(x_interest)
#> .prediction
#> 1 27.49074
nicemlr = NICERegr$new(predmlr, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47
treerpart = rpart(medv ~ ., data = BostonHousing[-1L,])
predrpart = Predictor$new(model = treerpart, data = BostonHousing[-1L,], y = "medv")
predrpart$predict(x_interest)
#> pred
#> 1 27.49074
nicerpart = NICERegr$new(predrpart, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47