In the following, we illustrate how further methods can be added to the counterfactuals
package by integrating the featureTweakR
package, which implements Feature Tweaking of Tolomei et al. (2017).
(Note that the featureTweakR
package has a couple of limitations, e.g., that factors in the training data cause problems or that the algorithm is only applicable to randomForests
trained on standardized features. Therefore, featureTweakR
is not part of the counterfactuals
package, but serves as a suitable example for our purpose here.)
Before we dive into the implementation details, we briefly explain the structure of the counterfactuals
package.
Each counterfactual method is represented by its own R6
class. Depending on whether a counterfactual method supports classification or regression tasks, it inherits from the (abstract) CounterfactualMethodClassif
or CounterfactualMethodRegr
classes, respectively. Counterfactual methods that support both tasks are split into two separate classes; for instance, as MOC is applicable to classification and regression tasks, we implement two classes: MOCClassif
and MOCRegr
.
Leaf classes (like MOCClassif
and MOCRegr
) inherit the find_counterfactuals()
method from CounterfactualMethodClassif
or CounterfactualMethodRegr
, respectively. The key advantage of this approach is that we are able to provide a tailored find_counterfactuals()
interface to the task at hand: for classification tasks find_counterfactuals()
has two arguments desired_class
and desired_prob
and for regression tasks it has one argument desired_outcome
.
The find_counterfactuals()
method calls a private run()
method (1)—implemented by the leaf classes—which performs the search and returns the counterfactuals as a data.table
(2). The find_counterfactuals()
method then creates a Counterfactuals
object, which contains the counterfactuals and provides several methods for their evaluation and visualization (3).
To integrate Feature Tweaking, we first need to install featureTweakR
and pforeach
and load the required libraries.
devtools::install_github("katokohaku/featureTweakR")
# required for FeatureTweakeR
devtools::install_github("hoxo-m/pforeach")
A new leaf class needs at least two methods: initialize()
and run()
. The method print_parameters()
is not mandatory, but strongly recommended as it gives objects of that class an informative print()
output.
As elaborated above, the new class inherits from either CounterfactualMethodClassif
or CounterfactualMethodRegr
, depending on which task it supports. Because Feature Tweaking supports classification tasks, the new FeatureTweakerClassif
class inherits from the former.
$initialize
methodThe initialize()
method must have a predictor argument that takes an iml::Predictor
object. In addition, it may have further arguments that are specific to the counterfactual method such as ktree
, epsiron
, and resample
in this case. For argument checks, we recommend the checkmate
package. We also fill the print_parameters()
method with the parameters of Feature Tweaking.
FeatureTweakerClassif = R6Class("FeatureTweakerClassif", inherit = CounterfactualMethodClassif,
public = list(
initialize = function(predictor, ktree = NULL, epsiron = 0.1,
resample = FALSE) {
# adds predictor to private$predictor field
super$initialize(predictor)
private$ktree = ktree
private$epsiron = epsiron
private$resample = resample
}
),
private = list(
ktree = NULL,
epsiron = NULL,
resample = NULL,
run = function() {},
print_parameters = function() {
cat(" - epsiron: ", private$epsiron, "\n")
cat(" - ktree: ", private$ktree, "\n")
cat(" - resample: ", private$resample)
}
)
)
$run
methodThe run()
method performs the search for counterfactuals. Its structure is completely free, which makes it flexible to add new counterfactual methods to the counterfactuals
package.
The workflow of finding counterfactuals with the featureTweakR
package is explained here and essentially consists of these steps:
# Rule extraction
rules = getRules(rf, ktree = 5L)
# Get e-satisfactory instance
es = set.eSatisfactory(rules, epsiron = 0.3)
# Compute counterfactuals
tweaked = tweak(
es, rf, x_interest, label.from = ..., label.to = ..., .dopar = FALSE
)
tweaked$suggest
As long as ktree
—the number of trees to parse—is smaller than the total number of trees in the randomForest
, the rule extraction is a random process. Hence, these steps can be repeated resample
times to obtain multiple counterfactuals.
FeatureTweakerClassif = R6Class("FeatureTweakerClassif",
inherit = CounterfactualMethodClassif,
public = list(
initialize = function(predictor, ktree = NULL, epsiron = 0.1,
resample = FALSE) {
# adds predictor to private$predictor field
super$initialize(predictor)
private$ktree = ktree
private$epsiron = epsiron
private$resample = resample
}
),
private = list(
ktree = NULL,
epsiron = NULL,
resample = NULL,
run = function() {
# Extract info from private fields
predictor = private$predictor
y_hat_interest = predictor$predict(private$x_interest)
class_x_interest = names(y_hat_interest)[which.max(y_hat_interest)]
rf = predictor$model
# Search counterfactuals by calling functions in featureTweakR
rules = getRules(rf, ktree = private$ktree, resample = private$resample)
es = set.eSatisfactory(rules, epsiron = private$epsiron)
tweaks = featureTweakR::tweak(
es, rf, private$x_interest, label.from = class_x_interest,
label.to = private$desired_class, .dopar = FALSE
)
res <- tweaks$suggest
},
print_parameters = function() {
cat(" - epsiron: ", private$epsiron, "\n")
cat(" - ktree: ", private$ktree, "\n")
cat(" - resample: ", private$resample)
}
)
)
Now, that we have implemented FeatureTweakerClassif
, let’s look at a short application to the iris
dataset.
First, we train a randomForest
model on the iris
data set and set up the iml::Predictor
object, omitting x_interest
from the training data.
set.seed(78546)
X = subset(iris, select = -Species)[- 130L,]
y = iris$Species[-130L]
rf = randomForest(X, y, ntree = 20L)
predictor = iml::Predictor$new(rf, data = iris[-130L, ], y = "Species", type = "prob")
For x_interest
, the model predicts a probability of 25% for versicolor
.
x_interest = iris[130L, ]
predictor$predict(x_interest)
#> setosa versicolor virginica
#> 1 0 0.3 0.7
#> extracting sampled 10 of 20 trees
#> Time difference of 3.120119 secs
#> set e-satisfactory instance (10 trees)
#> Time difference of 0.4980896 secs
#> 1 instances were predicted by 10 trees:
#> instance[1]: predicted "virginica" agreed by 7 tree (wants "virginica"->"versicolor")
#> evaluate 26 rules in 7 trees
#> - evalutate 10 candidate of rules (delta.min=0.951)
#> Time difference of 0.08894181 secs
Now, we use Feature Tweaking to address the question: "What would need to change in x_interest
for the model to predict a probability of at least 60% forversicolor
.
# Set up FeatureTweakerClassif
ft_classif = FeatureTweakerClassif$new(predictor, ktree = 10L, resample = TRUE)
# Find counterfactuals and create Counterfactuals Object
cfactuals = ft_classif$find_counterfactuals(
x_interest = x_interest, desired_class = "versicolor", desired_prob = c(0.6, 1)
)
Just as for the existing methods, the result is a Counterfactuals
object.
Tolomei, G., Silvestri, F., Haines, A., Lalmas, M.: Interpretable Predictions of Tree-based Ensembles via Actionable Feature Tweaking. In: Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. pp. 465–474. KDD ’17, ACM, New York, NY, USA (2017). .
Comments
A minor limitation of this basic implementation is that we would not be able to find counterfactuals for a setting with
max(desired_prob) < 0.5
, sincefeatureTweakR::tweak
only searches for instances that would be predicted asdesired_class
by majority vote. To enable this setting, we would need to change somefeatureTweakR
internal code, but for the sake of clarity we will not do this here.