Predict.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022 
00023     SEXP weights, leftnode, rightnode, split;
00024     SEXP responses, inputs, whichNA;
00025     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026     double sleft = 0.0, sright = 0.0;
00027     int *ix, *levelset, *iwhichNA;
00028     int nobs, i, nna;
00029                     
00030     weights = S3get_nodeweights(node);
00031     dweights = REAL(weights);
00032     responses = GET_SLOT(learnsample, PL2_responsesSym);
00033     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034     nobs = get_nobs(learnsample);
00035             
00036     /* set up memory for the left daughter */
00037     SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038     C_init_node(leftnode, nobs, 
00039         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040         ncol(get_predict_trafo(responses)));
00041     leftweights = REAL(S3get_nodeweights(leftnode));
00042 
00043     /* set up memory for the right daughter */
00044     SET_VECTOR_ELT(node, S3_RIGHT, 
00045                    rightnode = allocVector(VECSXP, NODE_LENGTH));
00046     C_init_node(rightnode, nobs, 
00047         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00048         ncol(get_predict_trafo(responses)));
00049     rightweights = REAL(S3get_nodeweights(rightnode));
00050 
00051     /* split according to the primary split */
00052     split = S3get_primarysplit(node);
00053     if (has_missings(inputs, S3get_variableID(split))) {
00054         whichNA = get_missings(inputs, S3get_variableID(split));
00055         iwhichNA = INTEGER(whichNA);
00056         nna = LENGTH(whichNA);
00057     } else {
00058         nna = 0;
00059         whichNA = R_NilValue;
00060         iwhichNA = NULL;
00061     }
00062     
00063     if (S3is_ordered(split)) {
00064         cutpoint = REAL(S3get_splitpoint(split))[0];
00065         dx = REAL(get_variable(inputs, S3get_variableID(split)));
00066         for (i = 0; i < nobs; i++) {
00067             if (nna > 0) {
00068                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00069             }
00070             if (dx[i] <= cutpoint) 
00071                 leftweights[i] = dweights[i]; 
00072             else 
00073                 leftweights[i] = 0.0;
00074             rightweights[i] = dweights[i] - leftweights[i];
00075             sleft += leftweights[i];
00076             sright += rightweights[i];
00077         }
00078     } else {
00079         levelset = INTEGER(S3get_splitpoint(split));
00080         ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00081 
00082         for (i = 0; i < nobs; i++) {
00083             if (nna > 0) {
00084                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00085             }
00086             if (levelset[ix[i] - 1])
00087                 leftweights[i] = dweights[i];
00088             else 
00089                 leftweights[i] = 0.0;
00090             rightweights[i] = dweights[i] - leftweights[i];
00091             sleft += leftweights[i];
00092             sright += rightweights[i];
00093         }
00094     }
00095     
00096     /* for the moment: NA's go with majority */
00097     if (nna > 0) {
00098         for (i = 0; i < nna; i++) {
00099             if (sleft > sright) {
00100                 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00101                 rightweights[iwhichNA[i] - 1] = 0.0;
00102             } else {
00103                 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00104                 leftweights[iwhichNA[i] - 1] = 0.0;
00105             }
00106         }
00107     }
00108 }
00109 
00110 
00121 SEXP C_get_node(SEXP subtree, SEXP newinputs, 
00122                 double mincriterion, int numobs, int varperm) {
00123 
00124     SEXP split, whichNA, ssplit, surrsplit;
00125     double cutpoint, x, swleft, swright;
00126     int level, *levelset, i, ns;
00127 
00128     if (S3get_nodeterminal(subtree) || 
00129         REAL(S3get_maxcriterion(subtree))[0] < mincriterion) 
00130         return(subtree);
00131     
00132     split = S3get_primarysplit(subtree);
00133 
00134     /* Maybe store the proportions left / right in each node? */
00135     swleft = S3get_sumweights(S3get_leftnode(subtree));
00136     swright = S3get_sumweights(S3get_rightnode(subtree));
00137 
00138     /* splits based on variable varperm are random */    
00139     if (S3get_variableID(split) == varperm) {
00140         if (unif_rand() < swleft / (swleft + swright)) {
00141             return(C_get_node(S3get_leftnode(subtree),
00142                        newinputs, mincriterion, numobs, varperm));
00143         } else {
00144             return(C_get_node(S3get_rightnode(subtree),
00145                        newinputs, mincriterion, numobs, varperm));
00146         }
00147     }
00148                    
00149     /* missing values */
00150     if (has_missings(newinputs, S3get_variableID(split))) {
00151         whichNA = get_missings(newinputs, S3get_variableID(split));
00152     
00153         /* numobs 0 ... n - 1 but whichNA has 1:n */
00154         if (C_i_in_set(numobs + 1, whichNA)) {
00155         
00156             surrsplit = S3get_surrogatesplits(subtree);
00157             ns = 0;
00158             i = numobs;      
00159 
00160             /* try to find a surrogate split */
00161             while(TRUE) {
00162     
00163                 if (ns >= LENGTH(surrsplit)) break;
00164             
00165                 ssplit = VECTOR_ELT(surrsplit, ns);
00166                 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00167                     if (INTEGER(get_missings(newinputs, 
00168                                              S3get_variableID(ssplit)))[i]) {
00169                         ns++;
00170                         continue;
00171                     }
00172                 }
00173 
00174                 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00175                 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00176                      
00177                 if (S3get_toleft(ssplit)) {
00178                     if (x <= cutpoint) {
00179                         return(C_get_node(S3get_leftnode(subtree),
00180                                           newinputs, mincriterion, numobs, varperm));
00181                     } else {
00182                         return(C_get_node(S3get_rightnode(subtree),
00183                                newinputs, mincriterion, numobs, varperm));
00184                     }
00185                 } else {
00186                     if (x <= cutpoint) {
00187                         return(C_get_node(S3get_rightnode(subtree),
00188                                           newinputs, mincriterion, numobs, varperm));
00189                     } else {
00190                         return(C_get_node(S3get_leftnode(subtree),
00191                                newinputs, mincriterion, numobs, varperm));
00192                     }
00193                 }
00194                 break;
00195             }
00196 
00197             /* if this was not successful, we go with the majority */
00198             if (swleft > swright) {
00199                 return(C_get_node(S3get_leftnode(subtree), 
00200                                   newinputs, mincriterion, numobs, varperm));
00201             } else {
00202                 return(C_get_node(S3get_rightnode(subtree), 
00203                                   newinputs, mincriterion, numobs, varperm));
00204             }
00205         }
00206     }
00207     
00208     if (S3is_ordered(split)) {
00209         cutpoint = REAL(S3get_splitpoint(split))[0];
00210         x = REAL(get_variable(newinputs, 
00211                      S3get_variableID(split)))[numobs];
00212         if (x <= cutpoint) {
00213             return(C_get_node(S3get_leftnode(subtree), 
00214                               newinputs, mincriterion, numobs, varperm));
00215         } else {
00216             return(C_get_node(S3get_rightnode(subtree), 
00217                               newinputs, mincriterion, numobs, varperm));
00218         }
00219     } else {
00220         levelset = INTEGER(S3get_splitpoint(split));
00221         level = INTEGER(get_variable(newinputs, 
00222                             S3get_variableID(split)))[numobs];
00223         /* level is in 1, ..., K */
00224         if (levelset[level - 1]) {
00225             return(C_get_node(S3get_leftnode(subtree), newinputs, 
00226                               mincriterion, numobs, varperm));
00227         } else {
00228             return(C_get_node(S3get_rightnode(subtree), newinputs, 
00229                               mincriterion, numobs, varperm));
00230         }
00231     }
00232 }
00233 
00234 
00243 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, 
00244                 SEXP numobs) {
00245     return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00246                       INTEGER(numobs)[0] - 1, -1));
00247 }
00248 
00249 
00256 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00257     
00258     if (nodenum == S3get_nodeID(subtree)) return(subtree);
00259 
00260     if (S3get_nodeterminal(subtree)) 
00261         error("no node with number %d\n", nodenum);
00262 
00263     if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00264         return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00265     } else {
00266         return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00267     }
00268 }
00269 
00270 
00277 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00278     return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00279 }
00280 
00281 
00291 SEXP C_get_prediction(SEXP subtree, SEXP newinputs, 
00292                       double mincriterion, int numobs, int varperm) {
00293     return(S3get_prediction(C_get_node(subtree, newinputs, 
00294                             mincriterion, numobs, varperm)));
00295 }
00296 
00297 
00306 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, 
00307                        double mincriterion, int numobs) {
00308     return(S3get_nodeweights(C_get_node(subtree, newinputs, 
00309                              mincriterion, numobs, -1)));
00310 }
00311 
00312 
00321 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00322                   double mincriterion, int numobs) {
00323      return(S3get_nodeID(C_get_node(subtree, newinputs, 
00324             mincriterion, numobs, -1)));
00325 }
00326 
00327 
00335 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00336 
00337     SEXP ans;
00338     int nobs, i, *dans;
00339             
00340     nobs = get_nobs(newinputs);
00341     PROTECT(ans = allocVector(INTSXP, nobs));
00342     dans = INTEGER(ans);
00343     for (i = 0; i < nobs; i++)
00344          dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00345     UNPROTECT(1);
00346     return(ans);
00347 }
00348 
00349 
00359 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, 
00360                int varperm, SEXP ans) {
00361     
00362     int nobs, i;
00363     
00364     nobs = get_nobs(newinputs);    
00365     if (LENGTH(ans) != nobs) 
00366         error("ans is not of length %d\n", nobs);
00367         
00368     for (i = 0; i < nobs; i++)
00369         SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, 
00370                        mincriterion, i, varperm));
00371 }
00372 
00373 
00381 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00382 
00383     SEXP ans;
00384     int nobs;
00385     
00386     nobs = get_nobs(newinputs);
00387     PROTECT(ans = allocVector(VECSXP, nobs));
00388     C_predict(tree, newinputs, REAL(mincriterion)[0], 
00389               -1, ans);
00390     UNPROTECT(1);
00391     return(ans);
00392 }
00393 
00402 SEXP R_predict2(SEXP tree, SEXP newinputs, SEXP mincriterion,
00403                SEXP varperm) {
00404 
00405     SEXP ans;
00406     int nobs;
00407     
00408     nobs = get_nobs(newinputs);
00409     PROTECT(ans = allocVector(VECSXP, nobs));
00410     C_predict(tree, newinputs, REAL(mincriterion)[0], 
00411               INTEGER(varperm)[0], ans);
00412     UNPROTECT(1);
00413     return(ans);
00414 }
00415 
00423 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00424 
00425     int nobs, i, *iwhere;
00426     
00427     nobs = LENGTH(where);
00428     iwhere = INTEGER(where);
00429     if (LENGTH(ans) != nobs)
00430         error("ans is not of length %d\n", nobs);
00431         
00432     for (i = 0; i < nobs; i++)
00433         SET_VECTOR_ELT(ans, i, S3get_prediction(
00434             C_get_nodebynum(tree, iwhere[i])));
00435 }
00436 
00437 
00444 SEXP R_getpredictions(SEXP tree, SEXP where) {
00445 
00446     SEXP ans;
00447     int nobs;
00448             
00449     nobs = LENGTH(where);
00450     PROTECT(ans = allocVector(VECSXP, nobs));
00451     C_getpredictions(tree, where, ans);
00452     UNPROTECT(1);
00453     return(ans);
00454 }                        
00455 
00466 SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights, 
00467                          SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00468 
00469     SEXP ans, tree, bw;
00470     int ntrees, nobs, i, b, j, iwhere, oob = 0, count = 0, ntrain;
00471     
00472     if (LOGICAL(oobpred)[0]) oob = 1;
00473     
00474     nobs = get_nobs(newinputs);
00475     ntrees = LENGTH(forest);
00476 
00477     if (oob) {
00478         if (LENGTH(VECTOR_ELT(weights, 0)) != nobs)
00479             error("number of observations don't match");
00480     }    
00481     
00482     tree = VECTOR_ELT(forest, 0);
00483     ntrain = LENGTH(VECTOR_ELT(weights, 0));
00484     
00485     PROTECT(ans = allocVector(VECSXP, nobs));
00486     
00487     for (i = 0; i < nobs; i++) {
00488         count = 0;
00489         SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00490         for (j = 0; j < ntrain; j++)
00491             REAL(bw)[j] = 0.0;
00492         for (b = 0; b < ntrees; b++) {
00493             tree = VECTOR_ELT(forest, b);
00494 
00495             if (oob && 
00496                 REAL(VECTOR_ELT(weights, b))[i] > 0.0) 
00497                 continue;
00498 
00499             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00500             
00501             for (j = 0; j < ntrain; j++) {
00502                 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00503                     REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j];
00504             }
00505             count++;
00506         }
00507         if (count == 0) 
00508             error("cannot compute out-of-bag predictions for obs ", i + 1);
00509     }
00510     UNPROTECT(1);
00511     return(ans);
00512 }
00513 
00514 
00520 SEXP R_proximity(SEXP where) {
00521 
00522     SEXP ans, bw, bin;
00523     int ntrees, nobs, i, b, j, iwhere;
00524     
00525     ntrees = LENGTH(where);
00526     nobs = LENGTH(VECTOR_ELT(where, 0));
00527     
00528     PROTECT(ans = allocVector(VECSXP, nobs));
00529     PROTECT(bin = allocVector(INTSXP, nobs));
00530      
00531     for (i = 0; i < nobs; i++) {
00532         SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs));
00533         for (j = 0; j < nobs; j++) {
00534             REAL(bw)[j] = 0.0;
00535             INTEGER(bin)[j] = 0;
00536         }
00537         for (b = 0; b < ntrees; b++) {
00538             /* don't look at out-of-bag observations */
00539             if (INTEGER(VECTOR_ELT(where, b))[i] == 0)
00540                 continue;
00541             iwhere = INTEGER(VECTOR_ELT(where, b))[i];
00542             for (j = 0; j < nobs; j++) {
00543                 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00544                     /* only count the number of trees; no weights */
00545                     REAL(bw)[j]++;
00546                 if (INTEGER(VECTOR_ELT(where, b))[j] > 0)
00547                     /* count the number of bootstrap samples
00548                     containing both i and j */
00549                     INTEGER(bin)[j]++;
00550             }
00551         }
00552         for (j = 0; j < nobs; j++)
00553             REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j];
00554     }
00555     UNPROTECT(2);
00556     return(ans);
00557 }