Automated g4 rollback of changelist 198087342
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 May 2018 23:07:25 +0000 (16:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 25 May 2018 23:10:20 +0000 (16:10 -0700)
PiperOrigin-RevId: 198117552

tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
tensorflow/contrib/boosted_trees/estimator_batch/model.py
tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py

index c8d401b..89d0d61 100644 (file)
@@ -41,8 +41,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
                feature_engineering_fn=None,
                logits_modifier_function=None,
                center_bias=True,
-               use_core_libs=False,
-               output_leaf_index=False):
+               use_core_libs=False):
     """Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
 
     Args:
@@ -67,14 +66,6 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
         the bias.
       use_core_libs: Whether feature columns and loss are from the core (as
         opposed to contrib) version of tensorflow.
-      output_leaf_index: whether to output leaf indices along with predictions
-        during inference. The leaf node indexes are available in predictions
-        dict by the key 'leaf_index'. For example,
-        result_dict = classifier.predict(...)
-        for example_prediction_result in result_dict:
-          # access leaf index list by example_prediction_result["leaf_index"]
-          # which contains one leaf index per tree
-
     Raises:
       ValueError: If learner_config is not valid.
     """
@@ -83,9 +74,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
       # supports second order derivative.
       def loss_fn(labels, logits, weights=None):
         result = losses.per_example_maxent_loss(
-            labels=labels,
-            logits=logits,
-            weights=weights,
+            labels=labels, logits=logits, weights=weights,
             num_classes=n_classes)
         return math_ops.reduce_mean(result[0])
     else:
@@ -113,7 +102,6 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
             'center_bias': center_bias,
             'logits_modifier_function': logits_modifier_function,
             'use_core_libs': use_core_libs,
-            'output_leaf_index': output_leaf_index,
         },
         model_dir=model_dir,
         config=config,
@@ -136,8 +124,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
                feature_engineering_fn=None,
                logits_modifier_function=None,
                center_bias=True,
-               use_core_libs=False,
-               output_leaf_index=False):
+               use_core_libs=False):
     """Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
 
     Args:
@@ -164,13 +151,6 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
         the bias.
       use_core_libs: Whether feature columns and loss are from the core (as
         opposed to contrib) version of tensorflow.
-      output_leaf_index: whether to output leaf indices along with predictions
-        during inference. The leaf node indexes are available in predictions
-        dict by the key 'leaf_index'. For example,
-        result_dict = classifier.predict(...)
-        for example_prediction_result in result_dict:
-          # access leaf index list by example_prediction_result["leaf_index"]
-          # which contains one leaf index per tree
     """
     head = head_lib.regression_head(
         label_name=label_name,
@@ -193,7 +173,6 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
             'logits_modifier_function': logits_modifier_function,
             'center_bias': center_bias,
             'use_core_libs': use_core_libs,
-            'output_leaf_index': False,
         },
         model_dir=model_dir,
         config=config,
@@ -218,8 +197,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
                feature_engineering_fn=None,
                logits_modifier_function=None,
                center_bias=True,
-               use_core_libs=False,
-               output_leaf_index=False):
+               use_core_libs=False):
     """Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
 
     Args:
@@ -242,13 +220,6 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
         the bias.
       use_core_libs: Whether feature columns and loss are from the core (as
         opposed to contrib) version of tensorflow.
-      output_leaf_index: whether to output leaf indices along with predictions
-        during inference. The leaf node indexes are available in predictions
-        dict by the key 'leaf_index'. For example,
-        result_dict = classifier.predict(...)
-        for example_prediction_result in result_dict:
-          # access leaf index list by example_prediction_result["leaf_index"]
-          # which contains one leaf index per tree
     """
     super(GradientBoostedDecisionTreeEstimator, self).__init__(
         model_fn=model.model_builder,
@@ -262,7 +233,6 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
             'logits_modifier_function': logits_modifier_function,
             'center_bias': center_bias,
             'use_core_libs': use_core_libs,
-            'output_leaf_index': False,
         },
         model_dir=model_dir,
         config=config,
index fe91e52..0d58317 100644 (file)
@@ -62,34 +62,12 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
         examples_per_layer=3,
         model_dir=model_dir,
         config=config,
-        feature_columns=[contrib_feature_column.real_valued_column("x")],
-        output_leaf_index=False)
+        feature_columns=[contrib_feature_column.real_valued_column("x")])
 
     classifier.fit(input_fn=_train_input_fn, steps=15)
     classifier.evaluate(input_fn=_eval_input_fn, steps=1)
     classifier.export(self._export_dir_base)
 
-  def testThatLeafIndexIsInPredictions(self):
-    learner_config = learner_pb2.LearnerConfig()
-    learner_config.num_classes = 2
-    learner_config.constraints.max_tree_depth = 1
-    model_dir = tempfile.mkdtemp()
-    config = run_config.RunConfig()
-
-    classifier = estimator.GradientBoostedDecisionTreeClassifier(
-        learner_config=learner_config,
-        num_trees=1,
-        examples_per_layer=3,
-        model_dir=model_dir,
-        config=config,
-        feature_columns=[contrib_feature_column.real_valued_column("x")],
-        output_leaf_index=True)
-
-    classifier.fit(input_fn=_train_input_fn, steps=15)
-    result_dict = classifier.predict(input_fn=_eval_input_fn)
-    for prediction_item in result_dict:
-      self.assertTrue("leaf_index" in prediction_item)
-
   def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self):
     learner_config = learner_pb2.LearnerConfig()
     learner_config.num_classes = 2
@@ -109,8 +87,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
         model_dir=model_dir,
         config=config,
         feature_columns=[core_feature_column.numeric_column("x")],
-        use_core_libs=True,
-        output_leaf_index=False)
+        use_core_libs=True)
 
     model.fit(input_fn=_train_input_fn, steps=15)
     model.evaluate(input_fn=_eval_input_fn, steps=1)
@@ -130,8 +107,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
         model_dir=model_dir,
         config=config,
         feature_columns=[core_feature_column.numeric_column("x")],
-        use_core_libs=True,
-        output_leaf_index=False)
+        use_core_libs=True)
 
     classifier.fit(input_fn=_train_input_fn, steps=15)
     classifier.evaluate(input_fn=_eval_input_fn, steps=1)
@@ -151,8 +127,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
         model_dir=model_dir,
         config=config,
         feature_columns=[core_feature_column.numeric_column("x")],
-        use_core_libs=True,
-        output_leaf_index=False)
+        use_core_libs=True)
 
     regressor.fit(input_fn=_train_input_fn, steps=15)
     regressor.evaluate(input_fn=_eval_input_fn, steps=1)
index 1ee8911..15ab6d8 100644 (file)
@@ -63,8 +63,6 @@ def model_builder(features, labels, mode, params, config):
   num_trees = params["num_trees"]
   use_core_libs = params["use_core_libs"]
   logits_modifier_function = params["logits_modifier_function"]
-  output_leaf_index = params["output_leaf_index"]
-
   if features is None:
     raise ValueError("At least one feature must be specified.")
 
@@ -98,8 +96,7 @@ def model_builder(features, labels, mode, params, config):
       feature_columns=feature_columns,
       logits_dimension=head.logits_dimension,
       features=training_features,
-      use_core_columns=use_core_libs,
-      output_leaf_index=output_leaf_index)
+      use_core_columns=use_core_libs)
   with ops.name_scope("gbdt", "gbdt_optimizer"):
     predictions_dict = gbdt_model.predict(mode)
     logits = predictions_dict["predictions"]
@@ -130,9 +127,6 @@ def model_builder(features, labels, mode, params, config):
         labels=labels,
         train_op_fn=_train_op_fn,
         logits=logits)
-  if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
-    model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
-        gbdt_batch.LEAF_INDEX]
   if num_trees:
     if center_bias:
       num_trees += 1
index dcce8bc..b3fe386 100644 (file)
@@ -59,7 +59,6 @@ const char* kApplyDropoutAttributeName = "apply_dropout";
 const char* kApplyAveragingAttributeName = "apply_averaging";
 const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights";
 const char* kPredictionsTensorName = "predictions";
-const char* kLeafIndexTensorName = "leaf_index";
 
 void CalculateTreesToInclude(
     const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
@@ -171,16 +170,15 @@ class GradientTreesPredictionOp : public OpKernel {
     core::ScopedUnref unref_me(ensemble_resource);
     if (use_locking_) {
       tf_shared_lock l(*ensemble_resource->get_mutex());
-      DoCompute(context, ensemble_resource, false);
+      DoCompute(context, ensemble_resource);
     } else {
-      DoCompute(context, ensemble_resource, false);
+      DoCompute(context, ensemble_resource);
     }
   }
 
- protected:
-  virtual void DoCompute(OpKernelContext* context,
-                         DecisionTreeEnsembleResource* ensemble_resource,
-                         const bool is_output_leaf_index) {
+ private:
+  void DoCompute(OpKernelContext* context,
+                 DecisionTreeEnsembleResource* ensemble_resource) {
     // Read dense float features list;
     OpInputList dense_float_features_list;
     OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
@@ -269,14 +267,6 @@ class GradientTreesPredictionOp : public OpKernel {
                                           &output_predictions_t));
     auto output_predictions = output_predictions_t->matrix<float>();
 
-    // Allocate output leaf index matrix.
-    Tensor* output_leaf_index_t = nullptr;
-    if (is_output_leaf_index) {
-      OP_REQUIRES_OK(context, context->allocate_output(
-                                  kLeafIndexTensorName,
-                                  {batch_size, ensemble_resource->num_trees()},
-                                  &output_leaf_index_t));
-    }
     // Run predictor.
     thread::ThreadPool* const worker_threads =
         context->device()->tensorflow_cpu_worker_threads()->workers;
@@ -298,13 +288,11 @@ class GradientTreesPredictionOp : public OpKernel {
             i, weight * (num_ensembles - i + start_averaging) / num_ensembles);
       }
       MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features,
-                                     worker_threads, output_predictions,
-                                     output_leaf_index_t);
+                                     worker_threads, output_predictions);
     } else {
       MultipleAdditiveTrees::Predict(
           ensemble_resource->decision_tree_ensemble(), trees_to_include,
-          batch_features, worker_threads, output_predictions,
-          output_leaf_index_t);
+          batch_features, worker_threads, output_predictions);
     }
 
     // Output dropped trees and original weights.
@@ -314,6 +302,7 @@ class GradientTreesPredictionOp : public OpKernel {
                                 {2, static_cast<int64>(dropped_trees.size())},
                                 &output_dropout_info_t));
     auto output_dropout_info = output_dropout_info_t->matrix<float>();
+
     for (int32 i = 0; i < dropped_trees.size(); ++i) {
       output_dropout_info(0, i) = dropped_trees[i];
       output_dropout_info(1, i) = original_weights[i];
@@ -337,26 +326,6 @@ class GradientTreesPredictionOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU),
                         GradientTreesPredictionOp);
 
-// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp
-// and have an additional output of tensor of rank 2 containing leaf ids for
-// each tree where an instance ended up with.
-class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
- public:
-  explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context)
-      : GradientTreesPredictionOp(context) {}
-
- protected:
-  void DoCompute(OpKernelContext* context,
-                 DecisionTreeEnsembleResource* ensemble_resource,
-                 bool is_output_leaf_index) override {
-    GradientTreesPredictionOp::DoCompute(context, ensemble_resource, true);
-  }
-};
-
-REGISTER_KERNEL_BUILDER(
-    Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU),
-    GradientTreesPredictionVerboseOp);
-
 class GradientTreesPartitionExamplesOp : public OpKernel {
  public:
   explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context)
index ee664f1..43b00d4 100644 (file)
@@ -26,8 +26,7 @@ void MultipleAdditiveTrees::Predict(
     const std::vector<int32>& trees_to_include,
     const boosted_trees::utils::BatchFeatures& features,
     tensorflow::thread::ThreadPool* const worker_threads,
-    tensorflow::TTypes<float>::Matrix output_predictions,
-    Tensor* output_leaf_indices) {
+    tensorflow::TTypes<float>::Matrix output_predictions) {
   // Zero out predictions as the model is additive.
   output_predictions.setZero();
 
@@ -39,8 +38,7 @@ void MultipleAdditiveTrees::Predict(
 
   // Lambda for doing a block of work.
   auto update_predictions = [&config, &features, &trees_to_include,
-                             &output_predictions,
-                             &output_leaf_indices](int64 start, int64 end) {
+                             &output_predictions](int64 start, int64 end) {
     auto examples_iterable = features.examples_iterable(start, end);
     for (const auto& example : examples_iterable) {
       for (const int32 tree_idx : trees_to_include) {
@@ -49,11 +47,6 @@ void MultipleAdditiveTrees::Predict(
         const float tree_weight = config.tree_weights(tree_idx);
         const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
         QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
-        // Checks if output leaf tree index is required.
-        if (output_leaf_indices != nullptr) {
-          output_leaf_indices->matrix<int>()(example.example_idx, tree_idx) =
-              leaf_idx;
-        }
         const auto& leaf_node = tree.nodes(leaf_idx);
         QCHECK(leaf_node.has_leaf())
             << "Invalid leaf node: " << leaf_node.DebugString();
index be7c155..cc3dc22 100644 (file)
@@ -33,17 +33,12 @@ class MultipleAdditiveTrees {
  public:
   // Predict runs tree ensemble on the given batch and updates
   // output predictions accordingly, for the given list of trees.
-  // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not
-  // null, this method fills output_leaf_indices with a per-tree leaf id where
-  // each of the instances from 'features' ended up in. Its shape is num
-  // examples X num of trees. When nullptr, leaf ids are not output of trees.
   static void Predict(
       const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
       const std::vector<int32>& trees_to_include,
       const boosted_trees::utils::BatchFeatures& features,
       tensorflow::thread::ThreadPool* const worker_threads,
-      tensorflow::TTypes<float>::Matrix output_predictions,
-      Tensor* output_leaf_indices);
+      tensorflow::TTypes<float>::Matrix output_predictions);
 };
 
 }  // namespace models
index caad023..4ca18be 100644 (file)
@@ -62,7 +62,7 @@ TEST_F(MultipleAdditiveTreesTest, Empty) {
   tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
                                          kNumThreadsSingleThreaded);
   MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
-                                 &threads, output_matrix, nullptr);
+                                 &threads, output_matrix);
   EXPECT_EQ(0, output_matrix(0, 0));
   EXPECT_EQ(0, output_matrix(1, 0));
 }
@@ -99,38 +99,17 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) {
   // Normal case.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
-                                   batch_features_, &threads, output_matrix,
-                                   nullptr);
+                                   batch_features_, &threads, output_matrix);
     EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0));  // -0.4 (bias) + 0.2 (leaf 2).
     EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0));   // -0.4 (bias) + 0.9 (leaf 1).
   }
-  // Normal case with leaf node.
-  {
-    // Initialize output leaf inedx tensor, since leaf index is positive in this
-    // case, initialize with the value of -1. Since there are 2 examples and
-    // there are 2 trees, initialize leaf output index by 2 * 2.
-    auto output_leaf_index_tensor = AsTensor<int>({-1, -1, -1, -1}, {2, 2});
-    MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
-                                   batch_features_, &threads, output_matrix,
-                                   &output_leaf_index_tensor);
-    EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0));  // -0.4 (bias) + 0.2 (leaf 2).
-    EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0));   // -0.4 (bias) + 0.9 (leaf 1).
-    EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix<int>()(
-                           0, 0));  // 1st leaf for the first example
-    EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix<int>()(
-                           1, 0));  // 1st leaf for the second example
-    EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix<int>()(
-                           0, 1));  // 2nd leaf for the first example
-    EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix<int>()(
-                           1, 1));  // 2nd leaf for the second example
-  }
   // Weighted case
   {
     DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
     weighted.set_tree_weights(0, 6.0);
     weighted.set_tree_weights(1, 3.2);
     MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads,
-                                   output_matrix, nullptr);
+                                   output_matrix);
     // -0.4 (bias) + 0.2 (leaf 2).
     EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0));
     // -0.4 (bias) + 0.9 (leaf 1).
@@ -139,21 +118,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) {
   // Drop first tree.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_,
-                                   &threads, output_matrix, nullptr);
+                                   &threads, output_matrix);
     EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0));  // 0.2 (leaf 2).
     EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0));  // 0.9 (leaf 1).
   }
   // Drop second tree.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_,
-                                   &threads, output_matrix, nullptr);
+                                   &threads, output_matrix);
     EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0));  // -0.4 (bias).
     EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0));  // -0.4 (bias).
   }
   // Drop all trees.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
-                                   &threads, output_matrix, nullptr);
+                                   &threads, output_matrix);
     EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
     EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0));
   }
@@ -193,8 +172,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
   // Normal case.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
-                                   batch_features_, &threads, output_matrix,
-                                   nullptr);
+                                   batch_features_, &threads, output_matrix);
     EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0));  // -0.4 (bias)
     EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1));  // -0.7 (bias) + 0.2 (leaf 2)
     EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0));   // -0.4 (bias) + 0.9 (leaf 1)
@@ -206,7 +184,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
     weighted.set_tree_weights(0, 6.0);
     weighted.set_tree_weights(1, 3.2);
     MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads,
-                                   output_matrix, nullptr);
+                                   output_matrix);
     // bias
     EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0));
     // bias + leaf 2
@@ -219,7 +197,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
   // Dropout first tree.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_,
-                                   &threads, output_matrix, nullptr);
+                                   &threads, output_matrix);
     EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
     EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1));  // 0.2 (leaf 2)
     EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0));  // 0.9 (leaf 2)
@@ -228,7 +206,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
   // Dropout second tree.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_,
-                                   &threads, output_matrix, nullptr);
+                                   &threads, output_matrix);
     EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0));  // -0.4 (bias)
     EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1));  // -0.7 (bias)
     EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0));  // -0.4 (bias)
@@ -237,7 +215,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
   // Drop both trees.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
-                                   &threads, output_matrix, nullptr);
+                                   &threads, output_matrix);
     EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0));
     EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1));
     EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0));
@@ -280,8 +258,7 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) {
   // Normal case.
   {
     MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
-                                   batch_features_, &threads, output_matrix,
-                                   nullptr);
+                                   batch_features_, &threads, output_matrix);
     EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0));  // -0.4 (tree1) + 0.2 (leaf 2)
     EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1));  // -0.7 (tree1) + 0.3 (leaf 2)
     EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2));   // 3.0 -(tree1) + 0.4 (leaf 2)
index 6491d58..d66f645 100644 (file)
@@ -40,24 +40,6 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
   return Status::OK();
 }
 
-static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) {
-  string learner_config_str;
-  c->GetAttr("learner_config", &learner_config_str).IgnoreError();
-  LearnerConfig learner_config;
-  ParseProtoUnlimited(&learner_config, learner_config_str);
-
-  bool reduce_dim;
-  c->GetAttr("reduce_dim", &reduce_dim).IgnoreError();
-  // Sets the shape of the output as a matrix.
-  c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
-                              reduce_dim ? learner_config.num_classes() - 1
-                                         : learner_config.num_classes())});
-  c->set_output(1, {c->UnknownShape()});
-  c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim,
-                              InferenceContext::kUnknownDim)});
-  return Status::OK();
-}
-
 REGISTER_OP("GradientTreesPrediction")
     .Attr("learner_config: string")
     .Attr("num_dense_float_features: int >= 0")
@@ -108,58 +90,6 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
 and original weights of those trees during prediction.
 )doc");
 
-REGISTER_OP("GradientTreesPredictionVerbose")
-    .Attr("learner_config: string")
-    .Attr("num_dense_float_features: int >= 0")
-    .Attr("num_sparse_float_features: int >= 0")
-    .Attr("num_sparse_int_features: int >= 0")
-    .Attr("use_locking: bool = false")
-    .Attr("apply_dropout: bool")
-    .Attr("apply_averaging: bool")
-    .Attr("center_bias: bool")
-    .Attr("reduce_dim: bool")
-    .Input("tree_ensemble_handle: resource")
-    .Input("seed: int64")
-    .Input("dense_float_features: num_dense_float_features * float")
-    .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
-    .Input("sparse_float_feature_values: num_sparse_float_features * float")
-    .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
-    .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
-    .Input("sparse_int_feature_values: num_sparse_int_features * int64")
-    .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
-    .Output("predictions: float")
-    .Output("drop_out_tree_indices_weights: float")
-    .Output("leaf_index: int32")
-    .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn)
-    .Doc(R"doc(
-Runs multiple additive regression forests predictors on input instances
-and computes the final prediction for each class, and outputs a matrix of
-leaf ids per each tree in an ensemble.
-
-learner_config: Config for the learner of type LearnerConfig proto. Prediction
-ops for now uses only LearningRateDropoutDrivenConfig config from the learner.
-num_dense_float_features: Number of dense float features.
-num_sparse_float_features: Number of sparse float features.
-num_sparse_int_features: Number of sparse int features.
-use_locking: Whether to use locking.
-seed: random seed to be used for dropout.
-reduce_dim: whether to reduce the dimension (legacy impl) or not.
-apply_dropout: whether to apply dropout during prediction.
-apply_averaging: whether averaging of tree ensembles should take place. If set
-to true, will be based on AveragingConfig from learner_config.
-tree_ensemble_handle: The handle to the tree ensemble.
-dense_float_features: Rank 2 Tensors containing dense float feature values.
-sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
-sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
-sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
-sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
-sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
-sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
-predictions: Rank 2 Tensor containing predictions per example per class.
-drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
-leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up.
-)doc");
-
 REGISTER_OP("GradientTreesPartitionExamples")
     .Attr("num_dense_float_features: int >= 0")
     .Attr("num_sparse_float_features: int >= 0")
index 7f6e55a..58f0d36 100644 (file)
@@ -21,5 +21,4 @@ from __future__ import print_function
 from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader
 from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples
 from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction
-from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose
 # pylint: enable=unused-import
index 35ccb45..5dd2e0c 100644 (file)
@@ -58,7 +58,6 @@ NUM_LAYERS_ATTEMPTED = "num_layers"
 NUM_TREES_ATTEMPTED = "num_trees"
 NUM_USED_HANDLERS = "num_used_handlers"
 USED_HANDLERS_MASK = "used_handlers_mask"
-LEAF_INDEX = "leaf_index"
 _FEATURE_NAME_TEMPLATE = "%s_%d"
 
 
@@ -72,25 +71,18 @@ def _get_column_by_index(tensor, indices):
   return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1])
 
 
-def _make_predictions_dict(stamp,
-                           logits,
-                           partition_ids,
-                           ensemble_stats,
-                           used_handlers,
-                           output_leaf_index=False,
-                           leaf_index=None):
+def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats,
+                           used_handlers):
   """Returns predictions for the given logits and n_classes.
 
   Args:
     stamp: The ensemble stamp.
-    logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that
-      contains predictions when no dropout was applied.
+    logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1].
+        that contains predictions when no dropout was applied.
     partition_ids: A rank 1 `Tensor` with shape [batch_size].
     ensemble_stats: A TreeEnsembleStatsOp result tuple.
     used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a
-      boolean mask.
-    leaf_index: A boolean variable indicating whether to output leaf index into
-      predictions dictionary.
+        boolean mask..
 
   Returns:
     A dict of predictions.
@@ -103,8 +95,6 @@ def _make_predictions_dict(stamp,
   result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees
   result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers
   result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask
-  if output_leaf_index:
-    result[LEAF_INDEX] = leaf_index
   return result
 
 
@@ -278,8 +268,7 @@ class GradientBoostedDecisionTreeModel(object):
                features,
                logits_dimension,
                feature_columns=None,
-               use_core_columns=False,
-               output_leaf_index=False):
+               use_core_columns=False):
     """Construct a new GradientBoostedDecisionTreeModel function.
 
     Args:
@@ -287,15 +276,13 @@ class GradientBoostedDecisionTreeModel(object):
       num_ps_replicas: Number of parameter server replicas, can be 0.
       ensemble_handle: A handle to the ensemble variable.
       center_bias: Whether to center the bias before growing trees.
-      examples_per_layer: Number of examples to accumulate before growing a tree
-        layer. It can also be a function that computes the number of examples
-        based on the depth of the layer that's being built.
+      examples_per_layer: Number of examples to accumulate before growing
+        a tree layer. It can also be a function that computes the number of
+        examples based on the depth of the layer that's being built.
       learner_config: A learner config.
       features: `dict` of `Tensor` objects.
       logits_dimension: An int, the dimension of logits.
       feature_columns: A list of feature columns.
-      output_leaf_index: A boolean variable indicating whether to output leaf
-        index into predictions dictionary.
 
     Raises:
       ValueError: if inputs are not valid.
@@ -372,7 +359,6 @@ class GradientBoostedDecisionTreeModel(object):
         self._learner_config.multi_class_strategy ==
         learner_pb2.LearnerConfig.TREE_PER_CLASS and
         learner_config.num_classes == 2)
-    self._output_leaf_index = output_leaf_index
 
   def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
     """Runs prediction and returns a dictionary of the prediction results.
@@ -402,44 +388,22 @@ class GradientBoostedDecisionTreeModel(object):
     # Make sure ensemble stats run. This will check that the ensemble has
     # the right stamp.
     with ops.control_dependencies(ensemble_stats):
-      leaf_matrix = []
-      # Only used in infer (predict), not used in train and eval.
-      if self._output_leaf_index and mode == learn.ModeKeys.INFER:
-        predictions, _, leaf_matrix = (
-            prediction_ops).gradient_trees_prediction_verbose(
-                ensemble_handle,
-                seed,
-                self._dense_floats,
-                self._sparse_float_indices,
-                self._sparse_float_values,
-                self._sparse_float_shapes,
-                self._sparse_int_indices,
-                self._sparse_int_values,
-                self._sparse_int_shapes,
-                learner_config=self._learner_config_serialized,
-                apply_dropout=apply_dropout,
-                apply_averaging=mode != learn.ModeKeys.TRAIN,
-                use_locking=True,
-                center_bias=self._center_bias,
-                reduce_dim=self._reduce_dim)
-
-      else:
-        predictions, _ = prediction_ops.gradient_trees_prediction(
-            ensemble_handle,
-            seed,
-            self._dense_floats,
-            self._sparse_float_indices,
-            self._sparse_float_values,
-            self._sparse_float_shapes,
-            self._sparse_int_indices,
-            self._sparse_int_values,
-            self._sparse_int_shapes,
-            learner_config=self._learner_config_serialized,
-            apply_dropout=apply_dropout,
-            apply_averaging=mode != learn.ModeKeys.TRAIN,
-            use_locking=True,
-            center_bias=self._center_bias,
-            reduce_dim=self._reduce_dim)
+      predictions, _ = prediction_ops.gradient_trees_prediction(
+          ensemble_handle,
+          seed,
+          self._dense_floats,
+          self._sparse_float_indices,
+          self._sparse_float_values,
+          self._sparse_float_shapes,
+          self._sparse_int_indices,
+          self._sparse_int_values,
+          self._sparse_int_shapes,
+          learner_config=self._learner_config_serialized,
+          apply_dropout=apply_dropout,
+          apply_averaging=mode != learn.ModeKeys.TRAIN,
+          use_locking=True,
+          center_bias=self._center_bias,
+          reduce_dim=self._reduce_dim)
       partition_ids = prediction_ops.gradient_trees_partition_examples(
           ensemble_handle,
           self._dense_floats,
@@ -452,8 +416,7 @@ class GradientBoostedDecisionTreeModel(object):
           use_locking=True)
 
     return _make_predictions_dict(ensemble_stamp, predictions, partition_ids,
-                                  ensemble_stats, used_handlers,
-                                  self._output_leaf_index, leaf_matrix)
+                                  ensemble_stats, used_handlers)
 
   def predict(self, mode):
     """Returns predictions given the features and mode.
@@ -558,7 +521,7 @@ class GradientBoostedDecisionTreeModel(object):
         aggregation_method=None)[0]
     strategy = self._learner_config.multi_class_strategy
 
-    class_id = -1
+    class_id = constant_op.constant(-1, dtype=dtypes.int32)
     # Handle different multiclass strategies.
     if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
       # We build one vs rest trees.
@@ -612,31 +575,39 @@ class GradientBoostedDecisionTreeModel(object):
     # Get the weights for each example for quantiles calculation,
     weights = self._get_weights(hessian_shape, squeezed_hessians)
 
-    regularization_config = self._learner_config.regularization
-    min_node_weight = self._learner_config.constraints.min_node_weight
     # Create all handlers ensuring resources are evenly allocated across PS.
     fc_name_idx = 0
     handlers = []
     init_stamp_token = constant_op.constant(0, dtype=dtypes.int64)
+    l1_regularization = constant_op.constant(
+        self._learner_config.regularization.l1, dtypes.float32)
+    l2_regularization = constant_op.constant(
+        self._learner_config.regularization.l2, dtypes.float32)
+    tree_complexity_regularization = constant_op.constant(
+        self._learner_config.regularization.tree_complexity, dtypes.float32)
+    min_node_weight = constant_op.constant(
+        self._learner_config.constraints.min_node_weight, dtypes.float32)
+    epsilon = 0.01
+    num_quantiles = 100
+    strategy_tensor = constant_op.constant(strategy)
     with ops.device(self._get_replica_device_setter(worker_device)):
       # Create handlers for dense float columns
       for dense_float_column_idx in range(len(self._dense_floats)):
         fc_name = self._fc_names[fc_name_idx]
         handlers.append(
             ordinal_split_handler.DenseSplitHandler(
-                l1_regularization=regularization_config.l1,
-                l2_regularization=regularization_config.l2,
-                tree_complexity_regularization=(
-                    regularization_config.tree_complexity),
+                l1_regularization=l1_regularization,
+                l2_regularization=l2_regularization,
+                tree_complexity_regularization=tree_complexity_regularization,
                 min_node_weight=min_node_weight,
                 feature_column_group_id=dense_float_column_idx,
-                epsilon=0.01,
-                num_quantiles=100,
+                epsilon=epsilon,
+                num_quantiles=num_quantiles,
                 dense_float_column=self._dense_floats[dense_float_column_idx],
                 name=fc_name,
                 gradient_shape=gradient_shape,
                 hessian_shape=hessian_shape,
-                multiclass_strategy=strategy,
+                multiclass_strategy=strategy_tensor,
                 init_stamp_token=init_stamp_token))
         fc_name_idx += 1
 
@@ -645,14 +616,13 @@ class GradientBoostedDecisionTreeModel(object):
         fc_name = self._fc_names[fc_name_idx]
         handlers.append(
             ordinal_split_handler.SparseSplitHandler(
-                l1_regularization=regularization_config.l1,
-                l2_regularization=regularization_config.l2,
-                tree_complexity_regularization=(
-                    regularization_config.tree_complexity),
+                l1_regularization=l1_regularization,
+                l2_regularization=l2_regularization,
+                tree_complexity_regularization=tree_complexity_regularization,
                 min_node_weight=min_node_weight,
                 feature_column_group_id=sparse_float_column_idx,
-                epsilon=0.01,
-                num_quantiles=100,
+                epsilon=epsilon,
+                num_quantiles=num_quantiles,
                 sparse_float_column=sparse_tensor.SparseTensor(
                     self._sparse_float_indices[sparse_float_column_idx],
                     self._sparse_float_values[sparse_float_column_idx],
@@ -660,7 +630,7 @@ class GradientBoostedDecisionTreeModel(object):
                 name=fc_name,
                 gradient_shape=gradient_shape,
                 hessian_shape=hessian_shape,
-                multiclass_strategy=strategy,
+                multiclass_strategy=strategy_tensor,
                 init_stamp_token=init_stamp_token))
         fc_name_idx += 1
 
@@ -669,10 +639,9 @@ class GradientBoostedDecisionTreeModel(object):
         fc_name = self._fc_names[fc_name_idx]
         handlers.append(
             categorical_split_handler.EqualitySplitHandler(
-                l1_regularization=regularization_config.l1,
-                l2_regularization=regularization_config.l2,
-                tree_complexity_regularization=(
-                    regularization_config.tree_complexity),
+                l1_regularization=l1_regularization,
+                l2_regularization=l2_regularization,
+                tree_complexity_regularization=tree_complexity_regularization,
                 min_node_weight=min_node_weight,
                 feature_column_group_id=sparse_int_column_idx,
                 sparse_int_column=sparse_tensor.SparseTensor(
@@ -682,7 +651,7 @@ class GradientBoostedDecisionTreeModel(object):
                 name=fc_name,
                 gradient_shape=gradient_shape,
                 hessian_shape=hessian_shape,
-                multiclass_strategy=strategy,
+                multiclass_strategy=strategy_tensor,
                 init_stamp_token=init_stamp_token))
         fc_name_idx += 1
 
@@ -804,6 +773,7 @@ class GradientBoostedDecisionTreeModel(object):
     empty_hessians = constant_op.constant(
         [], dtype=dtypes.float32, shape=empty_hess_shape)
 
+    active_handlers = array_ops.unstack(active_handlers, axis=0)
     for handler_idx in range(len(handlers)):
       handler = handlers[handler_idx]
       is_active = active_handlers[handler_idx]
@@ -1014,7 +984,7 @@ class GradientBoostedDecisionTreeModel(object):
       # This is a workaround for the slowness of graph building in tf.cond.
       # See (b/36554864).
       split_sizes = array_ops.reshape(
-          array_ops.shape_n(partition_ids_list), [-1])
+          array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
       partition_ids = array_ops.concat(partition_ids_list, axis=0)
       gains = array_ops.concat(gains_list, axis=0)
       split_infos = array_ops.concat(split_info_list, axis=0)
@@ -1079,8 +1049,11 @@ class GradientBoostedDecisionTreeModel(object):
 
       # Update ensemble.
       update_ops = [are_all_splits_ready]
-      update_model = control_flow_ops.cond(continue_centering, _center_bias_fn,
-                                           _grow_ensemble_fn)
+      if self._center_bias:
+        update_model = control_flow_ops.cond(continue_centering,
+                                             _center_bias_fn, _grow_ensemble_fn)
+      else:
+        update_model = _grow_ensemble_fn()
       update_ops.append(update_model)
 
       # Update ensemble stats.
index 0665c6c..289fb19 100644 (file)
@@ -19,15 +19,18 @@ from __future__ import division
 from __future__ import print_function
 
 from google.protobuf import text_format
+
 from tensorflow.contrib import layers
 from tensorflow.contrib.boosted_trees.proto import learner_pb2
 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
 from tensorflow.contrib.boosted_trees.python.ops import model_ops
 from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
 from tensorflow.contrib.boosted_trees.python.utils import losses
+
+from tensorflow.python.feature_column import feature_column_lib as core_feature_column
 from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
 from tensorflow.contrib.learn.python.learn.estimators import model_fn
-from tensorflow.python.feature_column import feature_column_lib as core_feature_column
+
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
@@ -725,8 +728,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
       self.assertEquals(len(output.tree_weights), 0)
       self.assertEquals(stamp_token.eval(), 0)
 
-  def testPredictFnWithLeafIndexAdvancedLeft(self):
-    """Tests the predict function with output leaf ids."""
+  def testPredictFn(self):
+    """Tests the predict function."""
     with self.test_session() as sess:
       # Create ensemble with one bias node.
       ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
@@ -734,61 +737,12 @@ class GbdtTest(test_util.TensorFlowTestCase):
           """
           trees {
             nodes {
-                dense_float_binary_split {
-                  threshold: 1.0
-                  left_id: 1
-                  right_id: 2
-                }
-                node_metadata {
-                  gain: 0
-                }
-              }
-              nodes {
-                leaf {
-                  vector {
-                    value: 0.25
-                  }
-                }
-              }
-              nodes {
-                leaf {
-                  vector {
-                    value: 0.0
-                  }
-                }
-              }
-          }
-          tree_weights: 1.0
-          tree_metadata {
-            num_tree_weight_updates: 1
-            num_layers_grown: 1
-            is_finalized: true
-          }
-          trees {
-            nodes {
-                dense_float_binary_split {
-                  threshold: 0.99
-                  left_id: 1
-                  right_id: 2
-                }
-                node_metadata {
-                  gain: 0
-                }
-              }
-              nodes {
-                leaf {
-                  vector {
-                    value: 0.25
-                  }
-                }
-              }
-              nodes {
-                leaf {
-                  vector {
-                    value: 0.0
-                  }
+              leaf {
+                vector {
+                  value: 0.25
                 }
               }
+            }
           }
           tree_weights: 1.0
           tree_metadata {
@@ -809,8 +763,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
       learner_config.constraints.max_tree_depth = 1
       learner_config.constraints.min_node_weight = 0
       features = {}
-      features["dense_float"] = array_ops.constant(
-          [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32)
+      features["dense_float"] = array_ops.ones([4, 1], dtypes.float32)
       gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
           is_chief=False,
           num_ps_replicas=0,
@@ -819,20 +772,15 @@ class GbdtTest(test_util.TensorFlowTestCase):
           examples_per_layer=1,
           learner_config=learner_config,
           logits_dimension=1,
-          features=features,
-          output_leaf_index=True)
+          features=features)
 
       # Create predict op.
-      mode = model_fn.ModeKeys.INFER
+      mode = model_fn.ModeKeys.EVAL
       predictions_dict = sess.run(gbdt_model.predict(mode))
       self.assertEquals(predictions_dict["ensemble_stamp"], 3)
-      # here are how the first two numbers in expected results are calculated,
-      # 0.5 = 0.25 + 0.25, and 0.25 = 0.25 + 0
       self.assertAllClose(predictions_dict["predictions"],
-                          [[0.5], [0.25], [0], [0]])
+                          [[0.25], [0.25], [0.25], [0.25]])
       self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
-      self.assertAllClose(predictions_dict["leaf_index"],
-                          [[1, 1], [1, 2], [2, 2], [2, 2]])
 
   def testTrainFnMulticlassFullHessian(self):
     """Tests the GBDT train for multiclass full hessian."""