Node.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00022 void C_prediction(const double *y, int n, int q, const double *weights, 
00023                   const double sweights, double *ans) {
00024 
00025     int i, j, jn;
00026     
00027     for (j = 0; j < q; j++) {
00028         ans[j] = 0.0;
00029         jn = j * n;
00030         for (i = 0; i < n; i++) 
00031             ans[j] += weights[i] * y[jn + i];
00032         ans[j] = ans[j] / sweights;
00033     }
00034 }
00035 
00036 
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights, 
00049             SEXP fitmem, SEXP controls, int TERMINAL) {
00050     
00051     int nobs, ninputs, jselect, q, j, k, i;
00052     double mincriterion, sweights, *dprediction;
00053     double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054     double *standstat, *splitstat;
00055     SEXP responses, inputs, x, expcovinf, linexpcov;
00056     SEXP varctrl, splitctrl, gtctrl, tgctrl, split, testy, predy;
00057     double *dxtransf, *dweights, *thisweights;
00058     int *itable;
00059     
00060     nobs = get_nobs(learnsample);
00061     ninputs = get_ninputs(learnsample);
00062     varctrl = get_varctrl(controls);
00063     splitctrl = get_splitctrl(controls);
00064     gtctrl = get_gtctrl(controls);
00065     tgctrl = get_tgctrl(controls);
00066     mincriterion = get_mincriterion(gtctrl);
00067     responses = GET_SLOT(learnsample, PL2_responsesSym);
00068     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00069     testy = get_test_trafo(responses);
00070     predy = get_predict_trafo(responses);
00071     q = ncol(testy);
00072 
00073     /* <FIXME> we compute C_GlobalTest even for TERMINAL nodes! </FIXME> */
00074 
00075     /* compute the test statistics and the node criteria for each input */        
00076     C_GlobalTest(learnsample, weights, fitmem, varctrl,
00077                  gtctrl, get_minsplit(splitctrl), 
00078                  REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00079     
00080     /* sum of weights: C_GlobalTest did nothing if sweights < mincriterion */
00081     sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym), 
00082                              PL2_sumweightsSym))[0];
00083     REAL(VECTOR_ELT(node, S3_SUMWEIGHTS))[0] = sweights;
00084 
00085     /* compute the prediction of this node */
00086     dprediction = REAL(S3get_prediction(node));
00087 
00088     /* <FIXME> feed raw numeric values OR dummy encoded factors as y 
00089        Problem: what happens for survival times ? */
00090     C_prediction(REAL(predy), nobs, ncol(predy), REAL(weights), 
00091                      sweights, dprediction);
00092     /* </FIXME> */
00093 
00094     teststat = REAL(S3get_teststat(node));
00095     pvalue = REAL(S3get_criterion(node));
00096 
00097     /* try the two out of ninputs best inputs variables */
00098     /* <FIXME> be more flexible and add a parameter controlling
00099                the number of inputs tried </FIXME> */
00100     for (j = 0; j < 2; j++) {
00101 
00102         smax = C_max(pvalue, ninputs);
00103         REAL(S3get_maxcriterion(node))[0] = smax;
00104     
00105         /* if the global null hypothesis was rejected */
00106         if (smax > mincriterion && !TERMINAL) {
00107 
00108             /* the input variable with largest association to the response */
00109             jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00110 
00111             /* get the raw numeric values or the codings of a factor */
00112             x = get_variable(inputs, jselect);
00113             if (has_missings(inputs, jselect)) {
00114                 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect), 
00115                                     PL2_expcovinfSym);
00116                 thisweights = C_tempweights(jselect, weights, fitmem, inputs);
00117             } else {
00118                 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00119                 thisweights = REAL(weights);
00120             }
00121 
00122             /* <FIXME> handle ordered factors separatly??? </FIXME> */
00123             if (!is_nominal(inputs, jselect)) {
00124             
00125                 /* search for a split in a ordered variable x */
00126                 split = S3get_primarysplit(node);
00127                 
00128                 /* check if the n-vector of splitstatistics 
00129                    should be returned for each primary split */
00130                 if (get_savesplitstats(tgctrl)) {
00131                     C_init_orderedsplit(split, nobs);
00132                     splitstat = REAL(S3get_splitstatistics(split));
00133                 } else {
00134                     C_init_orderedsplit(split, 0);
00135                     splitstat = REAL(get_splitstatistics(fitmem));
00136                 }
00137 
00138                 C_split(REAL(x), 1, REAL(testy), q, thisweights, nobs,
00139                         INTEGER(get_ordering(inputs, jselect)), splitctrl, 
00140                         GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00141                         expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00142                         splitstat);
00143                 S3set_variableID(split, jselect);
00144              } else {
00145            
00146                  /* search of a set of levels (split) in a numeric variable x */
00147                  split = S3get_primarysplit(node);
00148                  
00149                 /* check if the n-vector of splitstatistics 
00150                    should be returned for each primary split */
00151                 if (get_savesplitstats(tgctrl)) {
00152                     C_init_nominalsplit(split, 
00153                         LENGTH(get_levels(inputs, jselect)), 
00154                         nobs);
00155                     splitstat = REAL(S3get_splitstatistics(split));
00156                 } else {
00157                     C_init_nominalsplit(split, 
00158                         LENGTH(get_levels(inputs, jselect)), 
00159                         0);
00160                     splitstat = REAL(get_splitstatistics(fitmem));
00161                 }
00162           
00163                  linexpcov = get_varmemory(fitmem, jselect);
00164                  standstat = Calloc(get_dimension(linexpcov), double);
00165                  C_standardize(REAL(GET_SLOT(linexpcov, 
00166                                              PL2_linearstatisticSym)),
00167                                REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00168                                REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00169                                get_dimension(linexpcov), get_tol(splitctrl), 
00170                                standstat);
00171  
00172                  C_splitcategorical(INTEGER(x), 
00173                                     LENGTH(get_levels(inputs, jselect)), 
00174                                     REAL(testy), q, thisweights, 
00175                                     nobs, standstat, splitctrl, 
00176                                     GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00177                                     expcovinf, &cutpoint, 
00178                                     INTEGER(S3get_splitpoint(split)),
00179                                     &maxstat, splitstat);
00180 
00181                  /* compute which levels of a factor are available in this node 
00182                     (for printing) later on. A real `table' for this node would
00183                     induce too much overhead here. Maybe later. */
00184                     
00185                  itable = INTEGER(S3get_table(split));
00186                  dxtransf = REAL(get_transformation(inputs, jselect));
00187                  for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00188                      itable[k] = 0;
00189                      for (i = 0; i < nobs; i++) {
00190                          if (dxtransf[k * nobs + i] * thisweights[i] > 0) {
00191                              itable[k] = 1;
00192                              continue;
00193                          }
00194                      }
00195                  }
00196 
00197                  Free(standstat);
00198             }
00199             if (maxstat == 0) {
00200             
00201                 if (j == 1) {          
00202                     S3set_nodeterminal(node);
00203                 } else {
00204                     /* do not look at jselect in next iteration */
00205                     pvalue[jselect - 1] = R_NegInf;
00206                 }
00207             } else {
00208                 S3set_variableID(split, jselect);
00209                 break;
00210             }
00211         } else {
00212             S3set_nodeterminal(node);
00213             break;
00214         }
00215     }
00216 }       
00217 
00218 
00227 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00228             
00229      SEXP ans;
00230      
00231      PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00232      C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample), 
00233                  get_maxsurrogate(get_splitctrl(controls)),
00234                  ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00235 
00236      C_Node(ans, learnsample, weights, fitmem, controls, 0);
00237      UNPROTECT(1);
00238      return(ans);
00239 }

Generated on Mon Feb 23 11:05:48 2009 for party by  doxygen 1.5.6