boosted_trees: fixed the crash when eval/prediction is attempted with the initial...
authorYounghee Kwon <youngheek@google.com>
Thu, 17 May 2018 23:52:17 +0000 (16:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 23:58:20 +0000 (16:58 -0700)
PiperOrigin-RevId: 197073582

tensorflow/core/kernels/boosted_trees/prediction_ops.cc
tensorflow/python/estimator/canned/boosted_trees_test.py
tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py

index 1b5ce32..20359f2 100644 (file)
@@ -213,6 +213,12 @@ class BoostedTreesPredictOp : public OpKernel {
                                 &output_logits_t));
     auto output_logits = output_logits_t->matrix<float>();
 
+    // Return zero logits if it's an empty ensemble.
+    if (resource->num_trees() <= 0) {
+      output_logits.setZero();
+      return;
+    }
+
     const int32 latest_tree = resource->num_trees() - 1;
 
     auto do_work = [&resource, &batch_bucketized_features, &output_logits,
index 13595d4..0f2c1e1 100644 (file)
@@ -35,6 +35,7 @@ from tensorflow.python.ops import resources
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import session_run_hook
 
 NUM_FEATURES = 3
 
@@ -121,6 +122,39 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
 
     return ensemble_proto
 
+  def testFirstCheckpointWorksFine(self):
+    """Tests that eval/pred doesn't crash with the very first checkpoint.
+
+    The step-0 checkpoint will have only an empty ensemble, and a separate eval
+    job might read from it and crash.
+    This test ensures that prediction/evaluation works fine with it.
+    """
+    input_fn = _make_train_input_fn(is_classification=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+
+    class BailOutWithoutTraining(session_run_hook.SessionRunHook):
+
+      def before_run(self, run_context):
+        raise StopIteration('to bail out.')
+
+    est.train(input_fn, steps=100,  # must stop at 0 anyway.
+              hooks=[BailOutWithoutTraining()])
+    self._assert_checkpoint(
+        est.model_dir, global_step=0, finalized_trees=0, attempted_layers=0)
+    # Empty ensemble returns 0 logits, so that all output labels are 0.
+    eval_res = est.evaluate(input_fn=input_fn, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 0.6)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose([[0], [0], [0], [0], [0]],
+                        [pred['class_ids'] for pred in predictions])
+
   def testTrainAndEvaluateBinaryClassifier(self):
     input_fn = _make_train_input_fn(is_classification=True)
 
index 54f33f3..92cd53a 100644 (file)
@@ -792,6 +792,28 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
 class PredictionOpsTest(test_util.TensorFlowTestCase):
   """Tests prediction ops for inference."""
 
+  def testPredictionOnEmptyEnsemble(self):
+    """Tests that prediction on a empty ensemble does not fail."""
+    with self.test_session() as session:
+      # Create an empty ensemble.
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto='')
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      feature_0_values = [36, 32]
+      feature_1_values = [11, 27]
+      expected_logits = [[0.0], [0.0]]
+
+      # Prediction should work fine.
+      predict_op = boosted_trees_ops.predict(
+          tree_ensemble_handle,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits = session.run(predict_op)
+      self.assertAllClose(expected_logits, logits)
+
   def testPredictionMultipleTree(self):
     """Tests the predictions work when we have multiple trees."""
     with self.test_session() as session:
@@ -893,16 +915,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
       #            logit= 0.1*1.14+0.2*7.0-1*7.0
       expected_logits = [[6.114], [-5.486]]
 
-      # Do with parallelization, e.g. EVAL
-      predict_op = boosted_trees_ops.predict(
-          tree_ensemble_handle,
-          bucketized_features=[feature_0_values, feature_1_values],
-          logits_dimension=1)
-
-      logits = session.run(predict_op)
-      self.assertAllClose(expected_logits, logits)
-
-      # Do without parallelization, e.g. INFER - the result is the same
+      # Prediction should work fine.
       predict_op = boosted_trees_ops.predict(
           tree_ensemble_handle,
           bucketized_features=[feature_0_values, feature_1_values],