/* ------------------------------------------------------ * * @file decision_tree.cpp * * @brief Decision Tree models for regression and classification * * */ /* ----------------------------------------------------------------------- */ #include #include #include #include #include #include #include #include "DT_proto.hpp" #include "DT_impl.hpp" #include "ConSplits.hpp" #include /* fabs */ #include "decision_tree.hpp" namespace madlib { // Use Eigen using namespace dbal::eigen_integration; namespace modules { namespace recursive_partitioning { enum { NOT_FINISHED=0, FINISHED, TERMINATED }; // ------------------------------------------------------------ // types // ------------------------------------------------------------ typedef DecisionTree MutableTree; typedef DecisionTree Tree; // Transition State for collecting statistics typedef TreeAccumulator LevelState; typedef TreeAccumulator MutableLevelState; // ------------------------------------------------------------ // functions // ------------------------------------------------------------ AnyType initialize_decision_tree::run(AnyType & args){ DecisionTree dt = DecisionTree(); bool is_regression_tree = args[0].getAs(); std::string impurity_func_str = args[1].getAs(); uint16_t n_y_labels = args[2].getAs(); uint16_t max_n_surr = args[3].getAs(); if (is_regression_tree) n_y_labels = REGRESS_N_STATS; dt.rebind(1u, n_y_labels, max_n_surr, is_regression_tree); dt.feature_indices(0) = dt.IN_PROCESS_LEAF; dt.feature_thresholds(0) = 0; dt.is_categorical(0) = 0; if (max_n_surr > 0){ dt.surr_indices.setConstant(dt.SURR_NON_EXISTING); dt.surr_thresholds.setConstant(0); dt.surr_status.setConstant(0); } dt.predictions.row(0).setConstant(0); dt.is_regression = is_regression_tree; if (dt.is_regression){ dt.impurity_type = dt.MSE; // only MSE defined for regression } else { if ((impurity_func_str.compare("misclassification") == 0) || (impurity_func_str.compare("misclass") == 0)) dt.impurity_type = dt.MISCLASS; else if ((impurity_func_str.compare("entropy") == 0) || (impurity_func_str.compare("cross-entropy") == 0)) dt.impurity_type = dt.ENTROPY; else dt.impurity_type = dt.GINI; // default impurity for classification } return dt.storage(); } // initialize_decision_tree // ------------------------------------------------------------ //////////////////////////////////////////////////////////////// // Functions to capture leaf stats for picking primary splits // //////////////////////////////////////////////////////////////// AnyType compute_leaf_stats_transition::run(AnyType & args){ MutableLevelState state = args[0].getAs(); LevelState::tree_type dt = args[1].getAs(); // need to change this according to the calling function if (state.terminated || args[4].isNull()) { return args[0]; } double response = args[4].getAs(); double weight = args[5].getAs(); if (weight < 0) throw std::runtime_error("Negative weights present in the data"); NativeIntegerVector cat_features; NativeColumnVector con_features; try { if (args[2].isNull()){ cat_features.rebind(this->allocateArray(0)); } else { NativeIntegerVector xx_cat = args[2].getAs(); cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size()); } if (args[3].isNull()){ con_features.rebind(this->allocateArray(0)); } else { NativeColumnVector xx_con = args[3].getAs(); con_features.rebind(xx_con.memoryHandle(), xx_con.size()); } } catch (const ArrayWithNullException &e) { return args[0]; } // cat_levels.size = n_cat_features NativeIntegerVector cat_levels; if (args[6].isNull()){ cat_levels.rebind(this->allocateArray(0)); } else { MutableNativeIntegerVector n_levels_per_cat = args[6].getAs(); for (Index i = 0; i < n_levels_per_cat.size(); i++){ n_levels_per_cat[i] -= 1; // ignore the last level since a split // like 'var <= last level' would move all rows to // a one side. Such a split will always be ignored // when selecting the best split. } cat_levels.rebind(n_levels_per_cat.memoryHandle(), n_levels_per_cat.size()); } // con_splits size = num_con_features x num_bins // When num_con_features = 0, the input will be an empty string that is read // as a ByteString ConSplitsResult splits_results = args[7].getAs(); // n_response_labels are the number of values the dependent variable takes uint16_t n_response_labels = args[8].getAs(); if (!dt.is_regression && n_response_labels <= 1){ // for classification, response should have at least two distinct values throw std::runtime_error("Invalid response variable for a classification" "tree. Should have " "more than one distinct value"); } if (state.empty()){ // To initialize the accumulator, first find which of the leaf nodes // in current tree are actually reachable. // The lookup vector maps the leaf node index in a (fictional) complete // tree to the index in the actual tree. ColumnVector leaf_feature_indices = dt.feature_indices.tail(dt.feature_indices.size()/2 + 1).cast(); IntegerVector leaf_node_lookup(leaf_feature_indices.size()); uint32_t n_leaves_not_finished = 0; for (Index i=0; i < leaf_feature_indices.size(); i++){ if ((leaf_feature_indices(i) != dt.NODE_NON_EXISTING) && (leaf_feature_indices(i) != dt.FINISHED_LEAF)){ leaf_node_lookup(i) = static_cast(n_leaves_not_finished); n_leaves_not_finished++; } else{ leaf_node_lookup(i) = -1; } } // see DT_proto.hpp for explanation on stats_per_split uint16_t stats_per_split = dt.is_regression ? REGRESS_N_STATS : static_cast(n_response_labels + 1); const bool weights_as_rows = args[9].getAs(); state.rebind(static_cast(splits_results.con_splits.cols()), static_cast(cat_features.size()), static_cast(con_features.size()), static_cast(cat_levels.sum()), static_cast(dt.tree_depth), stats_per_split, weights_as_rows, static_cast(n_leaves_not_finished) ); for (Index i=0; i < state.stats_lookup.size(); i++) state.stats_lookup(i) = leaf_node_lookup(i); // compute cumulative sum of the levels of the categorical variables int current_sum = 0; for (Index i=0; i < state.n_cat_features; ++i){ // Assuming that the levels of each categorical variable are ordered, // create splits of the form 'A <= t', where A has N levels // and t in [0, N-2]. // This split places all levels <= t on true node and // others on false node. Checking till N-2 instead of N-1 // since at least 1 level should go to false node. // Variable with just 1 level is maintained to ensure alignment, // even though the variable will not be used as a split feature. current_sum += cat_levels(i); state.cat_levels_cumsum(i) = current_sum; } } state << MutableLevelState::tuple_type(dt, cat_features, con_features, response, weight, cat_levels, splits_results.con_splits); return state.storage(); } // transition function // ------------------------------------------------------------ AnyType compute_leaf_stats_merge::run(AnyType & args){ MutableLevelState stateLeft = args[0].getAs(); LevelState stateRight = args[1].getAs(); if (stateLeft.empty()) { return stateRight.storage(); } if (!stateRight.empty()) { stateLeft << stateRight; } return stateLeft.storage(); } // merge function AnyType dt_apply::run(AnyType & args){ MutableTree dt = args[0].getAs(); LevelState curr_level = args[1].getAs(); // 0 = running, 1 = finished training, 2 = terminated prematurely uint16_t return_code; if (!curr_level.terminated){ ConSplitsResult con_splits_results = args[2].getAs(); uint16_t min_split = args[3].getAs(); uint16_t min_bucket = args[4].getAs(); uint16_t max_depth = args[5].getAs(); bool subsample = args[6].getAs(); int num_random_features = args[7].getAs(); bool finished = false; if (!subsample) { finished = dt.expand(curr_level, con_splits_results.con_splits, min_split, min_bucket, max_depth); } else { finished = dt.expand_by_sampling(curr_level, con_splits_results.con_splits, min_split, min_bucket, max_depth, num_random_features); } return_code = finished ? FINISHED : NOT_FINISHED; } else{ return_code = TERMINATED; // indicates termination due to error } AnyType output_tuple; output_tuple << dt.storage() << return_code << static_cast(dt.tree_depth - 1); return output_tuple; } // apply function // ------------------------------------------------------------------------- /////////////////////////////////////////////// // Functions to capture surrogate statistics // /////////////////////////////////////////////// AnyType compute_surr_stats_transition::run(AnyType & args){ MutableLevelState state = args[0].getAs(); LevelState::tree_type dt = args[1].getAs(); // need to change this according to the calling function if (state.terminated) { return args[0]; } NativeIntegerVector cat_features; if (args[2].isNull()){ cat_features.rebind(this->allocateArray(0)); } else { NativeIntegerVector xx_cat = args[2].getAs(); cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size()); } NativeColumnVector con_features; if (args[3].isNull()){ con_features.rebind(this->allocateArray(0)); } else { NativeColumnVector xx_con = args[3].getAs(); con_features.rebind(xx_con.memoryHandle(), xx_con.size()); } // cat_levels size = n_cat_features NativeIntegerVector cat_levels; if (args[4].isNull()){ cat_levels.rebind(this->allocateArray(0)); } else { MutableNativeIntegerVector xx_cat = args[4].getAs(); for (Index i = 0; i < xx_cat.size(); i++) xx_cat[i] -= 1; // ignore the last level cat_levels.rebind(xx_cat.memoryHandle(), xx_cat.size()); } // con_splits size = n_con_features x n_bins ConSplitsResult splits_results = args[5].getAs(); // tree_depth = 1 implies a single leaf node in the tree. // We compute surrogates only for internal nodes. Hence we need // the root be an internal node i.e. we need the tree_depth to be more than 1. if (dt.tree_depth > 1){ if (state.empty()){ // To initialize the accumulator, first find which of the last // level of internal nodes are actually reachable. ColumnVector final_internal_feature_indices = dt.feature_indices.segment(dt.feature_indices.size()/4, dt.feature_indices.size()/4 + 1).cast(); IntegerVector index_lookup(final_internal_feature_indices.size()); uint32_t n_internal_nodes_reachable = 0; for (Index i=0; i < final_internal_feature_indices.size(); i++){ if (final_internal_feature_indices(i) >= 0){ index_lookup(i) = static_cast(n_internal_nodes_reachable); n_internal_nodes_reachable++; } else{ index_lookup(i) = -1; } } // 1. We need to compute stats for parent of each leaf. // Hence the tree_depth is decremented by 1. // 2. We store 2 values for each surrogate split // 1st value is for <= split; 2nd value is for > split // (hence stats_per_split = 2) state.rebind(static_cast(splits_results.con_splits.cols()), static_cast(cat_features.size()), static_cast(con_features.size()), static_cast(cat_levels.sum()), static_cast(dt.tree_depth - 1), 2, false, // dummy, only used in compute_leaf_stat n_internal_nodes_reachable ); for (Index i = 0; i < state.stats_lookup.size(); i++) state.stats_lookup(i) = index_lookup(i); // compute cumulative sum of the levels of the categorical variables int current_sum = 0; for (Index i=0; i < state.n_cat_features; i++){ current_sum += cat_levels(i); state.cat_levels_cumsum(i) = current_sum; } } const int dup_count = args[6].getAs(); state << MutableLevelState::surr_tuple_type( dt, cat_features, con_features, cat_levels, splits_results.con_splits, dup_count); } return state.storage(); } AnyType dt_surr_apply::run(AnyType & args){ MutableTree dt = args[0].getAs(); LevelState curr_level_surr = args[1].getAs(); if (!curr_level_surr.terminated && dt.max_n_surr > 0){ ConSplitsResult con_splits_results = args[2].getAs(); dt.pickSurrogates(curr_level_surr, con_splits_results.con_splits); } return dt.storage(); } // apply function // ------------------------------------------------------------------------- /* @brief Return the probabilities of classes as prediction */ AnyType predict_dt_prob::run(AnyType &args){ if (args[0].isNull()){ return Null(); } Tree dt = args[0].getAs(); NativeIntegerVector cat_features; NativeColumnVector con_features; try { if (args[1].isNull()){ cat_features.rebind(this->allocateArray(0)); } else { NativeIntegerVector xx_cat = args[1].getAs(); cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size()); } if (args[2].isNull()){ con_features.rebind(this->allocateArray(0)); } else { NativeColumnVector xx_con = args[2].getAs(); con_features.rebind(xx_con.memoryHandle(), xx_con.size()); } } catch (const ArrayWithNullException &e) { return Null(); } ColumnVector prediction = dt.predict(cat_features, con_features); return prediction; } // ------------------------------------------------------------------------- /* @brief Return the regression prediction or class that corresponds to max prediction */ AnyType predict_dt_response::run(AnyType &args){ if (args[0].isNull()){ return Null(); } Tree dt = args[0].getAs(); NativeIntegerVector cat_features; NativeColumnVector con_features; try { if (args[1].isNull()){ cat_features.rebind(this->allocateArray(0)); } else { NativeIntegerVector xx_cat = args[1].getAs(); cat_features.rebind(xx_cat.memoryHandle(), xx_cat.size()); } if (args[2].isNull()){ con_features.rebind(this->allocateArray(0)); } else { NativeColumnVector xx_con = args[2].getAs(); con_features.rebind(xx_con.memoryHandle(), xx_con.size()); } } catch (const ArrayWithNullException &e) { // reach here only if surrogates are not used return Null(); } return dt.predict_response(cat_features, con_features); } AnyType display_decision_tree::run(AnyType &args) { Tree dt = args[0].getAs(); ArrayHandle cat_feature_names = args[1].getAs >(); ArrayHandle con_feature_names = args[2].getAs >(); ArrayHandle cat_levels_text = args[3].getAs >(); ArrayHandle cat_n_levels = args[4].getAs >(); ArrayHandle dependent_var_levels = args[5].getAs >(); std::string id_prefix = args[6].getAs(); bool verbose = args[7].getAs(); string tree_str = dt.display(cat_feature_names, con_feature_names, cat_levels_text, cat_n_levels, dependent_var_levels, id_prefix, verbose); return tree_str; } AnyType display_decision_tree_surrogate::run(AnyType &args) { Tree dt = args[0].getAs(); ArrayHandle cat_feature_names = args[1].getAs >(); ArrayHandle con_feature_names = args[2].getAs >(); ArrayHandle cat_levels_text = args[3].getAs >(); ArrayHandle cat_n_levels = args[4].getAs >(); return dt.surr_display(cat_feature_names, con_feature_names, cat_levels_text, cat_n_levels); } AnyType print_decision_tree::run(AnyType &args){ Tree dt = args[0].getAs(); AnyType tuple; tuple << static_cast(dt.tree_depth - 1) << dt.feature_indices << dt.feature_thresholds << dt.is_categorical << dt.predictions << dt.surr_indices << dt.surr_thresholds << dt.surr_status; return tuple; } AnyType compute_variable_importance::run(AnyType &args){ Tree dt = args[0].getAs(); const int n_cat_features = args[1].getAs(); const int n_con_features = args[2].getAs(); ColumnVector cat_var_importance = ColumnVector::Zero(n_cat_features); ColumnVector con_var_importance = ColumnVector::Zero(n_con_features); dt.computeVariableImportance(cat_var_importance, con_var_importance); ColumnVector combined_var_imp(n_cat_features + n_con_features); combined_var_imp << cat_var_importance, con_var_importance; return combined_var_imp; } AnyType display_text_tree::run(AnyType &args){ Tree dt = args[0].getAs(); ArrayHandle cat_feature_names = args[1].getAs >(); ArrayHandle con_feature_names = args[2].getAs >(); ArrayHandle cat_levels_text = args[3].getAs >(); ArrayHandle cat_n_levels = args[4].getAs >(); ArrayHandle dep_levels = args[5].getAs >(); return dt.print(0, cat_feature_names, con_feature_names, cat_levels_text, cat_n_levels, dep_levels, 1u); } // ------------------------------------------------------------ // Prune the tree model using cost-complexity parameter // ------------------------------------------------------------ // Remove me's subtree and make it a leaf void mark_subtree_removal_recur(MutableTree &dt, int me) { if (me < dt.predictions.rows() && dt.feature_indices(me) != dt.NODE_NON_EXISTING) { int left = static_cast(dt.trueChild(static_cast(me))); int right = static_cast(dt.falseChild(static_cast(me))); mark_subtree_removal_recur(dt, left); mark_subtree_removal_recur(dt, right); dt.feature_indices(me) = dt.NODE_NON_EXISTING; } } void mark_subtree_removal(MutableTree &dt, int me) { mark_subtree_removal_recur(dt, me); dt.feature_indices(me) = dt.FINISHED_LEAF; } // ------------------------------------------------------------ /* * Data structure that contains the info for lower sub-trees */ struct SubTreeInfo { /* id of the root node of the subtree */ int root_id; /* number of node splits */ int n_split; /* current node's own risk */ double risk; /* * accumulated risk of sub-tree with the current * node being the root */ double sum_risk; /* sub-tree's average risk improvement per split */ double complexity; SubTreeInfo * left_child; SubTreeInfo * right_child; SubTreeInfo(int i, int n, double r, double s, double c): root_id(i), n_split(n), risk(r), sum_risk(s), complexity(c){ left_child = NULL; right_child = NULL; } }; // FIXME: Remove after finalzing code template string print_debug_list(IterableContainer debug_list){ std::stringstream debug; typename IterableContainer::iterator it; for (it = debug_list.begin(); it != debug_list.end(); it++){ debug << std::setprecision(8) << *it << ", "; } return debug.str(); } /* * Pruning the tree by setting the pruned nodes' * feature_indices value to be NODE_NON_EXISTING. * * Closely follow rpart's implementation. Please read the * source code of rpart/src/partition.c */ SubTreeInfo prune_tree(MutableTree &dt, int me, double alpha, double estimated_complexity, std::vector & node_complexities) { if (me >= dt.feature_indices.size() || /* out of range */ dt.feature_indices(me) == dt.NODE_NON_EXISTING) return SubTreeInfo(-1, 0, 0, 0, 0); double risk = dt.computeRisk(me); double adjusted_risk = risk > estimated_complexity ? estimated_complexity : risk; if (adjusted_risk <= alpha) { /* If the current node's risk is smaller than alpha, then the risk can * never decrease more than alpha by splitting. Remove the current * node's subtree and make the current node a leaf. */ mark_subtree_removal(dt, me); node_complexities[me] = alpha; return SubTreeInfo(me, 0, risk, risk, alpha); } if (dt.feature_indices(me) >= 0) { SubTreeInfo left = prune_tree(dt, 2*me+1, alpha, adjusted_risk - alpha, node_complexities); double left_improve_per_split = (risk - left.sum_risk) / (left.n_split + 1); double left_child_improve = risk - left.risk; if (left_improve_per_split < left_child_improve) left_improve_per_split = left_child_improve; adjusted_risk = left_improve_per_split > estimated_complexity ? estimated_complexity : left_improve_per_split; SubTreeInfo right = prune_tree(dt, 2*me+2, alpha, adjusted_risk - alpha, node_complexities); /* * Closely follow rpart's algorithm, in rpart/src/partition.c * * If the average improvement of risk per split is larger * than the sub-tree's average improvement, the current * split is important. And we need to manually increase * the value of the current split's improvement, which * aims at keeping the current split if possible. */ double left_risk = left.sum_risk; double right_risk = right.sum_risk; int left_n_split = left.n_split; int right_n_split = right.n_split; double tempcp = (risk - (left_risk + right_risk)) / (left_n_split + right_n_split + 1); if (right.complexity > left.complexity) { if (tempcp > left.complexity) { left_risk = left.risk; left_n_split = 0; tempcp = (risk - (left_risk + right_risk)) / (left_n_split + right_n_split + 1); if (tempcp > right.complexity) { right_risk = right.risk; right_n_split = 0; } } } else if (tempcp > right.complexity) { right_risk = right.risk; right_n_split = 0; tempcp = (risk - (left_risk + right_risk)) / (left_n_split + right_n_split + 1); if (tempcp > left.complexity) { left_risk = left.risk; left_n_split = 0; } } double complexity = (risk - (left_risk + right_risk)) / (left_n_split + right_n_split + 1); if (complexity <= alpha) { /* Prune this split by removing the subtree */ mark_subtree_removal(dt, me); node_complexities[me] = alpha; return SubTreeInfo(me, 0, risk, risk, alpha); } else { node_complexities[me] = complexity; return SubTreeInfo(me, left_n_split + right_n_split + 1, risk, left_risk + right_risk, complexity); } } // end of if (dt.feature_indices(me) >= 0) else { // node is a leaf node node_complexities[me] = alpha; return SubTreeInfo(me, 0, risk, risk, alpha); } } // ------------------------------------------------------------ void make_cp_list(MutableTree & dt, std::vector & node_complexities, const double & alpha, std::list & cp_list, const double root_risk){ cp_list.clear(); cp_list.push_back(node_complexities[0] / root_risk); for (uint i = 1; i < node_complexities.size(); i++){ Index parent_id = dt.parentIndex(i); if (dt.feature_indices(i) != dt.NODE_NON_EXISTING && dt.feature_indices(parent_id) != dt.NODE_NON_EXISTING){ double parent_cp = node_complexities[parent_id]; if (node_complexities[i] > parent_cp) node_complexities[i] = parent_cp; double current_cp = node_complexities[i]; if (current_cp < alpha) current_cp = alpha; // don't explore any cp less than alpha if (current_cp < parent_cp){ // original complexity is scaled by root_risk. But user // expects an unscaled cp value current_cp /= root_risk; std::list::iterator it; bool skip_cp = false; for (it = cp_list.begin(); !skip_cp && it != cp_list.end(); it++){ if (fabs(current_cp - *it) < 1e-4) skip_cp = true; if (current_cp > *it) break; } if (!skip_cp) cp_list.insert(it, current_cp); } } } } // ------------------------------------------------------------------------- AnyType prune_and_cplist::run(AnyType &args){ MutableTree dt = args[0].getAs(); double cp = args[1].getAs(); bool compute_cp_list = args[2].getAs(); // We use a scaled version of risk (similar to rpart's definition). // The risk is relative to a tree with no splits (single node tree). double root_risk = dt.computeRisk(0); double alpha = cp * root_risk; std::vector node_complexities(dt.feature_indices.size(), alpha); prune_tree(dt, 0, alpha, root_risk, node_complexities); // Get the new tree_depth after pruning // Note: externally, tree_depth starts from 0 but DecisionTree assumes // tree_depth starts from 1 uint16_t pruned_depth = static_cast(dt.recomputeTreeDepth() - 1); AnyType output_tuple; if (compute_cp_list){ std::list cp_list; make_cp_list(dt, node_complexities, alpha, cp_list, root_risk); // we copy to a vector since we currently only have << operator // defined for a vector ColumnVector cp_vector(cp_list.size()); std::list::iterator it = cp_list.begin(); for (Index i = 0; it != cp_list.end(); it++, i++){ cp_vector(i) = *it; } output_tuple << dt.storage() << pruned_depth << cp_vector; } else { output_tuple << dt.storage() << pruned_depth; } return output_tuple; } // ------------------------------------------------------------ // Helper function for PivotalR // Convert the result into rpart's frame item in the result /* * Fil a row of the frame matrix using data in tree */ void fill_row(MutableNativeMatrix &frame, Tree &dt, int me, int i, int n_cats) { frame(i,0) = static_cast(dt.encodeIndex(dt.feature_indices(me), dt.is_categorical(me), n_cats)); frame(i,5) = 1; // complexity is not needed in plotting frame(i,6) = 0; // ncompete is not needed in plotting // How many surrogate variables have been computed for this split int n_surrogates = 0; for (int ii = 0; ii < dt.max_n_surr; ii ++) { if (dt.surr_indices(me * dt.max_n_surr + ii) >= 0) { n_surrogates ++; } } frame(i,7) = n_surrogates; if (dt.is_regression) { frame(i,1) = dt.predictions(me,3); // n frame(i,2) = dt.predictions(me,0); // wt frame(i,3) = dt.computeRisk(me); // weighted variance frame(i,4) = dt.predictions(me, 1) / dt.predictions(me,0); // yval } else { double total_records = dt.nodeWeightedCount(0); double n_records_innode = static_cast(dt.nodeCount(me)); double n_records_weighted_innode = dt.nodeWeightedCount(me); int n_dep_levels = static_cast(dt.n_y_labels); // FIXME use weight sum as the total number frame(i,1) = n_records_innode; frame(i,2) = n_records_weighted_innode; frame(i,3) = dt.computeMisclassification(me); Index max_index; dt.predictions.row(me).head(dt.n_y_labels).maxCoeff(&max_index); // start from 1 to be consistent with R convention frame(i,4) = static_cast(max_index + 1); frame(i,8) = frame(i,4); for (int j = 0; j < n_dep_levels; ++j) { frame(i,9 + j) = dt.predictions(me, j); frame(i,9 + j + n_dep_levels) = dt.predictions(me, j) / n_records_innode; } frame(i,9 + 2 * n_dep_levels) = n_records_innode / total_records; } } // ------------------------------------------------------------ /* * Recursively transverse the tree in a depth first way, * and fill all row of frame at the same time */ void transverse_tree(Tree &dt, MutableNativeMatrix &frame, int me, int &row, int n_cats) { if (me < dt.feature_indices.size() && dt.feature_indices(me) != dt.NODE_NON_EXISTING) { fill_row(frame, dt, me, row++, n_cats); transverse_tree(dt, frame, static_cast(dt.falseChild(me)), row, n_cats); transverse_tree(dt, frame, static_cast(dt.trueChild(me)), row, n_cats); } } // ------------------------------------------------------------ AnyType convert_to_rpart_format::run(AnyType &args) { Tree dt = args[0].getAs(); int n_cats = args[1].getAs(); // number of nodes in the tree int n_nodes = 0; for (int i = 0; i < dt.feature_indices.size(); ++i) { if (dt.feature_indices(i) != dt.NODE_NON_EXISTING) n_nodes++; } // number of columns in rpart frame int n_col; if (dt.is_regression) { n_col = 8; } else { n_col = 10 + 2 * dt.n_y_labels; } MutableNativeMatrix frame(this->allocateArray( n_col, n_nodes), n_nodes, n_col); int row = 0; transverse_tree(dt, frame, 0, row, n_cats); return frame; } // transverse the tree to get the internal nodes' split thresholds void transverse_tree_thresh(const Tree &dt, MutableNativeMatrix &thresh, int me, int &row, int n_cats) { // only for non-leaf nodes if (dt.feature_indices(me) >= 0) { // primary thresh(row,0) = dt.encodeIndex(dt.feature_indices(me), dt.is_categorical(me), n_cats); thresh(row,1) = dt.feature_thresholds(me); row++; // surrogates for (int ii = 0; ii < dt.max_n_surr; ii ++) { int surr_ii = me * dt.max_n_surr + ii; if (dt.surr_indices(surr_ii) >= 0) { thresh(row,0) = dt.encodeIndex(dt.surr_indices(surr_ii), dt.surr_status(surr_ii) == 1 || dt.surr_status(surr_ii) == -1, n_cats); thresh(row,1) = dt.surr_thresholds(surr_ii); row++; } } transverse_tree_thresh(dt, thresh, static_cast(dt.falseChild(me)), row, n_cats); transverse_tree_thresh(dt, thresh, static_cast(dt.trueChild(me)), row, n_cats); } } // ------------------------------------------------------------ AnyType get_split_thresholds::run(AnyType &args) { Tree dt = args[0].getAs(); int n_cats = args[1].getAs(); // number of internal nodes int in_nodes = 0; // count how many surrogate variables in the whole tree int tot_surr_n = 0; for (int i = 0; i < dt.feature_indices.size(); ++i) { if (dt.feature_indices(i) >= 0) { in_nodes++; for (int ii = 0; ii < dt.max_n_surr; ii ++) { if (dt.surr_indices(i * dt.max_n_surr + ii) >= 0) { tot_surr_n ++; } } } } MutableNativeMatrix thresh( this->allocateArray(2, in_nodes + tot_surr_n), in_nodes + tot_surr_n, 2); int row = 0; transverse_tree_thresh(dt, thresh, 0, row, n_cats); return thresh; } // ------------------------------------------------------------ /* * PivotalR: randomForest * Fil a row of the frame matrix using data in tree */ void fill_one_row(MutableNativeMatrix &frame, Tree &dt, int me, int i, int &node_index) { int feature_index = dt.feature_indices(me); if (feature_index == dt.FINISHED_LEAF) { frame(i,0) = 0; frame(i,1) = 0; frame(i,2) = 0; frame(i,4) = -1; node_index--; } else { frame(i,0) = node_index * 2; frame(i,1) = node_index * 2 + 1; frame(i,4) = 1; } frame(i,2) = feature_index; frame(i,3) = dt.feature_thresholds(me); if (dt.is_regression) { frame(i,5) = dt.predictions(me,1) / dt.predictions(me,0); // yval } else { Index max_index; dt.predictions.row(me).head(dt.n_y_labels).maxCoeff(&max_index); // start from 1 to be consistent with R convention frame(i,5) = static_cast(max_index + 1); } } /* * PivotalR: randomForest * Convert to R's randomForest format for getTree(..) function * */ AnyType convert_to_random_forest_format::run(AnyType &args) { Tree dt = args[0].getAs(); // number of nodes in the tree int n_nodes = 0; for (int i = 0; i < dt.feature_indices.size(); ++i) { if (dt.feature_indices(i) != dt.NODE_NON_EXISTING) n_nodes++; } // number of columns in randomForest frame MutableNativeMatrix frame(this->allocateArray( 6, n_nodes), n_nodes, 6); int row = 0; int node_index = 1; for (int i = 0; i