00001
00009 #include "party.h"
00010
00011
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022
00023 SEXP weights, leftnode, rightnode, split;
00024 SEXP responses, inputs, whichNA;
00025 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026 double sleft = 0.0, sright = 0.0;
00027 int *ix, *levelset, *iwhichNA;
00028 int nobs, i, nna;
00029
00030 weights = S3get_nodeweights(node);
00031 dweights = REAL(weights);
00032 responses = GET_SLOT(learnsample, PL2_responsesSym);
00033 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034 nobs = get_nobs(learnsample);
00035
00036
00037 SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038 C_init_node(leftnode, nobs,
00039 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040 ncol(get_predict_trafo(responses)));
00041 leftweights = REAL(S3get_nodeweights(leftnode));
00042
00043
00044 SET_VECTOR_ELT(node, S3_RIGHT,
00045 rightnode = allocVector(VECSXP, NODE_LENGTH));
00046 C_init_node(rightnode, nobs,
00047 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00048 ncol(get_predict_trafo(responses)));
00049 rightweights = REAL(S3get_nodeweights(rightnode));
00050
00051
00052 split = S3get_primarysplit(node);
00053 if (has_missings(inputs, S3get_variableID(split))) {
00054 whichNA = get_missings(inputs, S3get_variableID(split));
00055 iwhichNA = INTEGER(whichNA);
00056 nna = LENGTH(whichNA);
00057 } else {
00058 nna = 0;
00059 whichNA = R_NilValue;
00060 iwhichNA = NULL;
00061 }
00062
00063 if (S3is_ordered(split)) {
00064 cutpoint = REAL(S3get_splitpoint(split))[0];
00065 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00066 for (i = 0; i < nobs; i++) {
00067 if (nna > 0) {
00068 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00069 }
00070 if (dx[i] <= cutpoint)
00071 leftweights[i] = dweights[i];
00072 else
00073 leftweights[i] = 0.0;
00074 rightweights[i] = dweights[i] - leftweights[i];
00075 sleft += leftweights[i];
00076 sright += rightweights[i];
00077 }
00078 } else {
00079 levelset = INTEGER(S3get_splitpoint(split));
00080 ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00081
00082 for (i = 0; i < nobs; i++) {
00083 if (nna > 0) {
00084 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00085 }
00086 if (levelset[ix[i] - 1])
00087 leftweights[i] = dweights[i];
00088 else
00089 leftweights[i] = 0.0;
00090 rightweights[i] = dweights[i] - leftweights[i];
00091 sleft += leftweights[i];
00092 sright += rightweights[i];
00093 }
00094 }
00095
00096
00097 if (nna > 0) {
00098 for (i = 0; i < nna; i++) {
00099 if (sleft > sright) {
00100 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00101 rightweights[iwhichNA[i] - 1] = 0.0;
00102 } else {
00103 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00104 leftweights[iwhichNA[i] - 1] = 0.0;
00105 }
00106 }
00107 }
00108 }
00109
00110
00121 SEXP C_get_node(SEXP subtree, SEXP newinputs,
00122 double mincriterion, int numobs, int varperm) {
00123
00124 SEXP split, whichNA, ssplit, surrsplit;
00125 double cutpoint, x, swleft, swright;
00126 int level, *levelset, i, ns;
00127
00128 if (S3get_nodeterminal(subtree) ||
00129 REAL(S3get_maxcriterion(subtree))[0] < mincriterion)
00130 return(subtree);
00131
00132 split = S3get_primarysplit(subtree);
00133
00134
00135 swleft = S3get_sumweights(S3get_leftnode(subtree));
00136 swright = S3get_sumweights(S3get_rightnode(subtree));
00137
00138
00139 if (S3get_variableID(split) == varperm) {
00140 if (unif_rand() < swleft / (swleft + swright)) {
00141 return(C_get_node(S3get_leftnode(subtree),
00142 newinputs, mincriterion, numobs, varperm));
00143 } else {
00144 return(C_get_node(S3get_rightnode(subtree),
00145 newinputs, mincriterion, numobs, varperm));
00146 }
00147 }
00148
00149
00150 if (has_missings(newinputs, S3get_variableID(split))) {
00151 whichNA = get_missings(newinputs, S3get_variableID(split));
00152
00153
00154 if (C_i_in_set(numobs + 1, whichNA)) {
00155
00156 surrsplit = S3get_surrogatesplits(subtree);
00157 ns = 0;
00158 i = numobs;
00159
00160
00161 while(TRUE) {
00162
00163 if (ns >= LENGTH(surrsplit)) break;
00164
00165 ssplit = VECTOR_ELT(surrsplit, ns);
00166 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00167 if (INTEGER(get_missings(newinputs,
00168 S3get_variableID(ssplit)))[i]) {
00169 ns++;
00170 continue;
00171 }
00172 }
00173
00174 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00175 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00176
00177 if (S3get_toleft(ssplit)) {
00178 if (x <= cutpoint) {
00179 return(C_get_node(S3get_leftnode(subtree),
00180 newinputs, mincriterion, numobs, varperm));
00181 } else {
00182 return(C_get_node(S3get_rightnode(subtree),
00183 newinputs, mincriterion, numobs, varperm));
00184 }
00185 } else {
00186 if (x <= cutpoint) {
00187 return(C_get_node(S3get_rightnode(subtree),
00188 newinputs, mincriterion, numobs, varperm));
00189 } else {
00190 return(C_get_node(S3get_leftnode(subtree),
00191 newinputs, mincriterion, numobs, varperm));
00192 }
00193 }
00194 break;
00195 }
00196
00197
00198 if (swleft > swright) {
00199 return(C_get_node(S3get_leftnode(subtree),
00200 newinputs, mincriterion, numobs, varperm));
00201 } else {
00202 return(C_get_node(S3get_rightnode(subtree),
00203 newinputs, mincriterion, numobs, varperm));
00204 }
00205 }
00206 }
00207
00208 if (S3is_ordered(split)) {
00209 cutpoint = REAL(S3get_splitpoint(split))[0];
00210 x = REAL(get_variable(newinputs,
00211 S3get_variableID(split)))[numobs];
00212 if (x <= cutpoint) {
00213 return(C_get_node(S3get_leftnode(subtree),
00214 newinputs, mincriterion, numobs, varperm));
00215 } else {
00216 return(C_get_node(S3get_rightnode(subtree),
00217 newinputs, mincriterion, numobs, varperm));
00218 }
00219 } else {
00220 levelset = INTEGER(S3get_splitpoint(split));
00221 level = INTEGER(get_variable(newinputs,
00222 S3get_variableID(split)))[numobs];
00223
00224 if (levelset[level - 1]) {
00225 return(C_get_node(S3get_leftnode(subtree), newinputs,
00226 mincriterion, numobs, varperm));
00227 } else {
00228 return(C_get_node(S3get_rightnode(subtree), newinputs,
00229 mincriterion, numobs, varperm));
00230 }
00231 }
00232 }
00233
00234
00243 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion,
00244 SEXP numobs) {
00245 return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00246 INTEGER(numobs)[0] - 1, -1));
00247 }
00248
00249
00256 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00257
00258 if (nodenum == S3get_nodeID(subtree)) return(subtree);
00259
00260 if (S3get_nodeterminal(subtree))
00261 error("no node with number %d\n", nodenum);
00262
00263 if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00264 return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00265 } else {
00266 return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00267 }
00268 }
00269
00270
00277 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00278 return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00279 }
00280
00281
00291 SEXP C_get_prediction(SEXP subtree, SEXP newinputs,
00292 double mincriterion, int numobs, int varperm) {
00293 return(S3get_prediction(C_get_node(subtree, newinputs,
00294 mincriterion, numobs, varperm)));
00295 }
00296
00297
00306 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs,
00307 double mincriterion, int numobs) {
00308 return(S3get_nodeweights(C_get_node(subtree, newinputs,
00309 mincriterion, numobs, -1)));
00310 }
00311
00312
00321 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00322 double mincriterion, int numobs) {
00323 return(S3get_nodeID(C_get_node(subtree, newinputs,
00324 mincriterion, numobs, -1)));
00325 }
00326
00327
00335 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00336
00337 SEXP ans;
00338 int nobs, i, *dans;
00339
00340 nobs = get_nobs(newinputs);
00341 PROTECT(ans = allocVector(INTSXP, nobs));
00342 dans = INTEGER(ans);
00343 for (i = 0; i < nobs; i++)
00344 dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00345 UNPROTECT(1);
00346 return(ans);
00347 }
00348
00349
00359 void C_predict(SEXP tree, SEXP newinputs, double mincriterion,
00360 int varperm, SEXP ans) {
00361
00362 int nobs, i;
00363
00364 nobs = get_nobs(newinputs);
00365 if (LENGTH(ans) != nobs)
00366 error("ans is not of length %d\n", nobs);
00367
00368 for (i = 0; i < nobs; i++)
00369 SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs,
00370 mincriterion, i, varperm));
00371 }
00372
00373
00381 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00382
00383 SEXP ans;
00384 int nobs;
00385
00386 nobs = get_nobs(newinputs);
00387 PROTECT(ans = allocVector(VECSXP, nobs));
00388 C_predict(tree, newinputs, REAL(mincriterion)[0],
00389 -1, ans);
00390 UNPROTECT(1);
00391 return(ans);
00392 }
00393
00402 SEXP R_predict2(SEXP tree, SEXP newinputs, SEXP mincriterion,
00403 SEXP varperm) {
00404
00405 SEXP ans;
00406 int nobs;
00407
00408 nobs = get_nobs(newinputs);
00409 PROTECT(ans = allocVector(VECSXP, nobs));
00410 C_predict(tree, newinputs, REAL(mincriterion)[0],
00411 INTEGER(varperm)[0], ans);
00412 UNPROTECT(1);
00413 return(ans);
00414 }
00415
00423 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00424
00425 int nobs, i, *iwhere;
00426
00427 nobs = LENGTH(where);
00428 iwhere = INTEGER(where);
00429 if (LENGTH(ans) != nobs)
00430 error("ans is not of length %d\n", nobs);
00431
00432 for (i = 0; i < nobs; i++)
00433 SET_VECTOR_ELT(ans, i, S3get_prediction(
00434 C_get_nodebynum(tree, iwhere[i])));
00435 }
00436
00437
00444 SEXP R_getpredictions(SEXP tree, SEXP where) {
00445
00446 SEXP ans;
00447 int nobs;
00448
00449 nobs = LENGTH(where);
00450 PROTECT(ans = allocVector(VECSXP, nobs));
00451 C_getpredictions(tree, where, ans);
00452 UNPROTECT(1);
00453 return(ans);
00454 }
00455
00466 SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights,
00467 SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00468
00469 SEXP ans, tree, bw;
00470 int ntrees, nobs, i, b, j, iwhere, oob = 0, count = 0, ntrain;
00471
00472 if (LOGICAL(oobpred)[0]) oob = 1;
00473
00474 nobs = get_nobs(newinputs);
00475 ntrees = LENGTH(forest);
00476
00477 if (oob) {
00478 if (LENGTH(VECTOR_ELT(weights, 0)) != nobs)
00479 error("number of observations don't match");
00480 }
00481
00482 tree = VECTOR_ELT(forest, 0);
00483 ntrain = LENGTH(VECTOR_ELT(weights, 0));
00484
00485 PROTECT(ans = allocVector(VECSXP, nobs));
00486
00487 for (i = 0; i < nobs; i++) {
00488 count = 0;
00489 SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00490 for (j = 0; j < ntrain; j++)
00491 REAL(bw)[j] = 0.0;
00492 for (b = 0; b < ntrees; b++) {
00493 tree = VECTOR_ELT(forest, b);
00494
00495 if (oob &&
00496 REAL(VECTOR_ELT(weights, b))[i] > 0.0)
00497 continue;
00498
00499 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00500
00501 for (j = 0; j < ntrain; j++) {
00502 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00503 REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j];
00504 }
00505 count++;
00506 }
00507 if (count == 0)
00508 error("cannot compute out-of-bag predictions for obs ", i + 1);
00509 }
00510 UNPROTECT(1);
00511 return(ans);
00512 }
00513
00514
00520 SEXP R_proximity(SEXP where) {
00521
00522 SEXP ans, bw, bin;
00523 int ntrees, nobs, i, b, j, iwhere;
00524
00525 ntrees = LENGTH(where);
00526 nobs = LENGTH(VECTOR_ELT(where, 0));
00527
00528 PROTECT(ans = allocVector(VECSXP, nobs));
00529 PROTECT(bin = allocVector(INTSXP, nobs));
00530
00531 for (i = 0; i < nobs; i++) {
00532 SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs));
00533 for (j = 0; j < nobs; j++) {
00534 REAL(bw)[j] = 0.0;
00535 INTEGER(bin)[j] = 0;
00536 }
00537 for (b = 0; b < ntrees; b++) {
00538
00539 if (INTEGER(VECTOR_ELT(where, b))[i] == 0)
00540 continue;
00541 iwhere = INTEGER(VECTOR_ELT(where, b))[i];
00542 for (j = 0; j < nobs; j++) {
00543 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00544
00545 REAL(bw)[j]++;
00546 if (INTEGER(VECTOR_ELT(where, b))[j] > 0)
00547
00548
00549 INTEGER(bin)[j]++;
00550 }
00551 }
00552 for (j = 0; j < nobs; j++)
00553 REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j];
00554 }
00555 UNPROTECT(2);
00556 return(ans);
00557 }