00001
00009 #include "party.h"
00010
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00022 SEXP fitmem) {
00023
00024 SEXP x, y, expcovinf;
00025 SEXP splitctrl, inputs;
00026 SEXP split, thiswhichNA;
00027 int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028 double ms, cp, *thisweights, *cutpoint, *maxstat,
00029 *splitstat, *dweights, *tweights, *dx, *dy;
00030 double cut, *twotab, *ytmp, sumw = 0.0;
00031
00032 nobs = get_nobs(learnsample);
00033 ninputs = get_ninputs(learnsample);
00034 splitctrl = get_splitctrl(controls);
00035 maxsurr = get_maxsurrogate(splitctrl);
00036 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037 jselect = S3get_variableID(S3get_primarysplit(node));
00038
00039
00040 y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT));
00041 ytmp = Calloc(nobs, double);
00042 for (i = 0; i < nobs; i++) {
00043 ytmp[i] = REAL(y)[i];
00044 if (ytmp[i] > 1.0) ytmp[i] = 1.0;
00045 }
00046
00047 for (j = 0; j < ninputs; j++) {
00048 if (is_nominal(inputs, j + 1)) continue;
00049 nvar++;
00050 }
00051 nvar--;
00052
00053 if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00054 error("nodes does not have %d surrogate splits", maxsurr);
00055 if (maxsurr > nvar)
00056 error("cannot set up %d surrogate splits with only %d ordered input variable(s)",
00057 maxsurr, nvar);
00058
00059 tweights = Calloc(nobs, double);
00060 dweights = REAL(weights);
00061 for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00062 if (has_missings(inputs, jselect)) {
00063 thiswhichNA = get_missings(inputs, jselect);
00064 for (k = 0; k < LENGTH(thiswhichNA); k++)
00065 tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00066 }
00067
00068
00069 sumw = 0.0;
00070 for (i = 0; i < nobs; i++) sumw += tweights[i];
00071 if (sumw < 2.0)
00072 error("can't implement surrogate splits, not enough observations available");
00073
00074 expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00075 C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf);
00076
00077 splitstat = REAL(get_splitstatistics(fitmem));
00078
00079 maxstat = Calloc(ninputs, double);
00080 cutpoint = Calloc(ninputs, double);
00081 order = Calloc(ninputs, int);
00082
00083
00084
00085
00086
00087
00088 for (j = 0; j < ninputs; j++) {
00089
00090 order[j] = j + 1;
00091 maxstat[j] = 0.0;
00092 cutpoint[j] = 0.0;
00093
00094
00095 if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00096 continue;
00097
00098 x = get_variable(inputs, j + 1);
00099
00100 if (has_missings(inputs, j + 1)) {
00101
00102 thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00103
00104
00105 sumw = 0.0;
00106 for (i = 0; i < nobs; i++) sumw += thisweights[i];
00107 if (sumw < 2.0) continue;
00108
00109 C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf);
00110
00111 C_split(REAL(x), 1, ytmp, 1, thisweights, nobs,
00112 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00113 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00114 expcovinf, &cp, &ms, splitstat);
00115 } else {
00116
00117 C_split(REAL(x), 1, ytmp, 1, tweights, nobs,
00118 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00119 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00120 expcovinf, &cp, &ms, splitstat);
00121 }
00122
00123 maxstat[j] = -ms;
00124 cutpoint[j] = cp;
00125 }
00126
00127
00128
00129
00130
00131
00132
00133
00134 rsort_with_index(maxstat, order, ninputs);
00135
00136 twotab = Calloc(4, double);
00137
00138
00139 for (j = 0; j < maxsurr; j++) {
00140
00141 for (i = 0; i < 4; i++) twotab[i] = 0.0;
00142 cut = cutpoint[order[j] - 1];
00143 SET_VECTOR_ELT(S3get_surrogatesplits(node), j,
00144 split = allocVector(VECSXP, SPLIT_LENGTH));
00145 C_init_orderedsplit(split, 0);
00146 S3set_variableID(split, order[j]);
00147 REAL(S3get_splitpoint(split))[0] = cut;
00148 dx = REAL(get_variable(inputs, order[j]));
00149 dy = REAL(y);
00150
00151
00152
00153
00154
00155 for (i = 0; i < nobs; i++) {
00156 twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00157 twotab[1] += (dy[i] == 1) * tweights[i];
00158 twotab[2] += (dx[i] <= cut) * tweights[i];
00159 twotab[3] += tweights[i];
00160 }
00161 S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] /
00162 twotab[3]) > 0);
00163 }
00164
00165 Free(maxstat);
00166 Free(cutpoint);
00167 Free(order);
00168 Free(tweights);
00169 Free(twotab);
00170 Free(ytmp);
00171 }
00172
00183 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00184 SEXP fitmem) {
00185
00186 C_surrogates(node, learnsample, weights, controls, fitmem);
00187 return(S3get_surrogatesplits(node));
00188
00189 }
00190
00198 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00199
00200 SEXP weights, split, surrsplit;
00201 SEXP inputs, whichNA, whichNAns;
00202 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00203 int *iwhichNA, k;
00204 int i, nna, ns;
00205
00206 weights = S3get_nodeweights(node);
00207 dweights = REAL(weights);
00208 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00209
00210 leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00211 rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00212 surrsplit = S3get_surrogatesplits(node);
00213
00214
00215 split = S3get_primarysplit(node);
00216 if (has_missings(inputs, S3get_variableID(split))) {
00217
00218
00219 whichNA = get_missings(inputs, S3get_variableID(split));
00220 iwhichNA = INTEGER(whichNA);
00221 nna = LENGTH(whichNA);
00222
00223
00224 for (k = 0; k < nna; k++) {
00225 ns = 0;
00226 i = iwhichNA[k] - 1;
00227 if (dweights[i] == 0) continue;
00228
00229
00230 while(TRUE) {
00231
00232 if (ns >= LENGTH(surrsplit)) break;
00233
00234 split = VECTOR_ELT(surrsplit, ns);
00235 if (has_missings(inputs, S3get_variableID(split))) {
00236 whichNAns = get_missings(inputs, S3get_variableID(split));
00237 if (C_i_in_set(i + 1, whichNAns)) {
00238 ns++;
00239 continue;
00240 }
00241 }
00242
00243 cutpoint = REAL(S3get_splitpoint(split))[0];
00244 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00245
00246 if (S3get_toleft(split)) {
00247 if (dx[i] <= cutpoint) {
00248 leftweights[i] = dweights[i];
00249 rightweights[i] = 0.0;
00250 } else {
00251 rightweights[i] = dweights[i];
00252 leftweights[i] = 0.0;
00253 }
00254 } else {
00255 if (dx[i] <= cutpoint) {
00256 rightweights[i] = dweights[i];
00257 leftweights[i] = 0.0;
00258 } else {
00259 leftweights[i] = dweights[i];
00260 rightweights[i] = 0.0;
00261 }
00262 }
00263 break;
00264 }
00265 }
00266 }
00267 }