The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

SMMAL_vignette

Introduction

This vignette demonstrates how to use the SMMAL package to estimate the Average Treatment Effect (ATE) using semi-supervised machine learning. We provide an example dataset and walk through the required input format and function usage.

Import Sample data.

Sample data contain 1000 observations with 60% of Y and A missing at random. Y is the outcome. A is the treatment indicator. X are the covariates. S are the surrogates.

For the sample data, missingness occurs at random and is encoded as NA. This package can handle datasets with a high proportion of missing values, but it requires a sufficiently large sample size to ensure that each fold in cross-validation contains at least 20 labeled observations.

library(SMMAL)

file_path <- system.file("extdata", "sample_data_withmissing.rds", package = "SMMAL")
dat <- readRDS(file_path)


file_path2 <- system.file("extdata", "semi_supervised_data.rds", package = "SMMAL")
data_loaded <- readRDS(file_path2)

Prepare Inputs

Input file S and X needs to be data frame, even if they are vectors.

  # Y and A are numeric vector 
  Y <- dat$Y
  A <- dat$A
  
  # S and X needs to be data frame
  S <- data.frame(dat$S)
  X <- data.frame(dat$X)

Estimate ATE with SMMAL & Output

Users can choose which model to use for the nuisance functions by setting the cf_model parameter. If no cf_model is indicated, the default value is “bspline”.

After cross-validation and prediction, the best-performing model is selected based on the lowest cross-entropy (log loss). Users can control how many folds are used in cross-validation by setting the nfold parameter. If no nfold is indicated, the default value is 5.

 SMMAL_output1 <- SMMAL(Y=Y,A=A,S=S,X=X)
 print(SMMAL_output1)
#> $est
#> [1] 0.1048938
#> 
#> $se
#> [1] 0.04907399

Other options for cf_model are “xgboost”

SMMAL_output2 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "xgboost")
print(SMMAL_output2)
#> $est
#> [1] 0.09174881
#> 
#> $se
#> [1] 0.05081966

or “random forest”

SMMAL_output3 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "randomforest")
print(SMMAL_output3)
#> $est
#> [1] 0.09918991
#> 
#> $se
#> [1] 0.05248152

or “glm”

 SMMAL_output4 <- SMMAL(Y=Y,A=A,S=S,X=X, cf_model= "glm")
 print(SMMAL_output4)
#> $est
#> [1] 0.1341408
#> 
#> $se
#> [1] 0.05347756

Using Your Own custom_model_fun

Users may customize the feature‐selection or penalization strategy by supplying their own function through the custom_model_fun argument. To do so, pass a function that meets these requirements:

  1. Function Signature It must accept exactly these arguments (in this order): X, Y, foldid_labelled, sub_set, labeled_indices, nfold, log_loss

(X, Y, foldid_labelled, sub_set, labeled_indices, and nfold are used internally by SMMAL to partition and fit the data.)

(log_loss is a function for computing cross‐entropy (log‐loss). Your function should call log_loss(true_labels, predicted_probs) to evaluate each tuning parameter.)

  1. Return Value It must return a list of length equal to the number of “ridge” penalty values defined in param_fun(). Each element of that list should be a numeric vector of length n containing out‐of‐fold predicted probabilities for all observations—i.e., it should stack together predictions from every held‐out fold (no NA values, except where Y is genuinely missing).

Below is an example showing how to plug in the packaged SMMAL_ada_lasso() as custom_model_fun. In practice, you could substitute any function with the same signature and return type:

 SMMAL_output5 <- SMMAL(Y=Y,A=A,S=S,X=X, custom_model_fun = SMMAL_ada_lasso)
 print(SMMAL_output5)
#> $est
#> [1] 0.1280551
#> 
#> $se
#> [1] 0.05729187

Understanding SMMAL_ada_lasso

SMMAL_ada_lasso
#> function (X, Y, X_full, foldid, foldid_labelled, sub_set, labeled_indices, 
#>     nfold, log_loss) 
#> {
#>     fold_predictions <- NULL
#>     param_grid <- param_fun()
#>     ridge_list <- param_grid$ridge
#>     lambda_list <- param_grid$lambda
#>     fold_predictions <- vector("list", length(ridge_list) * length(lambda_list))
#>     for (r in seq_along(ridge_list)) {
#>         ridge_val <- ridge_list[[r]]
#>         for (i in seq_along(lambda_list)) {
#>             lambda <- lambda_list[[i]]
#>             ridge_fit_all <- glmnet::glmnet(X, Y, lambda = ridge_val, 
#>                 alpha = 0, family = "binomial")
#>             ridge_coef <- as.numeric(coef(ridge_fit_all))[-1]
#>             penalty_factors <- 1/(abs(ridge_coef) + 1e-04)
#>             all_preds_matrix <- matrix(NA, nrow = length(foldid), 
#>                 ncol = 1)
#>             for (ifold in seq_len(nfold)) {
#>                 trainpos <- which((foldid_labelled != ifold) & 
#>                   sub_set[labeled_indices])
#>                 testpos <- which(foldid == ifold)
#>                 X_train <- as.matrix(X[trainpos, , drop = FALSE])
#>                 Y_train <- as.numeric(Y[trainpos])
#>                 X_test <- as.matrix(X_full[testpos, , drop = FALSE])
#>                 valid_idx <- which(!is.na(Y_train))
#>                 X_train <- X_train[valid_idx, , drop = FALSE]
#>                 Y_train <- Y_train[valid_idx]
#>                 fit <- glmnet::glmnet(X_train, Y_train, lambda = lambda, 
#>                   alpha = 1, family = "binomial", penalty.factor = penalty_factors, 
#>                   maxit = 1e+06)
#>                 preds <- predict(fit, newx = X_test, type = "response")
#>                 all_preds_matrix[testpos, ] <- preds
#>             }
#>             fold_predictions[[(r - 1) * length(lambda_list) + 
#>                 i]] <- all_preds_matrix
#>         }
#>     }
#>     return(fold_predictions)
#> }
#> <bytecode: 0x00000253d775deb8>
#> <environment: namespace:SMMAL>

Input of SMMAL_ada_lasso

str(data_loaded)
#> List of 9
#>  $ X              : num [1:400, 1:50] 0.672 0.056 0.108 0.142 0.371 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : chr [1:400] "1" "2" "3" "4" ...
#>   .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#>  $ Y              : int [1:400] 0 0 0 0 1 0 0 1 1 0 ...
#>  $ X_full         : num [1:1000, 1:50] 0.672 0.056 0.108 0.142 0.732 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#>  $ foldid         : num [1:1000] 1 4 5 3 2 3 1 4 4 5 ...
#>  $ foldid_labelled: num [1:400] 1 4 5 3 1 4 5 2 2 3 ...
#>  $ sub_set        : logi [1:400] TRUE TRUE TRUE TRUE TRUE TRUE ...
#>  $ labeled_indices: int [1:400] 1 2 3 4 7 9 12 15 17 18 ...
#>  $ nfold          : num 5
#>  $ log_loss       :function (y_true, y_pred)  
#>   ..- attr(*, "srcref")= 'srcref' int [1:8] 1 15 5 3 15 3 1 5
#>   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x00000253d568a9d0>

Input: X, Y, foldid_labelled, sub_set, labeled_indices, nfold, log_loss

X: The full matrix of predictors for labelled observations

Y: Outcome vector of length n, binary, may contain NA for unlabeled rows.

X_full: The full matrix of predictors for all observations.

foldid: A vector assigning each observation (labelled or unlabelled) to a fold.

foldid_labelled: Integer vector assigning labeled rows to CV folds (1 to nfold); NA for unlabeled.

sub_set: Logical or integer vector indicating rows included in supervised CV.

labeled_indices: Indices of labeled observations (where Y is not missing).

nfold: Number of cross-validation folds (e.g., 5 or 10).

log_loss: Function that computes log-loss: log_loss(true_labels, pred_probs) returns a single numeric.

Demonstration of how to run SMMAL_ada_lasso & Output of SMMAL_ada_lasso

Output:fold_predictions

When you use SMMAL_ada_lasso() as a custom_model_fun, it returns a list of numeric vectors where each element is a numeric vector of length equal to the total number of observations, containing the cross-validated predicted probabilities for the corresponding ridge value.

Below is a sample run & output of SMMAL_ada_lasso

SMMAL_fold_predictions <-SMMAL_ada_lasso(
  X = data_loaded$X,
  Y = data_loaded$Y,
  X_full = data_loaded$X_full,
  foldid = data_loaded$foldid,
  foldid_labelled = data_loaded$foldid_labelled,
  sub_set = data_loaded$sub_set,
  labeled_indices = data_loaded$labeled_indices,
  nfold = data_loaded$nfold,
  log_loss = data_loaded$log_loss
)

str(SMMAL_fold_predictions)
#> List of 100
#>  $ : num [1:1000, 1] 0.485 0.486 0.439 0.359 0.442 ...
#>  $ : num [1:1000, 1] 0.482 0.53 0.421 0.369 0.426 ...
#>  $ : num [1:1000, 1] 0.485 0.548 0.41 0.38 0.425 ...
#>  $ : num [1:1000, 1] 0.496 0.566 0.408 0.393 0.434 ...
#>  $ : num [1:1000, 1] 0.51 0.583 0.4 0.402 0.439 ...
#>  $ : num [1:1000, 1] 0.525 0.602 0.395 0.409 0.441 ...
#>  $ : num [1:1000, 1] 0.54 0.61 0.399 0.415 0.447 ...
#>  $ : num [1:1000, 1] 0.559 0.611 0.409 0.423 0.455 ...
#>  $ : num [1:1000, 1] 0.58 0.607 0.42 0.431 0.463 ...
#>  $ : num [1:1000, 1] 0.588 0.603 0.438 0.434 0.474 ...
#>  $ : num [1:1000, 1] 0.598 0.602 0.46 0.443 0.484 ...
#>  $ : num [1:1000, 1] 0.61 0.603 0.485 0.461 0.482 ...
#>  $ : num [1:1000, 1] 0.623 0.603 0.514 0.48 0.483 ...
#>  $ : num [1:1000, 1] 0.637 0.601 0.525 0.498 0.494 ...
#>  $ : num [1:1000, 1] 0.647 0.594 0.538 0.517 0.508 ...
#>  $ : num [1:1000, 1] 0.648 0.584 0.552 0.534 0.519 ...
#>  $ : num [1:1000, 1] 0.642 0.578 0.563 0.543 0.526 ...
#>  $ : num [1:1000, 1] 0.628 0.577 0.567 0.562 0.533 ...
#>  $ : num [1:1000, 1] 0.621 0.575 0.572 0.581 0.54 ...
#>  $ : num [1:1000, 1] 0.613 0.575 0.574 0.581 0.549 ...
#>  $ : num [1:1000, 1] 0.475 0.404 0.469 0.353 0.455 ...
#>  $ : num [1:1000, 1] 0.485 0.443 0.448 0.359 0.451 ...
#>  $ : num [1:1000, 1] 0.49 0.484 0.429 0.369 0.45 ...
#>  $ : num [1:1000, 1] 0.494 0.526 0.412 0.381 0.436 ...
#>  $ : num [1:1000, 1] 0.499 0.557 0.406 0.394 0.431 ...
#>  $ : num [1:1000, 1] 0.514 0.576 0.4 0.404 0.437 ...
#>  $ : num [1:1000, 1] 0.529 0.595 0.391 0.41 0.441 ...
#>  $ : num [1:1000, 1] 0.545 0.612 0.395 0.412 0.447 ...
#>  $ : num [1:1000, 1] 0.563 0.614 0.404 0.42 0.454 ...
#>  $ : num [1:1000, 1] 0.581 0.609 0.415 0.428 0.462 ...
#>  $ : num [1:1000, 1] 0.591 0.605 0.433 0.433 0.474 ...
#>  $ : num [1:1000, 1] 0.601 0.603 0.455 0.442 0.484 ...
#>  $ : num [1:1000, 1] 0.613 0.604 0.476 0.46 0.482 ...
#>  $ : num [1:1000, 1] 0.627 0.604 0.509 0.482 0.48 ...
#>  $ : num [1:1000, 1] 0.642 0.603 0.523 0.5 0.489 ...
#>  $ : num [1:1000, 1] 0.651 0.596 0.536 0.519 0.504 ...
#>  $ : num [1:1000, 1] 0.652 0.586 0.549 0.537 0.518 ...
#>  $ : num [1:1000, 1] 0.65 0.579 0.561 0.541 0.525 ...
#>  $ : num [1:1000, 1] 0.637 0.577 0.566 0.559 0.532 ...
#>  $ : num [1:1000, 1] 0.622 0.575 0.571 0.579 0.539 ...
#>  $ : num [1:1000, 1] 0.48 0.455 0.436 0.384 0.468 ...
#>  $ : num [1:1000, 1] 0.502 0.493 0.409 0.397 0.459 ...
#>  $ : num [1:1000, 1] 0.525 0.535 0.396 0.407 0.452 ...
#>  $ : num [1:1000, 1] 0.539 0.578 0.386 0.413 0.447 ...
#>  $ : num [1:1000, 1] 0.557 0.614 0.385 0.416 0.445 ...
#>  $ : num [1:1000, 1] 0.575 0.62 0.391 0.419 0.45 ...
#>  $ : num [1:1000, 1] 0.596 0.616 0.402 0.423 0.458 ...
#>  $ : num [1:1000, 1] 0.6 0.611 0.418 0.431 0.471 ...
#>  $ : num [1:1000, 1] 0.607 0.608 0.442 0.441 0.48 ...
#>  $ : num [1:1000, 1] 0.619 0.607 0.459 0.456 0.478 ...
#>  $ : num [1:1000, 1] 0.634 0.607 0.483 0.476 0.475 ...
#>  $ : num [1:1000, 1] 0.649 0.606 0.514 0.502 0.477 ...
#>  $ : num [1:1000, 1] 0.657 0.6 0.532 0.523 0.491 ...
#>  $ : num [1:1000, 1] 0.659 0.59 0.545 0.54 0.507 ...
#>  $ : num [1:1000, 1] 0.661 0.582 0.559 0.542 0.521 ...
#>  $ : num [1:1000, 1] 0.651 0.579 0.565 0.556 0.53 ...
#>  $ : num [1:1000, 1] 0.638 0.576 0.569 0.573 0.537 ...
#>  $ : num [1:1000, 1] 0.622 0.575 0.574 0.581 0.545 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.554 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.564 0.575 0.376 0.42 0.468 ...
#>  $ : num [1:1000, 1] 0.598 0.606 0.374 0.422 0.466 ...
#>  $ : num [1:1000, 1] 0.615 0.616 0.381 0.431 0.472 ...
#>  $ : num [1:1000, 1] 0.62 0.622 0.393 0.436 0.471 ...
#>  $ : num [1:1000, 1] 0.623 0.619 0.419 0.441 0.475 ...
#>  $ : num [1:1000, 1] 0.63 0.614 0.434 0.454 0.472 ...
#>  $ : num [1:1000, 1] 0.645 0.611 0.451 0.471 0.468 ...
#>  $ : num [1:1000, 1] 0.662 0.611 0.479 0.497 0.465 ...
#>  $ : num [1:1000, 1] 0.664 0.602 0.514 0.527 0.479 ...
#>  $ : num [1:1000, 1] 0.666 0.593 0.538 0.538 0.492 ...
#>  $ : num [1:1000, 1] 0.669 0.584 0.553 0.543 0.506 ...
#>  $ : num [1:1000, 1] 0.662 0.58 0.561 0.555 0.52 ...
#>  $ : num [1:1000, 1] 0.652 0.578 0.567 0.572 0.532 ...
#>  $ : num [1:1000, 1] 0.64 0.576 0.571 0.581 0.542 ...
#>  $ : num [1:1000, 1] 0.624 0.575 0.574 0.581 0.551 ...
#>  $ : num [1:1000, 1] 0.61 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.582 0.556 0.393 0.429 0.493 ...
#>  $ : num [1:1000, 1] 0.595 0.575 0.377 0.433 0.491 ...
#>  $ : num [1:1000, 1] 0.604 0.591 0.37 0.438 0.488 ...
#>  $ : num [1:1000, 1] 0.623 0.604 0.37 0.442 0.486 ...
#>  $ : num [1:1000, 1] 0.639 0.61 0.388 0.447 0.48 ...
#>  $ : num [1:1000, 1] 0.644 0.62 0.405 0.453 0.472 ...
#>  $ : num [1:1000, 1] 0.652 0.623 0.419 0.463 0.463 ...
#>  $ : num [1:1000, 1] 0.666 0.619 0.436 0.485 0.458 ...
#>  $ : num [1:1000, 1] 0.668 0.607 0.465 0.513 0.463 ...
#>  $ : num [1:1000, 1] 0.67 0.598 0.497 0.527 0.477 ...
#>  $ : num [1:1000, 1] 0.673 0.588 0.533 0.537 0.49 ...
#>  $ : num [1:1000, 1] 0.673 0.581 0.555 0.546 0.503 ...
#>  $ : num [1:1000, 1] 0.664 0.58 0.561 0.561 0.518 ...
#>  $ : num [1:1000, 1] 0.653 0.578 0.567 0.578 0.53 ...
#>  $ : num [1:1000, 1] 0.642 0.577 0.571 0.582 0.543 ...
#>  $ : num [1:1000, 1] 0.628 0.576 0.574 0.581 0.553 ...
#>  $ : num [1:1000, 1] 0.615 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>   [list output truncated]

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.