00001
00009 #include "party.h"
00010
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00022 SEXP fitmem) {
00023
00024 SEXP x, y, expcovinf;
00025 SEXP splitctrl, inputs;
00026 SEXP split, thiswhichNA;
00027 int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028 double ms, cp, *thisweights, *cutpoint, *maxstat,
00029 *splitstat, *dweights, *tweights, *dx, *dy;
00030 double cut, *twotab;
00031
00032 nobs = get_nobs(learnsample);
00033 ninputs = get_ninputs(learnsample);
00034 splitctrl = get_splitctrl(controls);
00035 maxsurr = get_maxsurrogate(splitctrl);
00036 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037 jselect = S3get_variableID(S3get_primarysplit(node));
00038 y = S3get_nodeweights(VECTOR_ELT(node, 7));
00039
00040 for (j = 0; j < ninputs; j++) {
00041 if (is_nominal(inputs, j + 1)) continue;
00042 nvar++;
00043 }
00044 nvar--;
00045
00046 if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00047 error("nodes does not have %d surrogate splits", maxsurr);
00048 if (maxsurr > nvar)
00049 error("cannot set up %d surrogate splits with only %d ordered input variable(s)",
00050 maxsurr, nvar);
00051
00052 tweights = Calloc(nobs, double);
00053 dweights = REAL(weights);
00054 for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00055 if (has_missings(inputs, jselect)) {
00056 thiswhichNA = get_missings(inputs, jselect);
00057 for (k = 0; k < LENGTH(thiswhichNA); k++)
00058 tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00059 }
00060
00061 expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00062 C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00063
00064 splitstat = REAL(get_splitstatistics(fitmem));
00065
00066 maxstat = Calloc(ninputs, double);
00067 cutpoint = Calloc(ninputs, double);
00068 order = Calloc(ninputs, int);
00069
00070
00071
00072
00073
00074
00075 for (j = 0; j < ninputs; j++) {
00076
00077 order[j] = j + 1;
00078 maxstat[j] = 0.0;
00079 cutpoint[j] = 0.0;
00080
00081
00082 if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00083 continue;
00084
00085 x = get_variable(inputs, j + 1);
00086
00087 if (has_missings(inputs, j + 1)) {
00088
00089 thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00090
00091 C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00092
00093 C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00094 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00095 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00096 expcovinf, &cp, &ms, splitstat);
00097 } else {
00098
00099 C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00100 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00101 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00102 expcovinf, &cp, &ms, splitstat);
00103 }
00104
00105 maxstat[j] = -ms;
00106 cutpoint[j] = cp;
00107 }
00108
00109
00110 rsort_with_index(maxstat, order, ninputs);
00111
00112 twotab = Calloc(4, double);
00113
00114
00115 for (j = 0; j < maxsurr; j++) {
00116
00117 for (i = 0; i < 4; i++) twotab[i] = 0.0;
00118 cut = cutpoint[order[j] - 1];
00119 SET_VECTOR_ELT(S3get_surrogatesplits(node), j,
00120 split = allocVector(VECSXP, SPLIT_LENGTH));
00121 C_init_orderedsplit(split, 0);
00122 S3set_variableID(split, order[j]);
00123 REAL(S3get_splitpoint(split))[0] = cut;
00124 dx = REAL(get_variable(inputs, order[j]));
00125 dy = REAL(y);
00126
00127
00128
00129
00130
00131 for (i = 0; i < nobs; i++) {
00132 twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00133 twotab[1] += (dy[i] == 1) * tweights[i];
00134 twotab[2] += (dx[i] <= cut) * tweights[i];
00135 twotab[3] += tweights[i];
00136 }
00137 S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] /
00138 twotab[3]) > 0);
00139 }
00140
00141 Free(maxstat);
00142 Free(cutpoint);
00143 Free(order);
00144 Free(tweights);
00145 Free(twotab);
00146 }
00147
00158 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00159 SEXP fitmem) {
00160
00161 C_surrogates(node, learnsample, weights, controls, fitmem);
00162 return(S3get_surrogatesplits(node));
00163
00164 }
00165
00173 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00174
00175 SEXP weights, split, surrsplit;
00176 SEXP inputs, whichNA;
00177 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00178 int *iwhichNA, k;
00179 int nobs, i, nna, ns;
00180
00181 weights = S3get_nodeweights(node);
00182 dweights = REAL(weights);
00183 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00184 nobs = get_nobs(learnsample);
00185
00186 leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00187 rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00188 surrsplit = S3get_surrogatesplits(node);
00189
00190
00191 split = S3get_primarysplit(node);
00192 if (has_missings(inputs, S3get_variableID(split))) {
00193
00194
00195 whichNA = get_missings(inputs, S3get_variableID(split));
00196 iwhichNA = INTEGER(whichNA);
00197 nna = LENGTH(whichNA);
00198
00199
00200 for (k = 0; k < nna; k++) {
00201 ns = 0;
00202 i = iwhichNA[k] - 1;
00203 if (dweights[i] == 0) continue;
00204
00205
00206 while(TRUE) {
00207
00208 if (ns >= LENGTH(surrsplit)) break;
00209
00210 split = VECTOR_ELT(surrsplit, ns);
00211 if (has_missings(inputs, S3get_variableID(split))) {
00212 if (INTEGER(get_missings(inputs,
00213 S3get_variableID(split)))[i]) {
00214 ns++;
00215 continue;
00216 }
00217 }
00218
00219 cutpoint = REAL(S3get_splitpoint(split))[0];
00220 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00221
00222 if (S3get_toleft(split)) {
00223 if (dx[i] <= cutpoint) {
00224 leftweights[i] = dweights[i];
00225 rightweights[i] = 0.0;
00226 } else {
00227 rightweights[i] = dweights[i];
00228 leftweights[i] = 0.0;
00229 }
00230 } else {
00231 if (dx[i] <= cutpoint) {
00232 rightweights[i] = dweights[i];
00233 leftweights[i] = 0.0;
00234 } else {
00235 leftweights[i] = dweights[i];
00236 rightweights[i] = 0.0;
00237 }
00238 }
00239 break;
00240 }
00241 }
00242 }
00243 }