Introduction and Basics: PRISM

Thomas Jemielita

Introduction

Welcome to the StratifiedMedicine R package. The overall goal of this package is to develop analytic and visualization tools to aid in stratified and personalized medicine. Stratified medicine aims to find subsets or subgroups of patients with similar treatment effects, for example responders vs non-responders, while personalized medicine aims to understand treatment effects at the individual level (does a specific individual respond to treatment A?). Development of this package is ongoing.

Currently, the main algorithm in this package is “PRISM” (Patient Responder Identifiers for Stratified Medicine; Jemielita and Mehrotra 2019 in progress). Given a data-structure of \((Y, A, X)\) (outcome, treatments, covariates), PRISM is a five step procedure:

  1. Estimand: Determine the question or estimand of interest. For example, \(\theta_0 = E(Y|A=1)-E(Y|A=0)\), where A is a binary treatment variable. While this isn’t an explicit step in the PRISM function, the question of interest guides how to set up PRISM.

  2. Filter (filter): Reduce covariate space by removing variables unrelated to outcome/treatment. Formally: \[ filter(Y, A, X) \longrightarrow (Y, A, X^{\star}) \] where \(X^{\star}\) has potentially lower dimension than \(X\).

  3. Patient-level estimate (ple): Estimate counterfactual patient-level quantities, for example the individual treatment effect, \(\theta(x) = E(Y|X=x,A=1)-E(Y|X=x,A=0)\). Formally: \[ ple(Y, A, X^{\star}) \longrightarrow \hat{\mathbf{\theta}}(X^{\star}) \] where \(\hat{\mathbf{\theta}}(X^{\star})\) is the vector of patient-level estimates.

  4. Subgroup model (submod): Partition the data into subsets of patients (likely with similar treatment effects). Formally: \[ submod(Y, A, X^{\star}, \hat{\mathbf{\theta}}(X^{\star})) \longrightarrow \mathbf{S}(X^{\star}) \] where \(\mathbf{S}(X^{\star})\) is a distinct set of rules that define the \(k=0,...,K\) discovered subgroups, for example \(\mathbf{S}(X^{\star}) = \{X_1=0, X_2=0\}\). Note that subgroups could be formed using the observed outcomes, PLEs, or both. By default, \(k=0\) corresponds to the overall population.

  5. Parameter estimation and inference (param): For the overall population and discovered subgroups, output point estimates and variability metrics. Formally: \[ param(Y, A, X^{\star}, \hat{\mathbf{\theta}}(X^{\star}), \mathbf{S}(X^{\star}) ) \longrightarrow \{ \hat{\theta}_{k}, SE(\hat{\theta}_k), CI_{\alpha}(\hat{\theta}_{k}), P(\hat{\theta}_{k} > c) \} \text{ for } k=0,...K \] where \(\hat{\theta}_{k}\) is the point-estimate, \(SE(\hat{\theta}_k)\) is the standard error, \(CI_{\alpha}(\hat{\theta}_{k})\) is a two (or one) sided confidence interval with nominal coverage \(1-\alpha\), and \(P(\hat{\theta}_{k} > c)\) is a probability statement for some constant \(c\) (ex: \(c=0\)). These outputs are crucial for Go-No-Go decision making.

Ultimately, PRISM provides information at the patient-level, the subgroup-level (if any), and the overall population. While there are defaults in place, the user can also input their own functions/model wrappers into the PRISM algorithm. We will demonstrate this later.

Example: Continuous Outcome with Binary Treatment

Next, we demonstrate PRISM for a continuous outcome with a binary treatment. The estimand of interest is the average treatment effect, \(\theta_0 = E(Y|A=1)-E(Y|A=0)\). First, we simulate continuous data where roughly 30% of the patients receive no treatment-benefit for using \(A=1\) vs \(A=0\). Responders vs non-responders are defined by the continuous predictive covariates \(X_1\) and \(X_2\).

library(StratifiedMedicine)
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
A = dat_ctns$A # binary treatment, 1:1 randomized 
length(Y)
#> [1] 800
table(A)
#> A
#>   0   1 
#> 409 391
dim(X)
#> [1] 800  50

Next, we run the default PRISM setting for continuous data: filter_glmnet (elastic net), ple_ranger (treatment-specific random forest models), submod_lmtree (model-based partitioning with OLS loss), and param_ple (parameter estimation/inference through the PLEs). Jemielita and Mehrotra 2019 (in progress) show that this configuration performs quite well in terms of bias, efficiency, coverage, and selecting the right predictive covariates.

# PRISM Default: filter_glmnet, ple_ranger, submod_lmtree, param_ple #
res0 = PRISM(Y=Y, A=A, X=X)
#> Observed Data
#> Filtering: filter_glmnet
#> PLE: ple_ranger
#> Subgroup Identification: submod_lmtree
#> Parameter Estimation: param_ple
## This is the same as running ##
res1 = PRISM(Y=Y, A=A, X=X, family="gaussian", filter="filter_glmnet", 
             ple = "ple_ranger", submod = "submod_lmtree", param="param_ple")
#> Observed Data
#> Filtering: filter_glmnet
#> PLE: ple_ranger
#> Subgroup Identification: submod_lmtree
#> Parameter Estimation: param_ple
## Plot the distribution of PLEs ###
hist(res0$mu_train$PLE, main="Estimated Distribution of PLEs",
     xlab = "Estimated PLEs: E(Y|X=x, A=1)-E(Y|X=x,A=0)")

## Plot of the subgroup model (lmtree) ##
plot(res0$Sub.mod)

## Overall/subgroup specific parameter estimates/inference
res0$param.dat
#>   Subgrps   N         est         SE          LCL       UCL       pval
#> 1       0 800 0.119163550 0.06277086 -0.004051715 0.2423788 0.05800485
#> 2       3  71 0.269953773 0.20845268 -0.145791990 0.6856995 0.19956283
#> 3       5 303 0.001823749 0.10423555 -0.203296193 0.2069437 0.98605215
#> 4       6 107 0.145690459 0.16108360 -0.173673450 0.4650544 0.36781312
#> 5       8 221 0.124661570 0.11491806 -0.101819583 0.3511427 0.27920349
#> 6       9  98 0.331351436 0.18714952 -0.040088556 0.7027914 0.07978260
## Forest plot: Overall/subgroup specific parameter estimates (CIs)
plot(res0)

A key output of PRISM is the object “param.dat”, which includes point-estimates and variability metrics. Lower confidence limits (LC) and upper confidence limits (UCL) are formed based on inputs alpha_ovrl (overall) and alpha_s (subgroups); default is 0.05 for both. The plot() function creates a forest plot based on the “param.dat” object.

The hyper-parameters for the individual steps of PRISM can also be easily modified. For example, “filter_glmnet” by default selects covariates based on “lambda.min”, “ple_ranger” requires nodes to contain at least 10% of the total observations, and “submod_lmtree” requires nodes to contain at least 5% of the total observations. To modify this:

# PRISM Default: filter_glmnet, ple_ranger, submod_lmtree, param_ple #
# Change hyper-parameters #
res_new_hyper = PRISM(Y=Y, A=A, X=X, filter.hyper = list(lambda="lambda.1se"),
                      ple.hyper = list(min.node.pct=0.05), 
                      submod.hyper = list(minsize=200))
#> Observed Data
#> Filtering: filter_glmnet
#> PLE: ple_ranger
#> Subgroup Identification: submod_lmtree
#> Parameter Estimation: param_ple
plot(res_new_hyper$Sub.mod) # Plot subgroup model results

plot(res_new_hyper)

Resampling is also a feature in PRISM. Both bootstrap (resample=“Bootstrap”) and permutation (resample=“Permutation”) based-resampling are included; this can be stratified by the discovered subgroups (default: stratify=TRUE). Resampling can be useful for obtaining valid inference and estimating posterior probabilities.

library(ggplot2)
#> Warning: package 'ggplot2' was built under R version 3.5.2
res_boot = PRISM(Y=Y, A=A, X=X, resample = "Bootstrap", R=50, verbose=FALSE)
#> Warning: package 'bindrcpp' was built under R version 3.5.2
# # Plot of distributions and P(est>0) #
plot(res_boot, type="resample")+geom_vline(xintercept = 0)

aggregate(I(est>0)~Subgrps, data=res_boot$resamp.dist, FUN="mean")
#>   Subgrps I(est > 0)
#> 1       0       0.94
#> 2       3       0.80
#> 3       5       0.74
#> 4       6       0.82
#> 5       8       0.90
#> 6       9       1.00

Example: Survival Outcome with Binary Treatment

Survival outcomes are also allowed in PRISM. The default settings use glmnet to filter (“filter_glmnet”“) and estimate patient-level hazard ratios (”ple_glmnet“”), “submod_weibull”" (MOB with weibull loss function) for subgroup identification, and param_cox (subgroup-specific cox regression models). Another subgroup option is to use “submod_ctree”“, which uses the conditional inference tree (CTREE) algorithm to find subgroups; this looks for partitions irrespective of treatment assignment and thus corresponds to finding prognostic effects.

library(survival)
library(ggplot2)
# Load TH.data (no treatment; generate treatment randomly to simulate null effect) ##
data("GBSG2", package = "TH.data")
surv.dat = GBSG2
# Design Matrices ###
Y = with(surv.dat, Surv(time, cens))
X = surv.dat[,!(colnames(surv.dat) %in% c("time", "cens")) ]
A = rbinom( n = dim(X)[1], size=1, prob=0.5  )

# Default: filter_glmnet ==> submod_weibull (MOB with Weibull) ==> param_cox (Cox regression)
res_weibull1 = PRISM(Y=Y, A=A, X=X, ple=NULL)
#> Observed Data
#> Filtering: filter_glmnet
#> PLE: ple_glmnet
#> Subgroup Identification: submod_weibull
#> Parameter Estimation: param_cox
plot(res_weibull1$Sub.mod)

plot(res_weibull1)+ylab("HR [95% CI]")


# PRISM: filter_glmnet ==> submod_ctree ==> param_cox (Cox regression) #
res_ctree1 = PRISM(Y=Y, A=A, X=X, ple=NULL, submod = "submod_ctree")
#> Observed Data
#> Filtering: filter_glmnet
#> PLE: ple_glmnet
#> Subgroup Identification: submod_ctree
#> Parameter Estimation: param_cox
plot(res_ctree1$Sub.mod)

plot(res_ctree1)+ylab("HR [95% CI]")

User-Created Models

One advantage of PRISM is the flexibility to adjust each step of the algorithm and also to input user-created functions/models. This facilitates faster testing and experimentation. First, let’s simulate the continuous data again.

library(StratifiedMedicine)
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
A = dat_ctns$A # binary treatment, 1:1 randomized 

Next, let’s going over the basic function template for the “filter”, “ple”, “submod”, and “param” steps in PRISM. For “filter,” consider the lasso:

filter_lasso = function(Y, A, X, lambda="lambda.min", family="gaussian", ...){
  require(glmnet)
  ## Model matrix X matrix #
  X = model.matrix(~. -1, data = X )

  ##### Elastic Net on estimated ITEs #####
  set.seed(6134)
  if (family=="survival") { family = "cox"  }
  mod <- cv.glmnet(x = X, y = Y, nlambda = 100, alpha=1, family=family)

  ### Extract filtered variable based on lambda ###
  VI <- coef(mod, s = lambda)[,1]
  VI = VI[-1]
  filter.vars = names(VI[VI!=0])
  return( list(mod=mod, filter.vars=filter.vars) )
}

Note that the filter uses the observed data (Y,A,X), which are required inputs, and outputs an object called “filter.vars.” This needs to contain the variable names of the variables that pass the filtering step. An option to change lambda, which can change which variables remain after filtering (lambda.min keeps more, lambda.1se keeps less), while not required, is included. This can then be adjusted in the “filter.hyper” argument in PRISM.

For “ple,” consider treatment-specific random forest (ranger) models.

ple_ranger_mtry = function(Y, A, X, Xtest, mtry=5, ...){
   require(ranger)
   ## Split data by treatment ###
    train0 =  data.frame(Y=Y[A==0], X[A==0,])
    train1 =  data.frame(Y=Y[A==1], X[A==1,])
    # Trt 0 #
    mod0 <- ranger(Y ~ ., data = train0, seed=1, mtry = mtry)
    # Trt 1 #
    mod1 <- ranger(Y ~ ., data = train1, seed=2, mtry = mtry)
    mods = list(mod0=mod0, mod1=mod1)
    ## Predictions: Train/Test ##
    mu_train = data.frame( mu1 =  predict(mod1, data = X)$predictions,
                             mu0 = predict(mod0, data = X)$predictions)
    mu_train$PLE = with(mu_train, mu1 - mu0 )

    mu_test = data.frame( mu1 =  predict(mod1, data = Xtest)$predictions,
                            mu0 = predict(mod0, data = Xtest)$predictions)
    mu_test$PLE = with(mu_test, mu1 - mu0 )
    return( list(mods=mods, mu_train=mu_train, mu_test=mu_test))
}

For the “ple” model, the only required arguments are the observed data (Y,A,X) and Xtest. By default, if Xtest is not provided in PRISM, it uses the training X instead. The only required outputs are mu_train and mu_test, which include a vector/column named “PLE” (Patient-level estimates). In this example, we allow “mtry” (number of variables randomly selected at each split) to vary and can be altered in the “ple.hyper” argument in PRISM.

For “submod”, consider a modified version of “submod_lmtree” where we search for predictive effects only. By default, “submod_lmtree” searches for prognostic and/or predictive effects.

submod_lmtree_pred = function(Y, A, X, Xtest, mu_train, ...){
  require(partykit)
  ## Fit Model ##
  mod <- lmtree(Y~A | ., data = X, parm=2) ##parm=2 focuses on treatment interaction #
  ##  Predict Subgroups for Train/Test ##
  Subgrps.train = as.numeric( predict(mod, type="node") )
  Subgrps.test = as.numeric( predict(mod, type="node", newdata = Xtest) )
  ## Predict E(Y|X=x, A=1)-E(Y|X=x,A=0) ##
  pred.train = predict( mod, newdata = data.frame(A=1, X) ) -
    predict( mod, newdata = data.frame(A=0, X) )
  pred.test =  predict( mod, newdata = data.frame(A=1, Xtest) ) -
    predict( mod, newdata = data.frame(A=0, Xtest) )
  ## Return Results ##
  return(  list(mod=mod, Subgrps.train=Subgrps.train, Subgrps.test=Subgrps.test,
                pred.train=pred.train, pred.test=pred.test) )
}

For the “submod” model, the only required arguments are the observed data (Y,A,X) and Xtest. By default, if Xtest is not provided in PRISM, it uses the training X instead. Note that “mu_train” is included as an argument here. If a subgroup model uses the PLEs to form groups, then this is a required argument. Required outputs include mod (the subgroup model) and Subgrps.train/Subgrps.test (predicted subgroups in train/test set). This function does also include predicted treatment effects for train/test; these aren’t required but are needed if the subgroup model is explicitly used in the final parameter estimation step.

Lastly, the “param” model outputs parameter estimates and variability metrics for the overall population and discovered subgroups. Just for demonstration, one option could be to fit subgroup-specific robust regression (M-estimator) models.


### Robust linear Regression: E(Y|A=1) - E(Y|A=0) ###
param_rlm = function(Y, A, X, mu_hat, Subgrps, alpha_ovrl, alpha_s, ...){
  require(MASS)
  indata = data.frame(Y=Y,A=A, X)
  mod.ovrl = rlm(Y ~ A , data=indata)
  param.dat0 = data.frame( Subgrps=0, N = dim(indata)[1],
                           est = summary(mod.ovrl)$coefficients[2,1],
                           SE = summary(mod.ovrl)$coefficients[2,2] )
  param.dat0$LCL = with(param.dat0, est-qt(1-alpha_ovrl/2, N-1)*SE)
  param.dat0$UCL = with(param.dat0, est+qt(1-alpha_ovrl/2, N-1)*SE)
  param.dat0$pval = with(param.dat0, 2*pt(-abs(est/SE), df=N-1) )

  ## Subgroup Specific Estimate ##
  looper = function(s){
    rlm.mod = tryCatch( rlm(Y ~ A , data=indata[Subgrps==s,]),
                       error = function(e) "param error" )
    n.s = dim(indata[Subgrps==s,])[1]
    est = summary(rlm.mod)$coefficients[2,1]
    SE = summary(rlm.mod)$coefficients[2,2]
    LCL =  est-qt(1-alpha_ovrl/2, n.s-1)*SE
    UCL =  est+qt(1-alpha_ovrl/2, n.s-1)*SE
    pval = 2*pt(-abs(est/SE), df=n.s-1)
    return( c(est, SE, LCL, UCL, pval) )
  }
  S_levels = as.numeric( names(table(Subgrps)) )
  S_N = as.numeric( table(Subgrps) )
  param.dat = lapply(S_levels, looper)
  param.dat = do.call(rbind, param.dat)
  param.dat = data.frame( S = S_levels, N=S_N, param.dat)
  colnames(param.dat) = c("Subgrps", "N", "est", "SE", "LCL", "UCL", "pval")
  param.dat = rbind( param.dat0, param.dat)
  return( param.dat )
}

Finally, let’s put it all together directly into PRISM:


res_user1 = PRISM(Y=Y, A=A, X=X, family="gaussian", filter="filter_lasso", 
             ple = "ple_ranger_mtry", submod = "submod_lmtree_pred",
             param="param_rlm")
#> Observed Data
#> Filtering: filter_lasso
#> Loading required package: glmnet
#> Warning: package 'glmnet' was built under R version 3.5.3
#> Loading required package: Matrix
#> Loading required package: foreach
#> Warning: package 'foreach' was built under R version 3.5.2
#> Loaded glmnet 2.0-18
#> PLE: ple_ranger_mtry
#> Loading required package: ranger
#> Warning: package 'ranger' was built under R version 3.5.2
#> Subgroup Identification: submod_lmtree_pred
#> Loading required package: partykit
#> Warning: package 'partykit' was built under R version 3.5.2
#> Loading required package: grid
#> Loading required package: libcoin
#> Warning: package 'libcoin' was built under R version 3.5.2
#> Loading required package: mvtnorm
#> Warning: package 'mvtnorm' was built under R version 3.5.2
#> Parameter Estimation: param_rlm
#> Loading required package: MASS
## variables that remain after filtering ##
res_user1$filter.vars
#>  [1] "X1"  "X2"  "X3"  "X5"  "X7"  "X8"  "X10" "X12" "X16" "X18" "X24"
#> [12] "X26" "X31" "X40" "X46" "X50"
## Subgroup model: lmtree searching for predictive only ##
plot(res_user1$Sub.mod)

## Parameter estimates/inference
res_user1$param.dat
#>   Subgrps   N        est         SE         LCL       UCL       pval
#> 1       0 800 0.14576659 0.07962457 -0.01053146 0.3020646 0.06752160
#> 2       2 481 0.08694906 0.09989682 -0.10934005 0.2832382 0.38452309
#> 3       3 319 0.22210668 0.12319010 -0.02026392 0.4644773 0.07234124
## Forest Plot (95% CI) ##
plot(res_user1)

Conclusion

Overall, PRISM is a flexible algorithm that can aid in subgroup detection and exploration of heterogeneous treatment effects. Each step of PRISM is customizable, allowing for fast experimentation and improvement of individual steps. The StratifiedMedicine R package and PRISM will be continually updated and improved. User-feedback will further faciliate improvements.