The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

policy_learn

library("data.table")
library("polle")

This vignette is a guide to policy_learn() and some of the associated S3 methods. The purpose of policy_learn is to specify a policy learning algorithm and estimate an optimal policy. For details on the methodology, see the associated paper (Nordland and Holst 2023).

We consider a fixed two-stage problem as a general setup and simulate data using sim_two_stage() and create a policy_data object using policy_data():

d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("B", "BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd
#> Policy data with n = 2000 observations and maximal K = 2 stages.
#> 
#>      action
#> stage    0    1    n
#>     1 1017  983 2000
#>     2  819 1181 2000
#> 
#> Baseline covariates: B, BB
#> State covariates: L, C
#> Average utility: 0.84

Specifying and applying a policy learner

policy_learn() specify a policy learning algorithm via the type argument: Q-learning (ql), doubly robust Q-learning (drql), doubly robust blip learning (blip), policy tree learning (ptl), and outcome weighted learning (owl).

Because each policy learning type has varying control arguments, these are passed as a list using the control argument. To help the user set the required control arguments and to provide documentation, each type has a helper function control_type() which sets the default control arguments and overwrite values if supplied by the user.

As an example we specify a doubly robust blip learner:

pl_blip <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  )
)

For details on the implementation, see Algorithm 3 in (Nordland and Holst 2023). The only required control argument for blip learning is a model input. The blip_models argument expects a q_model. In this case we input a simple linear model as implemented in q_glm.

The output of policy_learn() is again a function:

pl_blip
#> Policy learner with arguments:
#> policy_data, g_models=NULL, g_functions=NULL,
#> g_full_history=FALSE, q_models, q_full_history=FALSE

In order to apply the policy learner we need to input a policy_data object and nuisance models g_models and q_models for computing the doubly robust score.


(po_blip <- pl_blip(
  pd,
  g_models = list(g_glm(), g_glm()),
  q_models = list(q_glm(), q_glm())
 ))
#> Policy object with list elements:
#> blip_functions, q_functions, action_set, stage_action_sets,
#> alpha, threshold, K
#> Use 'get_policy' to get the associated policy.

Cross-fitting the doubly robust score

Like policy_eval() is it possible to cross-fit the doubly robust score used as input to the policy model. The number of folds for the cross-fitting procedure is provided via the L argument. As default, the cross-fitted nuisance models are not saved. The cross-fitted nuisance models can be saved via the save_cross_fit_models argument:

pl_blip_cross <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  L = 2,
  save_cross_fit_models = TRUE
)
po_blip_cross <- pl_blip_cross(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )

From a user perspective, nothing has changed. However, the policy object now contains each of the cross-fitted nuisance models:

po_blip_cross$g_functions_cf
#> $`1`
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    -0.18321      0.15191      0.90737     -0.03865      0.18927      0.15088  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1384 
#> Residual Deviance: 1086  AIC: 1098
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.24410      0.13150      0.99426     -0.02289     -0.41777     -0.17383  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1349 
#> Residual Deviance: 1082  AIC: 1094
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE
#> 
#> $`2`
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    0.113952    -0.240397     1.142507    -0.094362    -0.009235    -0.101783  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1386 
#> Residual Deviance: 1065  AIC: 1077
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.15426      0.01307      0.96485     -0.08554     -0.33532     -0.12597  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1357 
#> Residual Deviance: 1102  AIC: 1114
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE

Realistic policy learning

Realistic policy learning is implemented for types ql, drql, blip and ptl (for a binary action set). The alpha argument sets the probability threshold for defining the realistic action set. For implementation details, see Algorithm 5 in (Nordland and Holst 2023). Here we set a 5% restriction:

pl_blip_alpha <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  alpha = 0.05,
  L = 2
)
po_blip_alpha <- pl_blip_alpha(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )

The policy object now lists the alpha level as well as the g-model used to define the realistic action set:

po_blip_alpha$alpha
#> [1] 0.05
po_blip_alpha$g_functions
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    -0.03295     -0.05107      1.02271     -0.06478      0.09582      0.02370  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1994 Residual
#> Null Deviance:       2772 
#> Residual Deviance: 2161  AIC: 2173
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.19814      0.07355      0.97991     -0.05280     -0.37163     -0.14598  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1994 Residual
#> Null Deviance:       2707 
#> Residual Deviance: 2186  AIC: 2198
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE

Implementation/Simulation and get_policy_functions()

A policy function is great for evaluating a given policy or even implementing or simulating from a single-stage policy. However, the function is not useful for implementing or simulating from a learned multi-stage policy. To access the policy function for each stage we use get_policy_functions(). In this case we get the second stage policy function:

pf_blip <- get_policy_functions(po_blip, stage = 2)

The stage specific policy requires a data.table with named columns as input and returns a character vector with the recommended actions:

pf_blip(
  H = data.table(BB = c("group2", "group1"),
                 L = c(1, 0),
                 C = c(1, 2))
)
#> [1] "1" "1"

Policy objects and get_policy()

Applying the policy learner returns a policy_object containing all of the components needed to specify the learned policy. In this the only component of the policy is a model for the blip function:

po_blip$blip_functions$stage_1$blip_model
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)     BBgroup2     BBgroup3            L            C  
#>      0.4076       0.2585       0.2231       0.1765       0.8624  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1995 Residual
#> Null Deviance:       56820 
#> Residual Deviance: 53220     AIC: 12250
#> 
#> attr(,"class")
#> [1] "q_glm"

To access and apply the policy itself use get_policy(), which behaves as a policy meaning that we can apply to any (suitable) policy_data object to get the policy actions:

get_policy(po_blip)(pd) |> head(4)
#> Key: <id, stage>
#>       id stage      d
#>    <int> <int> <char>
#> 1:     1     1      1
#> 2:     1     2      1
#> 3:     2     1      0
#> 4:     2     2      0

SessionInfo

sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: aarch64-apple-darwin23.5.0
#> Running under: macOS Sonoma 14.6.1
#> 
#> Matrix products: default
#> BLAS:   /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib 
#> LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib;  LAPACK version 3.12.0
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> time zone: Europe/Copenhagen
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] splines   stats     graphics  grDevices utils     datasets  methods  
#> [8] base     
#> 
#> other attached packages:
#> [1] ggplot2_3.5.1       data.table_1.15.4   polle_1.5          
#> [4] SuperLearner_2.0-29 gam_1.22-4          foreach_1.5.2      
#> [7] nnls_1.5           
#> 
#> loaded via a namespace (and not attached):
#>  [1] sass_0.4.9          utf8_1.2.4          future_1.33.2      
#>  [4] lattice_0.22-6      listenv_0.9.1       digest_0.6.36      
#>  [7] magrittr_2.0.3      evaluate_0.24.0     grid_4.4.1         
#> [10] iterators_1.0.14    mvtnorm_1.2-5       policytree_1.2.3   
#> [13] fastmap_1.2.0       jsonlite_1.8.8      Matrix_1.7-0       
#> [16] survival_3.6-4      fansi_1.0.6         scales_1.3.0       
#> [19] numDeriv_2016.8-1.1 codetools_0.2-20    jquerylib_0.1.4    
#> [22] lava_1.8.0          cli_3.6.3           rlang_1.1.4        
#> [25] mets_1.3.4          parallelly_1.37.1   future.apply_1.11.2
#> [28] munsell_0.5.1       withr_3.0.0         cachem_1.1.0       
#> [31] yaml_2.3.8          tools_4.4.1         parallel_4.4.1     
#> [34] colorspace_2.1-0    ranger_0.16.0       globals_0.16.3     
#> [37] vctrs_0.6.5         R6_2.5.1            lifecycle_1.0.4    
#> [40] pkgconfig_2.0.3     timereg_2.0.5       progressr_0.14.0   
#> [43] bslib_0.7.0         pillar_1.9.0        gtable_0.3.5       
#> [46] Rcpp_1.0.13         glue_1.7.0          xfun_0.45          
#> [49] tibble_3.2.1        highr_0.11          knitr_1.47         
#> [52] farver_2.1.2        htmltools_0.5.8.1   rmarkdown_2.27     
#> [55] labeling_0.4.3      compiler_4.4.1

References

Nordland, Andreas, and Klaus K. Holst. 2023. “Policy Learning with the Polle Package.” https://doi.org/10.48550/arXiv.2212.02335.

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.