RandomForest.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00022 SEXP R_Ensemble(SEXP learnsample, SEXP weights, SEXP bwhere, SEXP bweights, 
00023                 SEXP fitmem, SEXP controls) {
00024             
00025      SEXP nweights, tree, where, ans, bw;
00026      double *dnweights, *dweights, sw = 0.0, *prob, tmp;
00027      int nobs, i, b, B , nodenum = 1, *iweights, *iweightstmp, 
00028          *iwhere, replace, fraction, wgrzero = 0, realweights = 0;
00029      int j, k, l;
00030      
00031      B = get_ntree(controls);
00032      nobs = get_nobs(learnsample);
00033      
00034      PROTECT(ans = allocVector(VECSXP, B));
00035 
00036      iweights = Calloc(nobs, int);
00037      iweightstmp = Calloc(nobs, int);
00038      prob = Calloc(nobs, double);
00039      dweights = REAL(weights);
00040 
00041      for (i = 0; i < nobs; i++) {
00042          /* sum of weights */
00043          sw += dweights[i];
00044          /* number of weights > 0 */
00045          if (dweights[i] > 0) wgrzero++;
00046          /* case weights or real weights? */
00047          if (dweights[i] - ftrunc(dweights[i]) > 0) 
00048              realweights = 1;
00049      }
00050      for (i = 0; i < nobs; i++)
00051          prob[i] = dweights[i]/sw;
00052 
00053      replace = get_replace(controls);
00054      /* fraction of number of obs with weight > 0 */
00055      if (realweights) {
00056          /* fraction of number of obs with weight > 0 for real weights*/
00057          tmp = (get_fraction(controls) * wgrzero);
00058      } else {
00059          /* fraction of sum of weights for case weights */
00060          tmp = (get_fraction(controls) * sw);
00061      }
00062      fraction = (int) ftrunc(tmp);
00063      if (ftrunc(tmp) < tmp) fraction++;
00064 
00065      if (!replace) {
00066          if (fraction < 10)
00067              error("fraction of %f is too small", fraction);
00068      }
00069 
00070      /* <FIXME> can we call those guys ONCE? what about the deeper
00071          calls??? </FIXME> */
00072      GetRNGstate();
00073   
00074      if (get_trace(controls))
00075          Rprintf("\n");
00076      for (b  = 0; b < B; b++) {
00077          SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
00078          SET_VECTOR_ELT(bwhere, b, where = allocVector(INTSXP, nobs));
00079          SET_VECTOR_ELT(bweights, b, bw = allocVector(REALSXP, nobs));
00080          
00081          iwhere = INTEGER(where);
00082          for (i = 0; i < nobs; i++) iwhere[i] = 0;
00083      
00084          C_init_node(tree, nobs, get_ninputs(learnsample), 
00085                      get_maxsurrogate(get_splitctrl(controls)),
00086                      ncol(get_predict_trafo(GET_SLOT(learnsample, 
00087                                                    PL2_responsesSym))));
00088 
00089          /* generate altered weights for perturbation */
00090          if (replace) {
00091              /* weights for a bootstrap sample */
00092              rmultinom((int) sw, prob, nobs, iweights);
00093          } else {
00094              /* weights for sample splitting */
00095              C_SampleSplitting(nobs, prob, iweights, fraction);
00096          }
00097 
00098          nweights = S3get_nodeweights(tree);
00099          dnweights = REAL(nweights);
00100          for (i = 0; i < nobs; i++) {
00101              REAL(bw)[i] = (double) iweights[i];
00102              dnweights[i] = REAL(bw)[i];
00103          }
00104      
00105          C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum, 1);
00106          nodenum = 1;
00107          C_remove_weights(tree, 0);
00108          
00109          if (get_trace(controls)) {
00110              /* progress bar; inspired by 
00111              http://avinashjoshi.co.in/2009/10/13/creating-a-progress-bar-in-c/ */
00112              Rprintf("[");
00113              /* Print the = until the current percentage */
00114              l = (int) ceil( ((double) b * 50.0) / B);
00115              for (j = 0; j < l; j++)
00116                  Rprintf("=");
00117              Rprintf(">");
00118              for (k = j; k < 50; k++)
00119                  Rprintf(" ");
00120              Rprintf("]");
00121              /* % completed */
00122                  Rprintf(" %3d%% completed", j * 2);
00123              /* To delete the previous line */
00124              Rprintf("\r");
00125              /* Flush all char in buffer */
00126              fflush(stdout);
00127          }
00128      }
00129      if (get_trace(controls))
00130          Rprintf("\n");
00131 
00132      PutRNGstate();
00133 
00134      Free(prob); Free(iweights); Free(iweightstmp);
00135      UNPROTECT(1);
00136      return(ans);
00137 }