SurrogateSplits.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00022                   SEXP fitmem) {
00023 
00024     SEXP x, y, expcovinf; 
00025     SEXP splitctrl, inputs; 
00026     SEXP split, thiswhichNA;
00027     int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028     double ms, cp, *thisweights, *cutpoint, *maxstat, 
00029            *splitstat, *dweights, *tweights, *dx, *dy;
00030     double cut, *twotab;
00031     
00032     nobs = get_nobs(learnsample);
00033     ninputs = get_ninputs(learnsample);
00034     splitctrl = get_splitctrl(controls);
00035     maxsurr = get_maxsurrogate(splitctrl);
00036     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037     jselect = S3get_variableID(S3get_primarysplit(node));
00038     y = S3get_nodeweights(VECTOR_ELT(node, 7));
00039 
00040     for (j = 0; j < ninputs; j++) {
00041         if (is_nominal(inputs, j + 1)) continue;
00042         nvar++;
00043     }
00044     nvar--;
00045 
00046     if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00047         error("nodes does not have %d surrogate splits", maxsurr);
00048     if (maxsurr > nvar)
00049         error("cannot set up %d surrogate splits with only %d ordered input variable(s)", 
00050               maxsurr, nvar);
00051 
00052     tweights = Calloc(nobs, double);
00053     dweights = REAL(weights);
00054     for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00055     if (has_missings(inputs, jselect)) {
00056         thiswhichNA = get_missings(inputs, jselect);
00057         for (k = 0; k < LENGTH(thiswhichNA); k++)
00058             tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00059     }
00060 
00061     expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00062     C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00063     
00064     splitstat = REAL(get_splitstatistics(fitmem));
00065     /* <FIXME> extend `TreeFitMemory' to those as well ... */
00066     maxstat = Calloc(ninputs, double);
00067     cutpoint = Calloc(ninputs, double);
00068     order = Calloc(ninputs, int);
00069     /* <FIXME> */
00070     
00071     /* this is essentially an exhaustive search */
00072     /* <FIXME>: we don't want to do this for random forest like trees 
00073        </FIXME>
00074      */
00075     for (j = 0; j < ninputs; j++) {
00076     
00077          order[j] = j + 1;
00078          maxstat[j] = 0.0;
00079          cutpoint[j] = 0.0;
00080 
00081          /* ordered input variables only (for the moment) */
00082          if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00083              continue;
00084 
00085          x = get_variable(inputs, j + 1);
00086 
00087          if (has_missings(inputs, j + 1)) {
00088 
00089              thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00090                  
00091              C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00092              
00093              C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00094                      INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00095                      GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00096                      expcovinf, &cp, &ms, splitstat);
00097          } else {
00098          
00099              C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00100              INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00101              GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00102              expcovinf, &cp, &ms, splitstat);
00103          }
00104 
00105          maxstat[j] = -ms;
00106          cutpoint[j] = cp;
00107     }
00108 
00109     /* order with respect to maximal statistic */
00110     rsort_with_index(maxstat, order, ninputs);
00111     
00112     twotab = Calloc(4, double);
00113     
00114     /* the best `maxsurr' ones are implemented */
00115     for (j = 0; j < maxsurr; j++) {
00116 
00117         for (i = 0; i < 4; i++) twotab[i] = 0.0;
00118         cut = cutpoint[order[j] - 1];
00119         SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
00120                        split = allocVector(VECSXP, SPLIT_LENGTH));
00121         C_init_orderedsplit(split, 0);
00122         S3set_variableID(split, order[j]);
00123         REAL(S3get_splitpoint(split))[0] = cut;
00124         dx = REAL(get_variable(inputs, order[j]));
00125         dy = REAL(y);
00126 
00127         /* OK, this is a dirty hack: determine if the split 
00128            goes left or right by the Pearson residual of a 2x2 table.
00129            I don't want to use the big caliber here 
00130         */
00131         for (i = 0; i < nobs; i++) {
00132             twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00133             twotab[1] += (dy[i] == 1) * tweights[i];
00134             twotab[2] += (dx[i] <= cut) * tweights[i];
00135             twotab[3] += tweights[i];
00136         }
00137         S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
00138                      twotab[3]) > 0);
00139     }
00140     
00141     Free(maxstat);
00142     Free(cutpoint);
00143     Free(order);
00144     Free(tweights);
00145     Free(twotab);
00146 }
00147 
00158 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00159                   SEXP fitmem) {
00160 
00161     C_surrogates(node, learnsample, weights, controls, fitmem);
00162     return(S3get_surrogatesplits(node));
00163     
00164 }
00165 
00173 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00174 
00175     SEXP weights, split, surrsplit;
00176     SEXP inputs, whichNA;
00177     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00178     int *iwhichNA, k;
00179     int nobs, i, nna, ns;
00180                     
00181     weights = S3get_nodeweights(node);
00182     dweights = REAL(weights);
00183     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00184     nobs = get_nobs(learnsample);
00185             
00186     leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00187     rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00188     surrsplit = S3get_surrogatesplits(node);
00189 
00190     /* if the primary split has any missings */
00191     split = S3get_primarysplit(node);
00192     if (has_missings(inputs, S3get_variableID(split))) {
00193 
00194         /* where are the missings? */
00195         whichNA = get_missings(inputs, S3get_variableID(split));
00196         iwhichNA = INTEGER(whichNA);
00197         nna = LENGTH(whichNA);
00198 
00199         /* for all missing values ... */
00200         for (k = 0; k < nna; k++) {
00201             ns = 0;
00202             i = iwhichNA[k] - 1;
00203             if (dweights[i] == 0) continue;
00204             
00205             /* loop over surrogate splits until an appropriate one is found */
00206             while(TRUE) {
00207             
00208                 if (ns >= LENGTH(surrsplit)) break;
00209             
00210                 split = VECTOR_ELT(surrsplit, ns);
00211                 if (has_missings(inputs, S3get_variableID(split))) {
00212                     if (INTEGER(get_missings(inputs, 
00213                             S3get_variableID(split)))[i]) {
00214                         ns++;
00215                         continue;
00216                     }
00217                 }
00218 
00219                 cutpoint = REAL(S3get_splitpoint(split))[0];
00220                 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00221 
00222                 if (S3get_toleft(split)) {
00223                     if (dx[i] <= cutpoint) {
00224                         leftweights[i] = dweights[i];
00225                         rightweights[i] = 0.0;
00226                     } else {
00227                         rightweights[i] = dweights[i];
00228                         leftweights[i] = 0.0;
00229                     }
00230                 } else {
00231                     if (dx[i] <= cutpoint) {
00232                         rightweights[i] = dweights[i];
00233                         leftweights[i] = 0.0;
00234                     } else {
00235                         leftweights[i] = dweights[i];
00236                         rightweights[i] = 0.0;
00237                     }
00238                 }
00239                 break;
00240             }
00241         }
00242     }
00243 }

Generated on Wed Jan 30 13:51:30 2008 for party by  doxygen 1.5.3