The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
No black-box model without XAI. This is where packages like
{flashlight} offers the following XAI methods:
light_performance()
: Performance metrics like RMSE
and/or \(R^2\)light_importance()
: Permutation variable importance
(Fisher, Rudin, and Dominici 2018)light_ice()
: Individual conditional expectation (ICE)
profiles (Goldstein et al. 2015) (centered or
uncentered)light_profile()
: Partial dependence (Friedman 2001), accumulated local
effects (ALE) (Apley and Zhu 2016), average
predicted/observed/residuallight_profile2d()
: Two-dimensional version of
light_profile()
light_effects()
: Combines partial dependence, ALE,
response and prediction profileslight_interaction()
: Different variants of Friedman’s H
statistics (Friedman and Popescu 2008)light_breakdown()
: Variable contribution breakdown
(approximate SHAP) for single observations (Gosiewska and Biecek
2019)light_global_surrogate()
: Global surrogate trees (Molnar
2019)Good to know:
flashlight
(see examples and Section “flashlights”).multiflashlight()
.plot()
visualizes the results via
{ggplot2}.# From CRAN
install.packages("flashlight")
# Development version
::install_github("mayer79/flashlight") devtools
Let’s start with an iris example. For simplicity, we do not split the data into training and testing/validation sets.
library(ggplot2)
library(MetricsWeighted)
library(flashlight)
<- lm(Sepal.Length ~ ., data = iris)
fit_lm
# Make explainer object
<- flashlight(
fl_lm model = fit_lm,
data = iris,
y = "Sepal.Length",
label = "lm",
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
|>
fl_lm light_performance() |>
plot(fill = "darkred") +
labs(x = element_blank(), title = "Performance on training data")
|>
fl_lm light_performance(by = "Species") |>
plot(fill = "darkred") +
ggtitle("Performance split by Species")
Error bars represent standard errors, i.e., the uncertainty of the estimated importance.
|>
fl_lm light_importance(m_repetitions = 4) |>
plot(fill = "darkred") +
labs(title = "Permutation importance", y = "Increase in RMSE")
Petal.Width
|>
fl_lm light_ice("Sepal.Width", n_max = 200) |>
plot(alpha = 0.3, color = "chartreuse4") +
labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction")
|>
fl_lm light_ice("Sepal.Width", n_max = 200, center = "middle") |>
plot(alpha = 0.3, color = "chartreuse4") +
labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)")
### PDPs
|>
fl_lm light_profile("Sepal.Width", n_bins = 40) |>
plot() +
ggtitle("PDP for 'Sepal.Width'")
|>
fl_lm light_profile("Sepal.Width", n_bins = 40, by = "Species") |>
plot() +
ggtitle("Same grouped by 'Species'")
|>
fl_lm light_profile2d(c("Petal.Width", "Petal.Length")) |>
plot()
|>
fl_lm light_profile("Sepal.Width", type = "ale") |>
plot() +
ggtitle("ALE plot for 'Sepal.Width'")
|>
fl_lm light_effects("Sepal.Width") |>
plot(use = "all") +
ggtitle("Different types of profiles for 'Sepal.Width'")
|>
fl_lm light_breakdown(new_obs = iris[1, ]) |>
plot()
|>
fl_lm light_global_surrogate() |>
plot()
### Multiple models
Multiple flashlights can be combined to a multiflashlight.
library(rpart)
<- rpart(
fit_tree ~ .,
Sepal.Length data = iris,
control = list(cp = 0, xval = 0, maxdepth = 5)
)
# Make explainer object
<- flashlight(
fl_tree model = fit_tree,
data = iris,
y = "Sepal.Length",
label = "tree",
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
# Combine with other explainer
<- multiflashlight(list(fl_tree, fl_lm))
fls
|>
fls light_performance() |>
plot(fill = "chartreuse4") +
labs(x = "Model", title = "Performance")
|>
fls light_importance() |>
plot(fill = "chartreuse4") +
labs(y = "Increase in RMSE", title = "Permutation importance")
|>
fls light_profile("Petal.Length", n_bins = 40) |>
plot() +
ggtitle("PDP")
|>
fls light_profile("Petal.Length", n_bins = 40, by = "Species") |>
plot() +
ggtitle("PDP by Species")
The “flashlight” explainer expects the following information:
model
: Fitted model. Currently, this argument must be
named.data
: Reference data used to calculate things, often
part of the validation data.y
: Column name in data
corresponding to
the numeric response.predict_function
: function of the same signature as
stats::predict()
. It takes a model
and a
data.frame data
, and provides numeric predictions, see
below for more details.linkinv
: Optional function applied to the output of
predict_function()
. Should actually be called
“trafo”.w
: Optional column name in data
corresponding to case weights.by
: Optional column name in data
used to
group the results. Must be discrete.metrics
: List of metrics, by default
list(rmse = MetricsWeighted::rmse)
. For binary
(probabilistic) classification, good candidate metrics would be
MetricsWeighted::logLoss
.label
: Mandatory name of the model.predict_function
s (a selection)The default stats::predict()
works for models of
class
lm()
,glm()
(for predictions on link scale), andrpart()
.It also works for meta-learner models like
Manual prediction functions are, e.g., required for
function(m, X) predict(m, X)$predictions
for regression, and
function(m, X) predict(m, X)$predictions[, 2]
for
probabilistic binary classificationglm()
: Use
function(m, X) predict(m, X, type = "response")
to get GLM
predictions at the response scaleA bit more complicated are models whose native predict function do not work on data.frames:
Example (XGBoost):
This works when non-numeric features are all factors (not categoricals):
<- vector of features
x = function(m, df) predict(m, data.matrix(df[x])) predict_function
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.