Node.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00022 void C_prediction(const double *y, int n, int q, const double *weights, 
00023                   const double sweights, double *ans) {
00024 
00025     int i, j, jn;
00026     
00027     for (j = 0; j < q; j++) {
00028         ans[j] = 0.0;
00029         jn = j * n;
00030         for (i = 0; i < n; i++) 
00031             ans[j] += weights[i] * y[jn + i];
00032         ans[j] = ans[j] / sweights;
00033     }
00034 }
00035 
00036 
00049 void C_Node(SEXP node, SEXP learnsample, SEXP weights, 
00050             SEXP fitmem, SEXP controls, int TERMINAL, int depth) {
00051     
00052     int nobs, ninputs, jselect, q, j, k, i;
00053     double mincriterion, sweights, *dprediction;
00054     double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00055     double *standstat, *splitstat;
00056     SEXP responses, inputs, x, expcovinf, linexpcov;
00057     SEXP varctrl, splitctrl, gtctrl, tgctrl, split, testy, predy;
00058     double *dxtransf, *dweights, *thisweights;
00059     int *itable;
00060     
00061     nobs = get_nobs(learnsample);
00062     ninputs = get_ninputs(learnsample);
00063     varctrl = get_varctrl(controls);
00064     splitctrl = get_splitctrl(controls);
00065     gtctrl = get_gtctrl(controls);
00066     tgctrl = get_tgctrl(controls);
00067     mincriterion = get_mincriterion(gtctrl);
00068     responses = GET_SLOT(learnsample, PL2_responsesSym);
00069     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00070     testy = get_test_trafo(responses);
00071     predy = get_predict_trafo(responses);
00072     q = ncol(testy);
00073 
00074     /* <FIXME> we compute C_GlobalTest even for TERMINAL nodes! </FIXME> */
00075 
00076     /* compute the test statistics and the node criteria for each input */        
00077     C_GlobalTest(learnsample, weights, fitmem, varctrl,
00078                  gtctrl, get_minsplit(splitctrl), 
00079                  REAL(S3get_teststat(node)), REAL(S3get_criterion(node)), depth);
00080     
00081     /* sum of weights: C_GlobalTest did nothing if sweights < mincriterion */
00082     sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym), 
00083                              PL2_sumweightsSym))[0];
00084     REAL(VECTOR_ELT(node, S3_SUMWEIGHTS))[0] = sweights;
00085 
00086     /* compute the prediction of this node */
00087     dprediction = REAL(S3get_prediction(node));
00088 
00089     /* <FIXME> feed raw numeric values OR dummy encoded factors as y 
00090        Problem: what happens for survival times ? */
00091     C_prediction(REAL(predy), nobs, ncol(predy), REAL(weights), 
00092                      sweights, dprediction);
00093     /* </FIXME> */
00094 
00095     teststat = REAL(S3get_teststat(node));
00096     pvalue = REAL(S3get_criterion(node));
00097 
00098     /* try the two out of ninputs best inputs variables */
00099     /* <FIXME> be more flexible and add a parameter controlling
00100                the number of inputs tried </FIXME> */
00101     for (j = 0; j < 2; j++) {
00102 
00103         smax = C_max(pvalue, ninputs);
00104         REAL(S3get_maxcriterion(node))[0] = smax;
00105     
00106         /* if the global null hypothesis was rejected */
00107         if (smax > mincriterion && !TERMINAL) {
00108 
00109             /* the input variable with largest association to the response */
00110             jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00111 
00112             /* get the raw numeric values or the codings of a factor */
00113             x = get_variable(inputs, jselect);
00114             if (has_missings(inputs, jselect)) {
00115                 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect), 
00116                                     PL2_expcovinfSym);
00117                 thisweights = C_tempweights(jselect, weights, fitmem, inputs);
00118             } else {
00119                 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00120                 thisweights = REAL(weights);
00121             }
00122 
00123             /* <FIXME> handle ordered factors separatly??? </FIXME> */
00124             if (!is_nominal(inputs, jselect)) {
00125             
00126                 /* search for a split in a ordered variable x */
00127                 split = S3get_primarysplit(node);
00128                 
00129                 /* check if the n-vector of splitstatistics 
00130                    should be returned for each primary split */
00131                 if (get_savesplitstats(tgctrl)) {
00132                     C_init_orderedsplit(split, nobs);
00133                     splitstat = REAL(S3get_splitstatistics(split));
00134                 } else {
00135                     C_init_orderedsplit(split, 0);
00136                     splitstat = REAL(get_splitstatistics(fitmem));
00137                 }
00138 
00139                 C_split(REAL(x), 1, REAL(testy), q, thisweights, nobs,
00140                         INTEGER(get_ordering(inputs, jselect)), splitctrl, 
00141                         GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00142                         expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00143                         splitstat);
00144                 S3set_variableID(split, jselect);
00145              } else {
00146            
00147                  /* search of a set of levels (split) in a numeric variable x */
00148                  split = S3get_primarysplit(node);
00149                  
00150                 /* check if the n-vector of splitstatistics 
00151                    should be returned for each primary split */
00152                 if (get_savesplitstats(tgctrl)) {
00153                     C_init_nominalsplit(split, 
00154                         LENGTH(get_levels(inputs, jselect)), 
00155                         nobs);
00156                     splitstat = REAL(S3get_splitstatistics(split));
00157                 } else {
00158                     C_init_nominalsplit(split, 
00159                         LENGTH(get_levels(inputs, jselect)), 
00160                         0);
00161                     splitstat = REAL(get_splitstatistics(fitmem));
00162                 }
00163           
00164                  linexpcov = get_varmemory(fitmem, jselect);
00165                  standstat = Calloc(get_dimension(linexpcov), double);
00166                  C_standardize(REAL(GET_SLOT(linexpcov, 
00167                                              PL2_linearstatisticSym)),
00168                                REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00169                                REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00170                                get_dimension(linexpcov), get_tol(splitctrl), 
00171                                standstat);
00172  
00173                  C_splitcategorical(INTEGER(x), 
00174                                     LENGTH(get_levels(inputs, jselect)), 
00175                                     REAL(testy), q, thisweights, 
00176                                     nobs, standstat, splitctrl, 
00177                                     GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00178                                     expcovinf, &cutpoint, 
00179                                     INTEGER(S3get_splitpoint(split)),
00180                                     &maxstat, splitstat);
00181 
00182                  /* compute which levels of a factor are available in this node 
00183                     (for printing) later on. A real `table' for this node would
00184                     induce too much overhead here. Maybe later. */
00185                     
00186                  itable = INTEGER(S3get_table(split));
00187                  dxtransf = REAL(get_transformation(inputs, jselect));
00188                  for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00189                      itable[k] = 0;
00190                      for (i = 0; i < nobs; i++) {
00191                          if (dxtransf[k * nobs + i] * thisweights[i] > 0) {
00192                              itable[k] = 1;
00193                              continue;
00194                          }
00195                      }
00196                  }
00197 
00198                  Free(standstat);
00199             }
00200             if (maxstat == 0) {
00201             
00202                 if (j == 1) {          
00203                     S3set_nodeterminal(node);
00204                 } else {
00205                     /* do not look at jselect in next iteration */
00206                     pvalue[jselect - 1] = R_NegInf;
00207                 }
00208             } else {
00209                 S3set_variableID(split, jselect);
00210                 break;
00211             }
00212         } else {
00213             S3set_nodeterminal(node);
00214             break;
00215         }
00216     }
00217 }       
00218 
00219 
00228 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00229             
00230      SEXP ans;
00231      
00232      PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00233      C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample), 
00234                  get_maxsurrogate(get_splitctrl(controls)),
00235                  ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00236 
00237      C_Node(ans, learnsample, weights, fitmem, controls, 0, 1);
00238      UNPROTECT(1);
00239      return(ans);
00240 }

Generated on Wed Jan 27 15:01:21 2010 for party by  doxygen 1.6.1