The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
Let’s see an example for DALEX
package for
classification models for the survival problem for Titanic dataset. Here
we are using a dataset titanic_imputed
avaliable in the
DALEX
package. Note that this data was copied from the
stablelearner
package and changed for practicality.
library("DALEX")
head(titanic_imputed)
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, now it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch,
data = titanic_imputed, probability = TRUE)
model_titanic_rf
#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, probability = TRUE)
#>
#> Type: Probability estimation
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 10
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error (Brier s.): 0.1422968
The third step (it’s optional but useful) is to create a
DALEX
explainer for random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed[,-8],
y = titanic_imputed[,8],
label = "Random Forest")
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest
#> -> data : 2207 rows 7 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.14.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.01164526 , mean = 0.3215481 , max = 0.9899436
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.7923093 , mean = 0.0006086512 , max = 0.8905081
#> A new explainer has been created!
Use the feature_importance()
explainer to present
importance of particular features. Note that
type = "difference"
normalizes dropouts, and now they all
start in 0.
library("ingredients")
fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.3408062 Random Forest
#> 2 parch 0.3520488 Random Forest
#> 3 sibsp 0.3520933 Random Forest
#> 4 embarked 0.3527842 Random Forest
#> 5 age 0.3760269 Random Forest
#> 6 fare 0.3848921 Random Forest
plot(fi_rf)
As we see the most important feature is gender
. Next
three importnat features are class
, age
and
fare
. Let’s see the link between model response and these
features.
Such univariate relation can be calculated with
partial_dependence()
.
Kids 5 years old and younger have much higher survival probability.
pp_age <- partial_dependence(explain_titanic_rf, variables = c("age", "fare"))
head(pp_age)
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 fare Random Forest 0.0000000 0.3630884 0
#> 2 age Random Forest 0.1666667 0.5347603 0
#> 3 age Random Forest 2.0000000 0.5536098 0
#> 4 age Random Forest 4.0000000 0.5595259 0
#> 5 fare Random Forest 6.1793080 0.3100674 0
#> 6 age Random Forest 7.0000000 0.5159751 0
plot(pp_age)
cp_age <- conditional_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(cp_age)
ap_age <- accumulated_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(ap_age)
Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
First Ceteris Paribus Profiles for numerical variables
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
show_observations(sp_rf)
And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.
plot(sp_rf,
variables = c("class", "embarked", "gender", "sibsp"),
variable_type = "categorical")
It looks like the most important feature for this passenger is
age
and sex
. After all his odds for survival
are higher than for the average passenger. Mainly because of the young
age and despite of being a male.
passangers <- select_sample(titanic, n = 100)
sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
#> Top profiles :
#> _vname_ _label_ _x_ _cluster_ _yhat_ _ids_
#> 1 fare Random Forest_1 0.0000000 1 0.2372045 0
#> 2 parch Random Forest_1 0.0000000 1 0.1658665 0
#> 3 sibsp Random Forest_1 0.0000000 1 0.1699181 0
#> 4 age Random Forest_1 0.1666667 1 0.4653162 0
#> 5 parch Random Forest_1 1.0000000 1 0.2539302 0
#> 6 sibsp Random Forest_1 1.0000000 1 0.1519697 0
plot(sp_rf, alpha = 0.1) +
show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
sessionInfo()
#> R version 4.2.2 (2022-10-31)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS Monterey 12.5.1
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] C/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggplot2_3.4.0 ranger_0.14.1 ingredients_2.3.0 DALEX_2.4.2
#>
#> loaded via a namespace (and not attached):
#> [1] Rcpp_1.0.9 highr_0.10 bslib_0.4.2 compiler_4.2.2
#> [5] pillar_1.8.1 jquerylib_0.1.4 tools_4.2.2 digest_0.6.31
#> [9] lattice_0.20-45 jsonlite_1.8.4 evaluate_0.19 lifecycle_1.0.3
#> [13] tibble_3.1.8 gtable_0.3.1 pkgconfig_2.0.3 rlang_1.0.6
#> [17] Matrix_1.5-1 cli_3.6.0 rstudioapi_0.14 yaml_2.3.6
#> [21] xfun_0.36 fastmap_1.1.0 withr_2.5.0 stringr_1.5.0
#> [25] knitr_1.41 vctrs_0.5.1 sass_0.4.4 grid_4.2.2
#> [29] glue_1.6.2 R6_2.5.1 fansi_1.0.3 rmarkdown_2.19
#> [33] farver_2.1.1 magrittr_2.0.3 scales_1.2.1 htmltools_0.5.4
#> [37] colorspace_2.0-3 labeling_0.4.2 utf8_1.2.2 stringi_1.7.12
#> [41] munsell_0.5.0 cachem_1.0.6
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.