Basic feature selection for boosted trees.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 18 Jan 2018 09:35:01 +0000 (01:35 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 18 Jan 2018 09:38:27 +0000 (01:38 -0800)
The idea is that we grow the trees normally until we reach the requested number of unique features, and once we reach that limit, we avoid using any new features.

PiperOrigin-RevId: 182336278

tensorflow/contrib/boosted_trees/kernels/training_ops.cc
tensorflow/contrib/boosted_trees/proto/learner.proto
tensorflow/contrib/boosted_trees/proto/tree_config.proto
tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h

index c77d90e243c304ec8e9a10a0b63401f9bd825c3e..7f8dea1d3c2a04b725843f6e2932a0cdfbc7733c 100644 (file)
@@ -361,10 +361,27 @@ class GrowTreeEnsembleOp : public OpKernel {
     // Increment attempt stats.
     ensemble_resource->IncrementAttempts();
 
+    // In case we want to do feature selection and we have reached the limit,
+    // build a list of handlers used so far to avoid adding new features.
+    std::vector<int64> allowed_handlers;
+    if (learner_config_.constraints().max_number_of_unique_feature_columns() >
+        0) {
+      allowed_handlers = ensemble_resource->GetUsedHandlers();
+      // TODO(soroush): We can disable handlers that are not going to be used to
+      // avoid unnecessary computations.
+      if (allowed_handlers.size() <
+          learner_config_.constraints()
+              .max_number_of_unique_feature_columns()) {
+        // We have not reached the limit yet. Empty the list of allow features
+        // which means we can keep adding new features.
+        allowed_handlers.clear();
+      }
+    }
+
     // Find best splits for each active partition.
     std::map<int32, SplitCandidate> best_splits;
-    FindBestSplitsPerPartition(context, partition_ids_list, gains_list,
-                               splits_list, &best_splits);
+    FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list,
+                               gains_list, splits_list, &best_splits);
 
     // No-op if no new splits can be considered.
     if (best_splits.empty()) {
@@ -381,7 +398,8 @@ class GrowTreeEnsembleOp : public OpKernel {
 
     // Split tree nodes.
     for (auto& split_entry : best_splits) {
-      SplitTreeNode(split_entry.first, &split_entry.second, tree_config);
+      SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
+                    ensemble_resource);
     }
 
     // Post-prune finalized tree if needed.
@@ -403,12 +421,20 @@ class GrowTreeEnsembleOp : public OpKernel {
   // Helper method which effectively does a reduce over all split candidates
   // and finds the best split for each partition.
   void FindBestSplitsPerPartition(
-      OpKernelContext* const context, const OpInputList& partition_ids_list,
-      const OpInputList& gains_list, const OpInputList& splits_list,
+      OpKernelContext* const context,
+      const std::vector<int64>& allowed_handlers,  // Empty means all handlers.
+      const OpInputList& partition_ids_list, const OpInputList& gains_list,
+      const OpInputList& splits_list,
       std::map<int32, SplitCandidate>* best_splits) {
     // Find best split per partition going through every feature candidate.
     // TODO(salehay): Is this worth parallelizing?
     for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
+      if (!allowed_handlers.empty()) {
+        if (!std::binary_search(allowed_handlers.begin(),
+                                allowed_handlers.end(), handler_id)) {
+          continue;
+        }
+      }
       const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
       const auto& gains = gains_list[handler_id].vec<float>();
       const auto& splits = splits_list[handler_id].vec<string>();
@@ -592,8 +618,10 @@ class GrowTreeEnsembleOp : public OpKernel {
 
   // Helper method to split a tree node and append its respective
   // leaf children given the split candidate.
-  void SplitTreeNode(const int32 node_id, SplitCandidate* split,
-                     boosted_trees::trees::DecisionTreeConfig* tree_config) {
+  void SplitTreeNode(
+      const int32 node_id, SplitCandidate* split,
+      boosted_trees::trees::DecisionTreeConfig* tree_config,
+      boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
     // No-op if we have no real node.
     CHECK(node_id < tree_config->nodes_size())
         << "Invalid node " << node_id << " to split.";
@@ -633,6 +661,9 @@ class GrowTreeEnsembleOp : public OpKernel {
     // Replace node in tree.
     (*tree_config->mutable_nodes(node_id)) =
         *split->split_info.mutable_split_node();
+    if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
+      ensemble_resource->MaybeAddUsedHandler(split->handler_id);
+    }
   }
 
   void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
index 919e7cd81427c27cf892bc77998f52406d2bcf15..d84ba7438e7f03685d5bafca52ff8283f0fce898 100644 (file)
@@ -22,6 +22,10 @@ message TreeConstraintsConfig {
 
   // Min hessian weight per node.
   float min_node_weight = 2;
+
+  // Maximum number of unique features used in the tree. Zero means there is no
+  // limit.
+  int64 max_number_of_unique_feature_columns = 3;
 }
 
 // LearningRateConfig describes all supported learning rate tuners.
index fc570c1083d01a65760a456c109dad93afd9f62a..4407c4d981785a279b6296f4726a221cacb4c5b1 100644 (file)
@@ -128,6 +128,10 @@ message GrowingMetadata {
   // Number of layers that we have attempted to build. After pruning, these
   // layers might have been removed.
   int64 num_layers_attempted = 2;
+
+  // Sorted list of column handlers that have been used in at least one split
+  // so far.
+  repeated int64 used_handler_ids = 3;
 }
 
 // DecisionTreeEnsembleConfig describes an ensemble of decision trees.
index c2e65b643df90e88aadb0bb9acaf692da35b1a16..8ca1aabacaf53b66aaba184962922294427d6803 100644 (file)
@@ -63,7 +63,7 @@ def _gen_learner_config(num_classes,
   if dropout_prob_of_skipping is not None:
     config.learning_rate_tuner.dropout.dropout_prob_of_skipping = (
         dropout_prob_of_skipping)
-  return config.SerializeToString()
+  return config
 
 
 def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
@@ -145,7 +145,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
           growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
           # Dropout does not change anything here.
-          dropout_probability=0.5)
+          dropout_probability=0.5).SerializeToString()
 
       # Center bias for the initial step.
       grads = constant_op.constant([0.4, -0.3])
@@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
           growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
           # Dropout does not change anything here, tree is not finalized.
-          dropout_probability=0.5)
+          dropout_probability=0.5).SerializeToString()
 
       # Prepare handler inputs.
       # Note that handlers 1 & 3 have the same gain but different splits.
@@ -443,7 +443,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
           growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
           # Dropout does not change anything here - tree is not finalized.
-          dropout_probability=0.5)
+          dropout_probability=0.5).SerializeToString()
 
       # Prepare handler inputs.
       # Handler 1 only has a candidate for partition 1, handler 2 has candidates
@@ -632,7 +632,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           max_depth=1,
           min_node_weight=0,
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
-          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+          )
 
       # Prepare handler inputs.
       handler1_partitions = np.array([0], dtype=np.int32)
@@ -772,7 +773,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           max_depth=1,
           min_node_weight=0,
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
-          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+          )
 
       # Prepare handler inputs.
       # All handlers have negative gain.
@@ -837,7 +839,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           max_depth=1,
           min_node_weight=0,
           pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
-          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+          )
 
       # Prepare handler inputs.
       # Note that handlers 1 & 3 have the same gain but different splits.
@@ -943,7 +946,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           max_depth=2,
           min_node_weight=0,
           pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
-          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+          )
 
       # Prepare handler inputs.
       # All handlers have negative gain.
@@ -1090,7 +1094,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           max_depth=2,
           min_node_weight=0,
           pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
-          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
+          )
 
       # Prepare handler inputs.
       # Second handler has positive gain.
@@ -1330,7 +1335,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
           growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER,
           # Dropout will have no effect, since the tree will not be fully grown.
-          dropout_probability=1.0)
+          dropout_probability=1.0).SerializeToString()
 
       # Prepare handler inputs.
       # Handler 1 only has a candidate for partition 1, handler 2 has candidates
@@ -1538,7 +1543,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
           min_node_weight=0,
           pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
           growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
-          dropout_probability=1.0)
+          dropout_probability=1.0).SerializeToString()
 
       # Prepare handler inputs.
       handler1_partitions = np.array([0], dtype=np.int32)
@@ -1583,6 +1588,301 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
       self.assertEqual(
           2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates)
 
+  def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self):
+    """Test growing a tree with feature selection."""
+    with self.test_session() as session:
+      # Create existing ensemble with one root split and one bias tree.
+      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+      text_format.Merge("""
+        trees {
+          nodes {
+            leaf {
+              vector {
+                value: -0.32
+                value: 0.28
+              }
+            }
+          }
+        }
+        trees {
+          nodes {
+            categorical_id_binary_split {
+              feature_column: 3
+              feature_id: 7
+              left_id: 1
+              right_id: 2
+            }
+            node_metadata {
+              gain: 1.3
+            }
+          }
+          nodes {
+            leaf {
+              sparse_vector {
+                index: 0
+                value: 2.3
+              }
+            }
+          }
+          nodes {
+            leaf {
+              sparse_vector {
+                index: 0
+                value: -0.9
+              }
+            }
+          }
+        }
+        tree_weights: 0.7
+        tree_weights: 1
+        tree_metadata {
+          num_tree_weight_updates: 1
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        tree_metadata {
+          num_tree_weight_updates: 5
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 2
+          used_handler_ids: 2
+          used_handler_ids: 5
+        }
+      """, tree_ensemble_config)
+      tree_ensemble_handle = model_ops.tree_ensemble_variable(
+          stamp_token=0,
+          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+          name="tree_ensemble")
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare learner config.
+      learner_config = _gen_learner_config(
+          num_classes=2,
+          l1_reg=0,
+          l2_reg=0,
+          tree_complexity=0,
+          max_depth=1,
+          min_node_weight=0,
+          pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+      # There are 2 handler_ids in used_handler_ids already but one of them
+      # is handler 2, so we can still grow trees.
+      learner_config.constraints.max_number_of_unique_feature_columns = 2
+      learner_config = learner_config.SerializeToString()
+      # Prepare handler inputs.
+      handler1_partitions = np.array([0], dtype=np.int32)
+      handler1_gains = np.array([7.62], dtype=np.float32)
+      handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)]
+      handler2_partitions = np.array([0], dtype=np.int32)
+      handler2_gains = np.array([0.63], dtype=np.float32)
+      handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)]
+      handler3_partitions = np.array([0], dtype=np.int32)
+      handler3_gains = np.array([7.62], dtype=np.float32)
+      handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)]
+
+      # Grow tree ensemble.
+      grow_op = training_ops.grow_tree_ensemble(
+          tree_ensemble_handle,
+          stamp_token=0,
+          next_stamp_token=1,
+          learning_rate=1,
+          partition_ids=[
+              handler1_partitions, handler2_partitions, handler3_partitions
+          ],
+          gains=[handler1_gains, handler2_gains, handler3_gains],
+          splits=[handler1_split, handler2_split, handler3_split],
+          learner_config=learner_config,
+          dropout_seed=123,
+          center_bias=True)
+      session.run(grow_op)
+
+      # Expect a new tree to be added with the split from handler 1.
+      _, serialized = session.run(
+          model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+      tree_ensemble_config.ParseFromString(serialized)
+      self.assertEqual(3, len(tree_ensemble_config.trees))
+      self.assertEqual(
+          2, len(tree_ensemble_config.growing_metadata.used_handler_ids))
+
+  def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self):
+    """Test growing a tree with feature selection with empty ensemble."""
+    with self.test_session() as session:
+      # Create existing ensemble with one root split and one bias tree.
+      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+      tree_ensemble_handle = model_ops.tree_ensemble_variable(
+          stamp_token=0,
+          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+          name="tree_ensemble")
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare learner config.
+      learner_config = _gen_learner_config(
+          num_classes=2,
+          l1_reg=0,
+          l2_reg=0,
+          tree_complexity=0,
+          max_depth=1,
+          min_node_weight=0,
+          pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+      learner_config.constraints.max_number_of_unique_feature_columns = 2
+      learner_config = learner_config.SerializeToString()
+      # Prepare handler inputs.
+      handler1_partitions = np.array([0], dtype=np.int32)
+      handler1_gains = np.array([7.62], dtype=np.float32)
+      handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)]
+      handler2_partitions = np.array([0], dtype=np.int32)
+      handler2_gains = np.array([0.63], dtype=np.float32)
+      handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)]
+      handler3_partitions = np.array([0], dtype=np.int32)
+      handler3_gains = np.array([7.62], dtype=np.float32)
+      handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)]
+
+      # Grow tree ensemble.
+      grow_op = training_ops.grow_tree_ensemble(
+          tree_ensemble_handle,
+          stamp_token=0,
+          next_stamp_token=1,
+          learning_rate=1,
+          partition_ids=[
+              handler1_partitions, handler2_partitions, handler3_partitions
+          ],
+          gains=[handler1_gains, handler2_gains, handler3_gains],
+          splits=[handler1_split, handler2_split, handler3_split],
+          learner_config=learner_config,
+          dropout_seed=123,
+          center_bias=True)
+      session.run(grow_op)
+
+      _, serialized = session.run(
+          model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+      tree_ensemble_config.ParseFromString(serialized)
+      self.assertEqual(1, len(tree_ensemble_config.trees))
+      self.assertEqual(
+          1, len(tree_ensemble_config.growing_metadata.used_handler_ids))
+
+  def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self):
+    """Test growing a tree with feature selection with empty ensemble."""
+    with self.test_session() as session:
+      # Create existing ensemble with one root split and one bias tree.
+      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+      text_format.Merge("""
+        trees {
+          nodes {
+            leaf {
+              vector {
+                value: -0.32
+                value: 0.28
+              }
+            }
+          }
+        }
+        trees {
+          nodes {
+            categorical_id_binary_split {
+              feature_column: 3
+              feature_id: 7
+              left_id: 1
+              right_id: 2
+            }
+            node_metadata {
+              gain: 1.3
+            }
+          }
+          nodes {
+            leaf {
+              sparse_vector {
+                index: 0
+                value: 2.3
+              }
+            }
+          }
+          nodes {
+            leaf {
+              sparse_vector {
+                index: 0
+                value: -0.9
+              }
+            }
+          }
+        }
+        tree_weights: 0.7
+        tree_weights: 1
+        tree_metadata {
+          num_tree_weight_updates: 1
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        tree_metadata {
+          num_tree_weight_updates: 5
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 2
+          used_handler_ids: 4
+          used_handler_ids: 5
+        }
+      """, tree_ensemble_config)
+      tree_ensemble_handle = model_ops.tree_ensemble_variable(
+          stamp_token=0,
+          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+          name="tree_ensemble")
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare learner config.
+      learner_config = _gen_learner_config(
+          num_classes=2,
+          l1_reg=0,
+          l2_reg=0,
+          tree_complexity=0,
+          max_depth=1,
+          min_node_weight=0,
+          pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+          growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+      learner_config.constraints.max_number_of_unique_feature_columns = 2
+      learner_config = learner_config.SerializeToString()
+      # Prepare handler inputs.
+      handler1_partitions = np.array([0], dtype=np.int32)
+      handler1_gains = np.array([7.62], dtype=np.float32)
+      handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)]
+      handler2_partitions = np.array([0], dtype=np.int32)
+      handler2_gains = np.array([0.63], dtype=np.float32)
+      handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)]
+      handler3_partitions = np.array([0], dtype=np.int32)
+      handler3_gains = np.array([7.62], dtype=np.float32)
+      handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)]
+
+      # Grow tree ensemble.
+      grow_op = training_ops.grow_tree_ensemble(
+          tree_ensemble_handle,
+          stamp_token=0,
+          next_stamp_token=1,
+          learning_rate=1,
+          partition_ids=[
+              handler1_partitions, handler2_partitions, handler3_partitions
+          ],
+          gains=[handler1_gains, handler2_gains, handler3_gains],
+          splits=[handler1_split, handler2_split, handler3_split],
+          learner_config=learner_config,
+          dropout_seed=123,
+          center_bias=True)
+      session.run(grow_op)
+
+      _, serialized = session.run(
+          model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+      tree_ensemble_config.ParseFromString(serialized)
+      # We can't grow a tree since we have reached the limit of 2 unique
+      # features [4, 5] and the only available splits are from
+      # handlers [0, 1, 2].
+      self.assertEqual(2, len(tree_ensemble_config.trees))
+      self.assertEqual(
+          2, len(tree_ensemble_config.growing_metadata.used_handler_ids))
+
 
 if __name__ == "__main__":
   googletest.main()
index 284ad5cdb9abf374650940ade7bb36663d72c0dd..ad9c8961aaadbc4c1ff6bdc7793171d0ad48d75f 100644 (file)
@@ -111,6 +111,35 @@ class DecisionTreeEnsembleResource : public StampedResource {
     return decision_tree_ensemble_->tree_weights(index);
   }
 
+  void MaybeAddUsedHandler(const int32 handler_id) {
+    protobuf::RepeatedField<protobuf_int64>* used_ids =
+        decision_tree_ensemble_->mutable_growing_metadata()
+            ->mutable_used_handler_ids();
+    protobuf::RepeatedField<protobuf_int64>::iterator first =
+        std::lower_bound(used_ids->begin(), used_ids->end(), handler_id);
+    if (first == used_ids->end()) {
+      used_ids->Add(handler_id);
+      return;
+    }
+    if (handler_id == *first) {
+      // It is a duplicate entry.
+      return;
+    }
+    used_ids->Add(handler_id);
+    std::rotate(first, used_ids->end() - 1, used_ids->end());
+  }
+
+  std::vector<int64> GetUsedHandlers() const {
+    std::vector<int64> result;
+    result.reserve(
+        decision_tree_ensemble_->growing_metadata().used_handler_ids().size());
+    for (int64 h :
+         decision_tree_ensemble_->growing_metadata().used_handler_ids()) {
+      result.push_back(h);
+    }
+    return result;
+  }
+
   // Sets the weight of i'th tree, and increment num_updates in tree_metadata.
   void SetTreeWeight(const int32 index, const float weight,
                      const int32 increment_num_updates) {