SurrogateSplits.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00022                   SEXP fitmem) {
00023 
00024     SEXP x, y, expcovinf; 
00025     SEXP splitctrl, inputs; 
00026     SEXP split, thiswhichNA;
00027     int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028     double ms, cp, *thisweights, *cutpoint, *maxstat, 
00029            *splitstat, *dweights, *tweights, *dx, *dy;
00030     double cut, *twotab;
00031     
00032     nobs = get_nobs(learnsample);
00033     ninputs = get_ninputs(learnsample);
00034     splitctrl = get_splitctrl(controls);
00035     maxsurr = get_maxsurrogate(splitctrl);
00036     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037     jselect = S3get_variableID(S3get_primarysplit(node));
00038     y = S3get_nodeweights(VECTOR_ELT(node, 7));
00039 
00040     for (j = 0; j < ninputs; j++) {
00041         if (is_nominal(inputs, j + 1)) continue;
00042         nvar++;
00043     }
00044     nvar--;
00045 Rprintf("nvar: %d\n", nvar);
00046 
00047     if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00048         error("nodes does not have %d surrogate splits", maxsurr);
00049     if (maxsurr > nvar)
00050         error("cannot set up %d surrogate splits with only %d ordered input variable(s)", 
00051               maxsurr, nvar);
00052 
00053     tweights = Calloc(nobs, double);
00054     dweights = REAL(weights);
00055     for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00056     if (has_missings(inputs, jselect)) {
00057         thiswhichNA = get_missings(inputs, jselect);
00058         for (k = 0; k < LENGTH(thiswhichNA); k++)
00059             tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00060     }
00061 
00062     expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00063     C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00064     
00065     splitstat = REAL(get_splitstatistics(fitmem));
00066     /* <FIXME> extend `TreeFitMemory' to those as well ... */
00067     maxstat = Calloc(ninputs, double);
00068     cutpoint = Calloc(ninputs, double);
00069     order = Calloc(ninputs, int);
00070     /* <FIXME> */
00071     
00072     /* this is essentially an exhaustive search */
00073     /* <FIXME>: we don't want to do this for random forest like trees 
00074        </FIXME>
00075      */
00076     for (j = 0; j < ninputs; j++) {
00077     
00078          order[j] = j + 1;
00079          maxstat[j] = 0.0;
00080          cutpoint[j] = 0.0;
00081 
00082          /* ordered input variables only (for the moment) */
00083          if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00084              continue;
00085 
00086          x = get_variable(inputs, j + 1);
00087 
00088          if (has_missings(inputs, j + 1)) {
00089 
00090              thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00091                  
00092              C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00093              
00094              C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00095                      INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00096                      GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00097                      expcovinf, &cp, &ms, splitstat);
00098          } else {
00099          
00100              C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00101              INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00102              GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00103              expcovinf, &cp, &ms, splitstat);
00104          }
00105 
00106          maxstat[j] = -ms;
00107          cutpoint[j] = cp;
00108     }
00109 
00110     /* order with respect to maximal statistic */
00111     rsort_with_index(maxstat, order, ninputs);
00112     
00113     twotab = Calloc(4, double);
00114     
00115     /* the best `maxsurr' ones are implemented */
00116     for (j = 0; j < maxsurr; j++) {
00117 
00118         for (i = 0; i < 4; i++) twotab[i] = 0.0;
00119         cut = cutpoint[order[j] - 1];
00120         SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
00121                        split = allocVector(VECSXP, SPLIT_LENGTH));
00122         C_init_orderedsplit(split, 0);
00123         S3set_variableID(split, order[j]);
00124         REAL(S3get_splitpoint(split))[0] = cut;
00125         dx = REAL(get_variable(inputs, order[j]));
00126         dy = REAL(y);
00127 
00128         /* OK, this is a dirty hack: determine if the split 
00129            goes left or right by the Pearson residual of a 2x2 table.
00130            I don't want to use the big caliber here 
00131         */
00132         for (i = 0; i < nobs; i++) {
00133             twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00134             twotab[1] += (dy[i] == 1) * tweights[i];
00135             twotab[2] += (dx[i] <= cut) * tweights[i];
00136             twotab[3] += tweights[i];
00137         }
00138         S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
00139                      twotab[3]) > 0);
00140     }
00141     
00142     Free(maxstat);
00143     Free(cutpoint);
00144     Free(order);
00145     Free(tweights);
00146     Free(twotab);
00147 }
00148 
00159 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00160                   SEXP fitmem) {
00161 
00162     C_surrogates(node, learnsample, weights, controls, fitmem);
00163     return(S3get_surrogatesplits(node));
00164     
00165 }
00166 
00174 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00175 
00176     SEXP weights, split, surrsplit;
00177     SEXP inputs, whichNA;
00178     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00179     int *iwhichNA, k;
00180     int nobs, i, nna, ns;
00181                     
00182     weights = S3get_nodeweights(node);
00183     dweights = REAL(weights);
00184     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00185     nobs = get_nobs(learnsample);
00186             
00187     leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00188     rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00189     surrsplit = S3get_surrogatesplits(node);
00190 
00191     /* if the primary split has any missings */
00192     split = S3get_primarysplit(node);
00193     if (has_missings(inputs, S3get_variableID(split))) {
00194 
00195         /* where are the missings? */
00196         whichNA = get_missings(inputs, S3get_variableID(split));
00197         iwhichNA = INTEGER(whichNA);
00198         nna = LENGTH(whichNA);
00199 
00200         /* for all missing values ... */
00201         for (k = 0; k < nna; k++) {
00202             ns = 0;
00203             i = iwhichNA[k] - 1;
00204             if (dweights[i] == 0) continue;
00205             
00206             /* loop over surrogate splits until an appropriate one is found */
00207             while(TRUE) {
00208             
00209                 if (ns >= LENGTH(surrsplit)) break;
00210             
00211                 split = VECTOR_ELT(surrsplit, ns);
00212                 if (has_missings(inputs, S3get_variableID(split))) {
00213                     if (INTEGER(get_missings(inputs, 
00214                             S3get_variableID(split)))[i]) {
00215                         ns++;
00216                         continue;
00217                     }
00218                 }
00219 
00220                 cutpoint = REAL(S3get_splitpoint(split))[0];
00221                 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00222 
00223                 if (S3get_toleft(split)) {
00224                     if (dx[i] <= cutpoint) {
00225                         leftweights[i] = dweights[i];
00226                         rightweights[i] = 0.0;
00227                     } else {
00228                         rightweights[i] = dweights[i];
00229                         leftweights[i] = 0.0;
00230                     }
00231                 } else {
00232                     if (dx[i] <= cutpoint) {
00233                         rightweights[i] = dweights[i];
00234                         leftweights[i] = 0.0;
00235                     } else {
00236                         leftweights[i] = dweights[i];
00237                         rightweights[i] = 0.0;
00238                     }
00239                 }
00240                 break;
00241             }
00242         }
00243     }
00244 }

Generated on Fri Nov 30 16:04:21 2007 for party by  doxygen 1.4.6