00001
00009 #include "party.h"
00010
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00022 SEXP fitmem) {
00023
00024 SEXP x, y, expcovinf;
00025 SEXP splitctrl, inputs;
00026 SEXP split, thiswhichNA;
00027 int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028 double ms, cp, *thisweights, *cutpoint, *maxstat,
00029 *splitstat, *dweights, *tweights, *dx, *dy;
00030 double cut, *twotab;
00031
00032 nobs = get_nobs(learnsample);
00033 ninputs = get_ninputs(learnsample);
00034 splitctrl = get_splitctrl(controls);
00035 maxsurr = get_maxsurrogate(splitctrl);
00036 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037 jselect = S3get_variableID(S3get_primarysplit(node));
00038 y = S3get_nodeweights(VECTOR_ELT(node, 7));
00039
00040 for (j = 0; j < ninputs; j++) {
00041 if (is_nominal(inputs, j + 1)) continue;
00042 nvar++;
00043 }
00044 nvar--;
00045 Rprintf("nvar: %d\n", nvar);
00046
00047 if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00048 error("nodes does not have %d surrogate splits", maxsurr);
00049 if (maxsurr > nvar)
00050 error("cannot set up %d surrogate splits with only %d ordered input variable(s)",
00051 maxsurr, nvar);
00052
00053 tweights = Calloc(nobs, double);
00054 dweights = REAL(weights);
00055 for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00056 if (has_missings(inputs, jselect)) {
00057 thiswhichNA = get_missings(inputs, jselect);
00058 for (k = 0; k < LENGTH(thiswhichNA); k++)
00059 tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00060 }
00061
00062 expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00063 C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00064
00065 splitstat = REAL(get_splitstatistics(fitmem));
00066
00067 maxstat = Calloc(ninputs, double);
00068 cutpoint = Calloc(ninputs, double);
00069 order = Calloc(ninputs, int);
00070
00071
00072
00073
00074
00075
00076 for (j = 0; j < ninputs; j++) {
00077
00078 order[j] = j + 1;
00079 maxstat[j] = 0.0;
00080 cutpoint[j] = 0.0;
00081
00082
00083 if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00084 continue;
00085
00086 x = get_variable(inputs, j + 1);
00087
00088 if (has_missings(inputs, j + 1)) {
00089
00090 thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00091
00092 C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00093
00094 C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00095 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00096 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00097 expcovinf, &cp, &ms, splitstat);
00098 } else {
00099
00100 C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00101 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00102 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00103 expcovinf, &cp, &ms, splitstat);
00104 }
00105
00106 maxstat[j] = -ms;
00107 cutpoint[j] = cp;
00108 }
00109
00110
00111 rsort_with_index(maxstat, order, ninputs);
00112
00113 twotab = Calloc(4, double);
00114
00115
00116 for (j = 0; j < maxsurr; j++) {
00117
00118 for (i = 0; i < 4; i++) twotab[i] = 0.0;
00119 cut = cutpoint[order[j] - 1];
00120 SET_VECTOR_ELT(S3get_surrogatesplits(node), j,
00121 split = allocVector(VECSXP, SPLIT_LENGTH));
00122 C_init_orderedsplit(split, 0);
00123 S3set_variableID(split, order[j]);
00124 REAL(S3get_splitpoint(split))[0] = cut;
00125 dx = REAL(get_variable(inputs, order[j]));
00126 dy = REAL(y);
00127
00128
00129
00130
00131
00132 for (i = 0; i < nobs; i++) {
00133 twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00134 twotab[1] += (dy[i] == 1) * tweights[i];
00135 twotab[2] += (dx[i] <= cut) * tweights[i];
00136 twotab[3] += tweights[i];
00137 }
00138 S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] /
00139 twotab[3]) > 0);
00140 }
00141
00142 Free(maxstat);
00143 Free(cutpoint);
00144 Free(order);
00145 Free(tweights);
00146 Free(twotab);
00147 }
00148
00159 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00160 SEXP fitmem) {
00161
00162 C_surrogates(node, learnsample, weights, controls, fitmem);
00163 return(S3get_surrogatesplits(node));
00164
00165 }
00166
00174 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00175
00176 SEXP weights, split, surrsplit;
00177 SEXP inputs, whichNA;
00178 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00179 int *iwhichNA, k;
00180 int nobs, i, nna, ns;
00181
00182 weights = S3get_nodeweights(node);
00183 dweights = REAL(weights);
00184 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00185 nobs = get_nobs(learnsample);
00186
00187 leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00188 rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00189 surrsplit = S3get_surrogatesplits(node);
00190
00191
00192 split = S3get_primarysplit(node);
00193 if (has_missings(inputs, S3get_variableID(split))) {
00194
00195
00196 whichNA = get_missings(inputs, S3get_variableID(split));
00197 iwhichNA = INTEGER(whichNA);
00198 nna = LENGTH(whichNA);
00199
00200
00201 for (k = 0; k < nna; k++) {
00202 ns = 0;
00203 i = iwhichNA[k] - 1;
00204 if (dweights[i] == 0) continue;
00205
00206
00207 while(TRUE) {
00208
00209 if (ns >= LENGTH(surrsplit)) break;
00210
00211 split = VECTOR_ELT(surrsplit, ns);
00212 if (has_missings(inputs, S3get_variableID(split))) {
00213 if (INTEGER(get_missings(inputs,
00214 S3get_variableID(split)))[i]) {
00215 ns++;
00216 continue;
00217 }
00218 }
00219
00220 cutpoint = REAL(S3get_splitpoint(split))[0];
00221 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00222
00223 if (S3get_toleft(split)) {
00224 if (dx[i] <= cutpoint) {
00225 leftweights[i] = dweights[i];
00226 rightweights[i] = 0.0;
00227 } else {
00228 rightweights[i] = dweights[i];
00229 leftweights[i] = 0.0;
00230 }
00231 } else {
00232 if (dx[i] <= cutpoint) {
00233 rightweights[i] = dweights[i];
00234 leftweights[i] = 0.0;
00235 } else {
00236 leftweights[i] = dweights[i];
00237 rightweights[i] = 0.0;
00238 }
00239 }
00240 break;
00241 }
00242 }
00243 }
00244 }