00001
00009 #include "party.h"
00010
00011
00023 void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem,
00024 SEXP controls, int *where, int *nodenum, int depth) {
00025
00026 SEXP weights;
00027 int nobs, i, stop;
00028 double *dweights;
00029
00030 weights = S3get_nodeweights(node);
00031
00032
00033
00034 stop = (nodenum[0] == 2 || nodenum[0] == 3) &&
00035 get_stump(get_tgctrl(controls));
00036 stop = stop || !check_depth(get_tgctrl(controls), depth);
00037
00038 if (stop)
00039 C_Node(node, learnsample, weights, fitmem, controls, 1, depth);
00040 else
00041 C_Node(node, learnsample, weights, fitmem, controls, 0, depth);
00042
00043 S3set_nodeID(node, nodenum[0]);
00044
00045 if (!S3get_nodeterminal(node)) {
00046
00047 C_splitnode(node, learnsample, controls);
00048
00049
00050 if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
00051 C_surrogates(node, learnsample, weights, controls, fitmem);
00052 C_splitsurrogate(node, learnsample);
00053 }
00054
00055 nodenum[0] += 1;
00056 C_TreeGrow(S3get_leftnode(node), learnsample, fitmem,
00057 controls, where, nodenum, depth + 1);
00058
00059 nodenum[0] += 1;
00060 C_TreeGrow(S3get_rightnode(node), learnsample, fitmem,
00061 controls, where, nodenum, depth + 1);
00062
00063 } else {
00064 dweights = REAL(weights);
00065 nobs = get_nobs(learnsample);
00066 for (i = 0; i < nobs; i++)
00067 if (dweights[i] > 0) where[i] = nodenum[0];
00068 }
00069 }
00070
00071
00081 SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {
00082
00083 SEXP ans, nweights;
00084 double *dnweights, *dweights;
00085 int nobs, i, nodenum = 1;
00086
00087 GetRNGstate();
00088
00089 nobs = get_nobs(learnsample);
00090 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00091 C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
00092 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00093
00094 nweights = S3get_nodeweights(ans);
00095 dnweights = REAL(nweights);
00096 dweights = REAL(weights);
00097 for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];
00098
00099 C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1);
00100
00101 PutRNGstate();
00102
00103 UNPROTECT(1);
00104 return(ans);
00105 }