The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
This vignette demonstrates how to use the SMMAL package to estimate the Average Treatment Effect (ATE) using semi-supervised machine learning. We provide an example dataset and walk through the required input format and function usage.
Sample data contain 1000 observations with 60% of Y and A missing at random. Y is the outcome. A is the treatment indicator. X are the covariates. S are the surrogates.
For the sample data, missingness occurs at random and is encoded as NA. This package can handle datasets with a high proportion of missing values, but it requires a sufficiently large sample size to ensure that each fold in cross-validation contains at least 20 labeled observations.
Input file S and X needs to be data frame, even if they are vectors.
Users can choose which model to use for the nuisance functions by setting the cf_model parameter. If no cf_model is indicated, the default value is “bspline”.
After cross-validation and prediction, the best-performing model is selected based on the lowest cross-entropy (log loss). Users can control how many folds are used in cross-validation by setting the nfold parameter. If no nfold is indicated, the default value is 5.
SMMAL_output1 <- SMMAL(Y=Y,A=A,S=S,X=X)
print(SMMAL_output1)
#> $est
#> [1] 0.1048938
#>
#> $se
#> [1] 0.04907399
Other options for cf_model are “xgboost”
SMMAL_output2 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "xgboost")
print(SMMAL_output2)
#> $est
#> [1] 0.09174881
#>
#> $se
#> [1] 0.05081966
or “random forest”
SMMAL_output3 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "randomforest")
print(SMMAL_output3)
#> $est
#> [1] 0.09918991
#>
#> $se
#> [1] 0.05248152
or “glm”
Users may customize the feature‐selection or penalization strategy by supplying their own function through the custom_model_fun argument. To do so, pass a function that meets these requirements:
(X, Y, foldid_labelled, sub_set, labeled_indices, and nfold are used internally by SMMAL to partition and fit the data.)
(log_loss is a function for computing cross‐entropy (log‐loss). Your function should call log_loss(true_labels, predicted_probs) to evaluate each tuning parameter.)
Below is an example showing how to plug in the packaged SMMAL_ada_lasso() as custom_model_fun. In practice, you could substitute any function with the same signature and return type:
SMMAL_output5 <- SMMAL(Y=Y,A=A,S=S,X=X, custom_model_fun = SMMAL_ada_lasso)
print(SMMAL_output5)
#> $est
#> [1] 0.1280551
#>
#> $se
#> [1] 0.05729187
SMMAL_ada_lasso
#> function (X, Y, X_full, foldid, foldid_labelled, sub_set, labeled_indices,
#> nfold, log_loss)
#> {
#> fold_predictions <- NULL
#> param_grid <- param_fun()
#> ridge_list <- param_grid$ridge
#> lambda_list <- param_grid$lambda
#> fold_predictions <- vector("list", length(ridge_list) * length(lambda_list))
#> for (r in seq_along(ridge_list)) {
#> ridge_val <- ridge_list[[r]]
#> for (i in seq_along(lambda_list)) {
#> lambda <- lambda_list[[i]]
#> ridge_fit_all <- glmnet::glmnet(X, Y, lambda = ridge_val,
#> alpha = 0, family = "binomial")
#> ridge_coef <- as.numeric(coef(ridge_fit_all))[-1]
#> penalty_factors <- 1/(abs(ridge_coef) + 1e-04)
#> all_preds_matrix <- matrix(NA, nrow = length(foldid),
#> ncol = 1)
#> for (ifold in seq_len(nfold)) {
#> trainpos <- which((foldid_labelled != ifold) &
#> sub_set[labeled_indices])
#> testpos <- which(foldid == ifold)
#> X_train <- as.matrix(X[trainpos, , drop = FALSE])
#> Y_train <- as.numeric(Y[trainpos])
#> X_test <- as.matrix(X_full[testpos, , drop = FALSE])
#> valid_idx <- which(!is.na(Y_train))
#> X_train <- X_train[valid_idx, , drop = FALSE]
#> Y_train <- Y_train[valid_idx]
#> fit <- glmnet::glmnet(X_train, Y_train, lambda = lambda,
#> alpha = 1, family = "binomial", penalty.factor = penalty_factors,
#> maxit = 1e+06)
#> preds <- predict(fit, newx = X_test, type = "response")
#> all_preds_matrix[testpos, ] <- preds
#> }
#> fold_predictions[[(r - 1) * length(lambda_list) +
#> i]] <- all_preds_matrix
#> }
#> }
#> return(fold_predictions)
#> }
#> <bytecode: 0x00000253d775deb8>
#> <environment: namespace:SMMAL>
str(data_loaded)
#> List of 9
#> $ X : num [1:400, 1:50] 0.672 0.056 0.108 0.142 0.371 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : chr [1:400] "1" "2" "3" "4" ...
#> .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#> $ Y : int [1:400] 0 0 0 0 1 0 0 1 1 0 ...
#> $ X_full : num [1:1000, 1:50] 0.672 0.056 0.108 0.142 0.732 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : NULL
#> .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#> $ foldid : num [1:1000] 1 4 5 3 2 3 1 4 4 5 ...
#> $ foldid_labelled: num [1:400] 1 4 5 3 1 4 5 2 2 3 ...
#> $ sub_set : logi [1:400] TRUE TRUE TRUE TRUE TRUE TRUE ...
#> $ labeled_indices: int [1:400] 1 2 3 4 7 9 12 15 17 18 ...
#> $ nfold : num 5
#> $ log_loss :function (y_true, y_pred)
#> ..- attr(*, "srcref")= 'srcref' int [1:8] 1 15 5 3 15 3 1 5
#> .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x00000253d568a9d0>
Input: X, Y, foldid_labelled, sub_set, labeled_indices, nfold, log_loss
X: The full matrix of predictors for labelled observations
Y: Outcome vector of length n, binary, may contain NA for unlabeled rows.
X_full: The full matrix of predictors for all observations.
foldid: A vector assigning each observation (labelled or unlabelled) to a fold.
foldid_labelled: Integer vector assigning labeled rows to CV folds (1 to nfold); NA for unlabeled.
sub_set: Logical or integer vector indicating rows included in supervised CV.
labeled_indices: Indices of labeled observations (where Y is not missing).
nfold: Number of cross-validation folds (e.g., 5 or 10).
log_loss: Function that computes log-loss: log_loss(true_labels, pred_probs) returns a single numeric.
Output:fold_predictions
When you use SMMAL_ada_lasso() as a custom_model_fun, it returns a list of numeric vectors where each element is a numeric vector of length equal to the total number of observations, containing the cross-validated predicted probabilities for the corresponding ridge value.
Below is a sample run & output of SMMAL_ada_lasso
SMMAL_fold_predictions <-SMMAL_ada_lasso(
X = data_loaded$X,
Y = data_loaded$Y,
X_full = data_loaded$X_full,
foldid = data_loaded$foldid,
foldid_labelled = data_loaded$foldid_labelled,
sub_set = data_loaded$sub_set,
labeled_indices = data_loaded$labeled_indices,
nfold = data_loaded$nfold,
log_loss = data_loaded$log_loss
)
str(SMMAL_fold_predictions)
#> List of 100
#> $ : num [1:1000, 1] 0.485 0.486 0.439 0.359 0.442 ...
#> $ : num [1:1000, 1] 0.482 0.53 0.421 0.369 0.426 ...
#> $ : num [1:1000, 1] 0.485 0.548 0.41 0.38 0.425 ...
#> $ : num [1:1000, 1] 0.496 0.566 0.408 0.393 0.434 ...
#> $ : num [1:1000, 1] 0.51 0.583 0.4 0.402 0.439 ...
#> $ : num [1:1000, 1] 0.525 0.602 0.395 0.409 0.441 ...
#> $ : num [1:1000, 1] 0.54 0.61 0.399 0.415 0.447 ...
#> $ : num [1:1000, 1] 0.559 0.611 0.409 0.423 0.455 ...
#> $ : num [1:1000, 1] 0.58 0.607 0.42 0.431 0.463 ...
#> $ : num [1:1000, 1] 0.588 0.603 0.438 0.434 0.474 ...
#> $ : num [1:1000, 1] 0.598 0.602 0.46 0.443 0.484 ...
#> $ : num [1:1000, 1] 0.61 0.603 0.485 0.461 0.482 ...
#> $ : num [1:1000, 1] 0.623 0.603 0.514 0.48 0.483 ...
#> $ : num [1:1000, 1] 0.637 0.601 0.525 0.498 0.494 ...
#> $ : num [1:1000, 1] 0.647 0.594 0.538 0.517 0.508 ...
#> $ : num [1:1000, 1] 0.648 0.584 0.552 0.534 0.519 ...
#> $ : num [1:1000, 1] 0.642 0.578 0.563 0.543 0.526 ...
#> $ : num [1:1000, 1] 0.628 0.577 0.567 0.562 0.533 ...
#> $ : num [1:1000, 1] 0.621 0.575 0.572 0.581 0.54 ...
#> $ : num [1:1000, 1] 0.613 0.575 0.574 0.581 0.549 ...
#> $ : num [1:1000, 1] 0.475 0.404 0.469 0.353 0.455 ...
#> $ : num [1:1000, 1] 0.485 0.443 0.448 0.359 0.451 ...
#> $ : num [1:1000, 1] 0.49 0.484 0.429 0.369 0.45 ...
#> $ : num [1:1000, 1] 0.494 0.526 0.412 0.381 0.436 ...
#> $ : num [1:1000, 1] 0.499 0.557 0.406 0.394 0.431 ...
#> $ : num [1:1000, 1] 0.514 0.576 0.4 0.404 0.437 ...
#> $ : num [1:1000, 1] 0.529 0.595 0.391 0.41 0.441 ...
#> $ : num [1:1000, 1] 0.545 0.612 0.395 0.412 0.447 ...
#> $ : num [1:1000, 1] 0.563 0.614 0.404 0.42 0.454 ...
#> $ : num [1:1000, 1] 0.581 0.609 0.415 0.428 0.462 ...
#> $ : num [1:1000, 1] 0.591 0.605 0.433 0.433 0.474 ...
#> $ : num [1:1000, 1] 0.601 0.603 0.455 0.442 0.484 ...
#> $ : num [1:1000, 1] 0.613 0.604 0.476 0.46 0.482 ...
#> $ : num [1:1000, 1] 0.627 0.604 0.509 0.482 0.48 ...
#> $ : num [1:1000, 1] 0.642 0.603 0.523 0.5 0.489 ...
#> $ : num [1:1000, 1] 0.651 0.596 0.536 0.519 0.504 ...
#> $ : num [1:1000, 1] 0.652 0.586 0.549 0.537 0.518 ...
#> $ : num [1:1000, 1] 0.65 0.579 0.561 0.541 0.525 ...
#> $ : num [1:1000, 1] 0.637 0.577 0.566 0.559 0.532 ...
#> $ : num [1:1000, 1] 0.622 0.575 0.571 0.579 0.539 ...
#> $ : num [1:1000, 1] 0.48 0.455 0.436 0.384 0.468 ...
#> $ : num [1:1000, 1] 0.502 0.493 0.409 0.397 0.459 ...
#> $ : num [1:1000, 1] 0.525 0.535 0.396 0.407 0.452 ...
#> $ : num [1:1000, 1] 0.539 0.578 0.386 0.413 0.447 ...
#> $ : num [1:1000, 1] 0.557 0.614 0.385 0.416 0.445 ...
#> $ : num [1:1000, 1] 0.575 0.62 0.391 0.419 0.45 ...
#> $ : num [1:1000, 1] 0.596 0.616 0.402 0.423 0.458 ...
#> $ : num [1:1000, 1] 0.6 0.611 0.418 0.431 0.471 ...
#> $ : num [1:1000, 1] 0.607 0.608 0.442 0.441 0.48 ...
#> $ : num [1:1000, 1] 0.619 0.607 0.459 0.456 0.478 ...
#> $ : num [1:1000, 1] 0.634 0.607 0.483 0.476 0.475 ...
#> $ : num [1:1000, 1] 0.649 0.606 0.514 0.502 0.477 ...
#> $ : num [1:1000, 1] 0.657 0.6 0.532 0.523 0.491 ...
#> $ : num [1:1000, 1] 0.659 0.59 0.545 0.54 0.507 ...
#> $ : num [1:1000, 1] 0.661 0.582 0.559 0.542 0.521 ...
#> $ : num [1:1000, 1] 0.651 0.579 0.565 0.556 0.53 ...
#> $ : num [1:1000, 1] 0.638 0.576 0.569 0.573 0.537 ...
#> $ : num [1:1000, 1] 0.622 0.575 0.574 0.581 0.545 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.554 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.564 0.575 0.376 0.42 0.468 ...
#> $ : num [1:1000, 1] 0.598 0.606 0.374 0.422 0.466 ...
#> $ : num [1:1000, 1] 0.615 0.616 0.381 0.431 0.472 ...
#> $ : num [1:1000, 1] 0.62 0.622 0.393 0.436 0.471 ...
#> $ : num [1:1000, 1] 0.623 0.619 0.419 0.441 0.475 ...
#> $ : num [1:1000, 1] 0.63 0.614 0.434 0.454 0.472 ...
#> $ : num [1:1000, 1] 0.645 0.611 0.451 0.471 0.468 ...
#> $ : num [1:1000, 1] 0.662 0.611 0.479 0.497 0.465 ...
#> $ : num [1:1000, 1] 0.664 0.602 0.514 0.527 0.479 ...
#> $ : num [1:1000, 1] 0.666 0.593 0.538 0.538 0.492 ...
#> $ : num [1:1000, 1] 0.669 0.584 0.553 0.543 0.506 ...
#> $ : num [1:1000, 1] 0.662 0.58 0.561 0.555 0.52 ...
#> $ : num [1:1000, 1] 0.652 0.578 0.567 0.572 0.532 ...
#> $ : num [1:1000, 1] 0.64 0.576 0.571 0.581 0.542 ...
#> $ : num [1:1000, 1] 0.624 0.575 0.574 0.581 0.551 ...
#> $ : num [1:1000, 1] 0.61 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.582 0.556 0.393 0.429 0.493 ...
#> $ : num [1:1000, 1] 0.595 0.575 0.377 0.433 0.491 ...
#> $ : num [1:1000, 1] 0.604 0.591 0.37 0.438 0.488 ...
#> $ : num [1:1000, 1] 0.623 0.604 0.37 0.442 0.486 ...
#> $ : num [1:1000, 1] 0.639 0.61 0.388 0.447 0.48 ...
#> $ : num [1:1000, 1] 0.644 0.62 0.405 0.453 0.472 ...
#> $ : num [1:1000, 1] 0.652 0.623 0.419 0.463 0.463 ...
#> $ : num [1:1000, 1] 0.666 0.619 0.436 0.485 0.458 ...
#> $ : num [1:1000, 1] 0.668 0.607 0.465 0.513 0.463 ...
#> $ : num [1:1000, 1] 0.67 0.598 0.497 0.527 0.477 ...
#> $ : num [1:1000, 1] 0.673 0.588 0.533 0.537 0.49 ...
#> $ : num [1:1000, 1] 0.673 0.581 0.555 0.546 0.503 ...
#> $ : num [1:1000, 1] 0.664 0.58 0.561 0.561 0.518 ...
#> $ : num [1:1000, 1] 0.653 0.578 0.567 0.578 0.53 ...
#> $ : num [1:1000, 1] 0.642 0.577 0.571 0.582 0.543 ...
#> $ : num [1:1000, 1] 0.628 0.576 0.574 0.581 0.553 ...
#> $ : num [1:1000, 1] 0.615 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> [list output truncated]
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.