00001
00009 #include "party.h"
00010
00011
00022 void C_prediction(const double *y, int n, int q, const double *weights,
00023 const double sweights, double *ans) {
00024
00025 int i, j, jn;
00026
00027 for (j = 0; j < q; j++) {
00028 ans[j] = 0.0;
00029 jn = j * n;
00030 for (i = 0; i < n; i++)
00031 ans[j] += weights[i] * y[jn + i];
00032 ans[j] = ans[j] / sweights;
00033 }
00034 }
00035
00036
00049 void C_Node(SEXP node, SEXP learnsample, SEXP weights,
00050 SEXP fitmem, SEXP controls, int TERMINAL, int depth) {
00051
00052 int nobs, ninputs, jselect, q, j, k, i;
00053 double mincriterion, sweights, *dprediction;
00054 double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00055 double *standstat, *splitstat;
00056 SEXP responses, inputs, x, expcovinf, linexpcov;
00057 SEXP varctrl, splitctrl, gtctrl, tgctrl, split, testy, predy;
00058 double *dxtransf, *dweights, *thisweights;
00059 int *itable;
00060
00061 nobs = get_nobs(learnsample);
00062 ninputs = get_ninputs(learnsample);
00063 varctrl = get_varctrl(controls);
00064 splitctrl = get_splitctrl(controls);
00065 gtctrl = get_gtctrl(controls);
00066 tgctrl = get_tgctrl(controls);
00067 mincriterion = get_mincriterion(gtctrl);
00068 responses = GET_SLOT(learnsample, PL2_responsesSym);
00069 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00070 testy = get_test_trafo(responses);
00071 predy = get_predict_trafo(responses);
00072 q = ncol(testy);
00073
00074
00075
00076
00077 C_GlobalTest(learnsample, weights, fitmem, varctrl,
00078 gtctrl, get_minsplit(splitctrl),
00079 REAL(S3get_teststat(node)), REAL(S3get_criterion(node)), depth);
00080
00081
00082 sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym),
00083 PL2_sumweightsSym))[0];
00084 REAL(VECTOR_ELT(node, S3_SUMWEIGHTS))[0] = sweights;
00085
00086
00087 dprediction = REAL(S3get_prediction(node));
00088
00089
00090
00091 C_prediction(REAL(predy), nobs, ncol(predy), REAL(weights),
00092 sweights, dprediction);
00093
00094
00095 teststat = REAL(S3get_teststat(node));
00096 pvalue = REAL(S3get_criterion(node));
00097
00098
00099
00100
00101 for (j = 0; j < 2; j++) {
00102
00103 smax = C_max(pvalue, ninputs);
00104 REAL(S3get_maxcriterion(node))[0] = smax;
00105
00106
00107 if (smax > mincriterion && !TERMINAL) {
00108
00109
00110 jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00111
00112
00113 x = get_variable(inputs, jselect);
00114 if (has_missings(inputs, jselect)) {
00115 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect),
00116 PL2_expcovinfSym);
00117 thisweights = C_tempweights(jselect, weights, fitmem, inputs);
00118 } else {
00119 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00120 thisweights = REAL(weights);
00121 }
00122
00123
00124 if (!is_nominal(inputs, jselect)) {
00125
00126
00127 split = S3get_primarysplit(node);
00128
00129
00130
00131 if (get_savesplitstats(tgctrl)) {
00132 C_init_orderedsplit(split, nobs);
00133 splitstat = REAL(S3get_splitstatistics(split));
00134 } else {
00135 C_init_orderedsplit(split, 0);
00136 splitstat = REAL(get_splitstatistics(fitmem));
00137 }
00138
00139 C_split(REAL(x), 1, REAL(testy), q, thisweights, nobs,
00140 INTEGER(get_ordering(inputs, jselect)), splitctrl,
00141 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00142 expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00143 splitstat);
00144 S3set_variableID(split, jselect);
00145 } else {
00146
00147
00148 split = S3get_primarysplit(node);
00149
00150
00151
00152 if (get_savesplitstats(tgctrl)) {
00153 C_init_nominalsplit(split,
00154 LENGTH(get_levels(inputs, jselect)),
00155 nobs);
00156 splitstat = REAL(S3get_splitstatistics(split));
00157 } else {
00158 C_init_nominalsplit(split,
00159 LENGTH(get_levels(inputs, jselect)),
00160 0);
00161 splitstat = REAL(get_splitstatistics(fitmem));
00162 }
00163
00164 linexpcov = get_varmemory(fitmem, jselect);
00165 standstat = Calloc(get_dimension(linexpcov), double);
00166 C_standardize(REAL(GET_SLOT(linexpcov,
00167 PL2_linearstatisticSym)),
00168 REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00169 REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00170 get_dimension(linexpcov), get_tol(splitctrl),
00171 standstat);
00172
00173 C_splitcategorical(INTEGER(x),
00174 LENGTH(get_levels(inputs, jselect)),
00175 REAL(testy), q, thisweights,
00176 nobs, standstat, splitctrl,
00177 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00178 expcovinf, &cutpoint,
00179 INTEGER(S3get_splitpoint(split)),
00180 &maxstat, splitstat);
00181
00182
00183
00184
00185
00186 itable = INTEGER(S3get_table(split));
00187 dxtransf = REAL(get_transformation(inputs, jselect));
00188 for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00189 itable[k] = 0;
00190 for (i = 0; i < nobs; i++) {
00191 if (dxtransf[k * nobs + i] * thisweights[i] > 0) {
00192 itable[k] = 1;
00193 continue;
00194 }
00195 }
00196 }
00197
00198 Free(standstat);
00199 }
00200 if (maxstat == 0) {
00201
00202 if (j == 1) {
00203 S3set_nodeterminal(node);
00204 } else {
00205
00206 pvalue[jselect - 1] = R_NegInf;
00207 }
00208 } else {
00209 S3set_variableID(split, jselect);
00210 break;
00211 }
00212 } else {
00213 S3set_nodeterminal(node);
00214 break;
00215 }
00216 }
00217 }
00218
00219
00228 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00229
00230 SEXP ans;
00231
00232 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00233 C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample),
00234 get_maxsurrogate(get_splitctrl(controls)),
00235 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00236
00237 C_Node(ans, learnsample, weights, fitmem, controls, 0, 1);
00238 UNPROTECT(1);
00239 return(ans);
00240 }