// Increment attempt stats.
ensemble_resource->IncrementAttempts();
+ // In case we want to do feature selection and we have reached the limit,
+ // build a list of handlers used so far to avoid adding new features.
+ std::vector<int64> allowed_handlers;
+ if (learner_config_.constraints().max_number_of_unique_feature_columns() >
+ 0) {
+ allowed_handlers = ensemble_resource->GetUsedHandlers();
+ // TODO(soroush): We can disable handlers that are not going to be used to
+ // avoid unnecessary computations.
+ if (allowed_handlers.size() <
+ learner_config_.constraints()
+ .max_number_of_unique_feature_columns()) {
+ // We have not reached the limit yet. Empty the list of allow features
+ // which means we can keep adding new features.
+ allowed_handlers.clear();
+ }
+ }
+
// Find best splits for each active partition.
std::map<int32, SplitCandidate> best_splits;
- FindBestSplitsPerPartition(context, partition_ids_list, gains_list,
- splits_list, &best_splits);
+ FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list,
+ gains_list, splits_list, &best_splits);
// No-op if no new splits can be considered.
if (best_splits.empty()) {
// Split tree nodes.
for (auto& split_entry : best_splits) {
- SplitTreeNode(split_entry.first, &split_entry.second, tree_config);
+ SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
+ ensemble_resource);
}
// Post-prune finalized tree if needed.
// Helper method which effectively does a reduce over all split candidates
// and finds the best split for each partition.
void FindBestSplitsPerPartition(
- OpKernelContext* const context, const OpInputList& partition_ids_list,
- const OpInputList& gains_list, const OpInputList& splits_list,
+ OpKernelContext* const context,
+ const std::vector<int64>& allowed_handlers, // Empty means all handlers.
+ const OpInputList& partition_ids_list, const OpInputList& gains_list,
+ const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate.
// TODO(salehay): Is this worth parallelizing?
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
+ if (!allowed_handlers.empty()) {
+ if (!std::binary_search(allowed_handlers.begin(),
+ allowed_handlers.end(), handler_id)) {
+ continue;
+ }
+ }
const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
const auto& gains = gains_list[handler_id].vec<float>();
const auto& splits = splits_list[handler_id].vec<string>();
// Helper method to split a tree node and append its respective
// leaf children given the split candidate.
- void SplitTreeNode(const int32 node_id, SplitCandidate* split,
- boosted_trees::trees::DecisionTreeConfig* tree_config) {
+ void SplitTreeNode(
+ const int32 node_id, SplitCandidate* split,
+ boosted_trees::trees::DecisionTreeConfig* tree_config,
+ boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
// No-op if we have no real node.
CHECK(node_id < tree_config->nodes_size())
<< "Invalid node " << node_id << " to split.";
// Replace node in tree.
(*tree_config->mutable_nodes(node_id)) =
*split->split_info.mutable_split_node();
+ if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
+ ensemble_resource->MaybeAddUsedHandler(split->handler_id);
+ }
}
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
if dropout_prob_of_skipping is not None:
config.learning_rate_tuner.dropout.dropout_prob_of_skipping = (
dropout_prob_of_skipping)
- return config.SerializeToString()
+ return config
def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
# Dropout does not change anything here.
- dropout_probability=0.5)
+ dropout_probability=0.5).SerializeToString()
# Center bias for the initial step.
grads = constant_op.constant([0.4, -0.3])
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
# Dropout does not change anything here, tree is not finalized.
- dropout_probability=0.5)
+ dropout_probability=0.5).SerializeToString()
# Prepare handler inputs.
# Note that handlers 1 & 3 have the same gain but different splits.
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
# Dropout does not change anything here - tree is not finalized.
- dropout_probability=0.5)
+ dropout_probability=0.5).SerializeToString()
# Prepare handler inputs.
# Handler 1 only has a candidate for partition 1, handler 2 has candidates
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+ )
# Prepare handler inputs.
handler1_partitions = np.array([0], dtype=np.int32)
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+ )
# Prepare handler inputs.
# All handlers have negative gain.
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+ )
# Prepare handler inputs.
# Note that handlers 1 & 3 have the same gain but different splits.
max_depth=2,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+ )
# Prepare handler inputs.
# All handlers have negative gain.
max_depth=2,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+ )
# Prepare handler inputs.
# Second handler has positive gain.
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER,
# Dropout will have no effect, since the tree will not be fully grown.
- dropout_probability=1.0)
+ dropout_probability=1.0).SerializeToString()
# Prepare handler inputs.
# Handler 1 only has a candidate for partition 1, handler 2 has candidates
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
- dropout_probability=1.0)
+ dropout_probability=1.0).SerializeToString()
# Prepare handler inputs.
handler1_partitions = np.array([0], dtype=np.int32)
self.assertEqual(
2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates)
+ def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self):
+ """Test growing a tree with feature selection."""
+ with self.test_session() as session:
+ # Create existing ensemble with one root split and one bias tree.
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge("""
+ trees {
+ nodes {
+ leaf {
+ vector {
+ value: -0.32
+ value: 0.28
+ }
+ }
+ }
+ }
+ trees {
+ nodes {
+ categorical_id_binary_split {
+ feature_column: 3
+ feature_id: 7
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 1.3
+ }
+ }
+ nodes {
+ leaf {
+ sparse_vector {
+ index: 0
+ value: 2.3
+ }
+ }
+ }
+ nodes {
+ leaf {
+ sparse_vector {
+ index: 0
+ value: -0.9
+ }
+ }
+ }
+ }
+ tree_weights: 0.7
+ tree_weights: 1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ num_tree_weight_updates: 5
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 2
+ used_handler_ids: 2
+ used_handler_ids: 5
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=1,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ # There are 2 handler_ids in used_handler_ids already but one of them
+ # is handler 2, so we can still grow trees.
+ learner_config.constraints.max_number_of_unique_feature_columns = 2
+ learner_config = learner_config.SerializeToString()
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([7.62], dtype=np.float32)
+ handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([0.63], dtype=np.float32)
+ handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([7.62], dtype=np.float32)
+ handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)]
+
+ # Grow tree ensemble.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config,
+ dropout_seed=123,
+ center_bias=True)
+ session.run(grow_op)
+
+ # Expect a new tree to be added with the split from handler 1.
+ _, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ tree_ensemble_config.ParseFromString(serialized)
+ self.assertEqual(3, len(tree_ensemble_config.trees))
+ self.assertEqual(
+ 2, len(tree_ensemble_config.growing_metadata.used_handler_ids))
+
+ def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self):
+ """Test growing a tree with feature selection with empty ensemble."""
+ with self.test_session() as session:
+ # Create existing ensemble with one root split and one bias tree.
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=1,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ learner_config.constraints.max_number_of_unique_feature_columns = 2
+ learner_config = learner_config.SerializeToString()
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([7.62], dtype=np.float32)
+ handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([0.63], dtype=np.float32)
+ handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([7.62], dtype=np.float32)
+ handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)]
+
+ # Grow tree ensemble.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config,
+ dropout_seed=123,
+ center_bias=True)
+ session.run(grow_op)
+
+ _, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ tree_ensemble_config.ParseFromString(serialized)
+ self.assertEqual(1, len(tree_ensemble_config.trees))
+ self.assertEqual(
+ 1, len(tree_ensemble_config.growing_metadata.used_handler_ids))
+
+ def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self):
+ """Test growing a tree with feature selection with empty ensemble."""
+ with self.test_session() as session:
+ # Create existing ensemble with one root split and one bias tree.
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge("""
+ trees {
+ nodes {
+ leaf {
+ vector {
+ value: -0.32
+ value: 0.28
+ }
+ }
+ }
+ }
+ trees {
+ nodes {
+ categorical_id_binary_split {
+ feature_column: 3
+ feature_id: 7
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 1.3
+ }
+ }
+ nodes {
+ leaf {
+ sparse_vector {
+ index: 0
+ value: 2.3
+ }
+ }
+ }
+ nodes {
+ leaf {
+ sparse_vector {
+ index: 0
+ value: -0.9
+ }
+ }
+ }
+ }
+ tree_weights: 0.7
+ tree_weights: 1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ num_tree_weight_updates: 5
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 2
+ used_handler_ids: 4
+ used_handler_ids: 5
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=1,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+ learner_config.constraints.max_number_of_unique_feature_columns = 2
+ learner_config = learner_config.SerializeToString()
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([7.62], dtype=np.float32)
+ handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([0.63], dtype=np.float32)
+ handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([7.62], dtype=np.float32)
+ handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)]
+
+ # Grow tree ensemble.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config,
+ dropout_seed=123,
+ center_bias=True)
+ session.run(grow_op)
+
+ _, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ tree_ensemble_config.ParseFromString(serialized)
+ # We can't grow a tree since we have reached the limit of 2 unique
+ # features [4, 5] and the only available splits are from
+ # handlers [0, 1, 2].
+ self.assertEqual(2, len(tree_ensemble_config.trees))
+ self.assertEqual(
+ 2, len(tree_ensemble_config.growing_metadata.used_handler_ids))
+
if __name__ == "__main__":
googletest.main()