00001
00009 #include "party.h"
00010
00022 SEXP R_Ensemble(SEXP learnsample, SEXP weights, SEXP bwhere, SEXP bweights,
00023 SEXP fitmem, SEXP controls) {
00024
00025 SEXP nweights, tree, where, ans, bw;
00026 double *dnweights, *dweights, sw = 0.0, *prob, tmp;
00027 int nobs, i, b, B , nodenum = 1, *iweights, *iweightstmp,
00028 *iwhere, replace, fraction, wgrzero = 0, realweights = 0;
00029 int j, k, l;
00030
00031 B = get_ntree(controls);
00032 nobs = get_nobs(learnsample);
00033
00034 PROTECT(ans = allocVector(VECSXP, B));
00035
00036 iweights = Calloc(nobs, int);
00037 iweightstmp = Calloc(nobs, int);
00038 prob = Calloc(nobs, double);
00039 dweights = REAL(weights);
00040
00041 for (i = 0; i < nobs; i++) {
00042
00043 sw += dweights[i];
00044
00045 if (dweights[i] > 0) wgrzero++;
00046
00047 if (dweights[i] - ftrunc(dweights[i]) > 0)
00048 realweights = 1;
00049 }
00050 for (i = 0; i < nobs; i++)
00051 prob[i] = dweights[i]/sw;
00052
00053 replace = get_replace(controls);
00054
00055 if (realweights) {
00056
00057 tmp = (get_fraction(controls) * wgrzero);
00058 } else {
00059
00060 tmp = (get_fraction(controls) * sw);
00061 }
00062 fraction = (int) ftrunc(tmp);
00063 if (ftrunc(tmp) < tmp) fraction++;
00064
00065 if (!replace) {
00066 if (fraction < 10)
00067 error("fraction of %f is too small", fraction);
00068 }
00069
00070
00071
00072 GetRNGstate();
00073
00074 if (get_trace(controls))
00075 Rprintf("\n");
00076 for (b = 0; b < B; b++) {
00077 SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
00078 SET_VECTOR_ELT(bwhere, b, where = allocVector(INTSXP, nobs));
00079 SET_VECTOR_ELT(bweights, b, bw = allocVector(REALSXP, nobs));
00080
00081 iwhere = INTEGER(where);
00082 for (i = 0; i < nobs; i++) iwhere[i] = 0;
00083
00084 C_init_node(tree, nobs, get_ninputs(learnsample),
00085 get_maxsurrogate(get_splitctrl(controls)),
00086 ncol(get_predict_trafo(GET_SLOT(learnsample,
00087 PL2_responsesSym))));
00088
00089
00090 if (replace) {
00091
00092 rmultinom((int) sw, prob, nobs, iweights);
00093 } else {
00094
00095 C_SampleSplitting(nobs, prob, iweights, fraction);
00096 }
00097
00098 nweights = S3get_nodeweights(tree);
00099 dnweights = REAL(nweights);
00100 for (i = 0; i < nobs; i++) {
00101 REAL(bw)[i] = (double) iweights[i];
00102 dnweights[i] = REAL(bw)[i];
00103 }
00104
00105 C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum, 1);
00106 nodenum = 1;
00107 C_remove_weights(tree, 0);
00108
00109 if (get_trace(controls)) {
00110
00111
00112 Rprintf("[");
00113
00114 l = (int) ceil( ((double) b * 50.0) / B);
00115 for (j = 0; j < l; j++)
00116 Rprintf("=");
00117 Rprintf(">");
00118 for (k = j; k < 50; k++)
00119 Rprintf(" ");
00120 Rprintf("]");
00121
00122 Rprintf(" %3d%% completed", j * 2);
00123
00124 Rprintf("\r");
00125
00126 fflush(stdout);
00127 }
00128 }
00129 if (get_trace(controls))
00130 Rprintf("\n");
00131
00132 PutRNGstate();
00133
00134 Free(prob); Free(iweights); Free(iweightstmp);
00135 UNPROTECT(1);
00136 return(ans);
00137 }