&output_logits_t));
auto output_logits = output_logits_t->matrix<float>();
+ // Return zero logits if it's an empty ensemble.
+ if (resource->num_trees() <= 0) {
+ output_logits.setZero();
+ return;
+ }
+
const int32 latest_tree = resource->num_trees() - 1;
auto do_work = [&resource, &batch_bucketized_features, &output_logits,
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import session_run_hook
NUM_FEATURES = 3
return ensemble_proto
+ def testFirstCheckpointWorksFine(self):
+ """Tests that eval/pred doesn't crash with the very first checkpoint.
+
+ The step-0 checkpoint will have only an empty ensemble, and a separate eval
+ job might read from it and crash.
+ This test ensures that prediction/evaluation works fine with it.
+ """
+ input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ class BailOutWithoutTraining(session_run_hook.SessionRunHook):
+
+ def before_run(self, run_context):
+ raise StopIteration('to bail out.')
+
+ est.train(input_fn, steps=100, # must stop at 0 anyway.
+ hooks=[BailOutWithoutTraining()])
+ self._assert_checkpoint(
+ est.model_dir, global_step=0, finalized_trees=0, attempted_layers=0)
+ # Empty ensemble returns 0 logits, so that all output labels are 0.
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 0.6)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [0], [0], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+
def testTrainAndEvaluateBinaryClassifier(self):
input_fn = _make_train_input_fn(is_classification=True)
class PredictionOpsTest(test_util.TensorFlowTestCase):
"""Tests prediction ops for inference."""
+ def testPredictionOnEmptyEnsemble(self):
+ """Tests that prediction on a empty ensemble does not fail."""
+ with self.test_session() as session:
+ # Create an empty ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto='')
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [36, 32]
+ feature_1_values = [11, 27]
+ expected_logits = [[0.0], [0.0]]
+
+ # Prediction should work fine.
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
def testPredictionMultipleTree(self):
"""Tests the predictions work when we have multiple trees."""
with self.test_session() as session:
# logit= 0.1*1.14+0.2*7.0-1*7.0
expected_logits = [[6.114], [-5.486]]
- # Do with parallelization, e.g. EVAL
- predict_op = boosted_trees_ops.predict(
- tree_ensemble_handle,
- bucketized_features=[feature_0_values, feature_1_values],
- logits_dimension=1)
-
- logits = session.run(predict_op)
- self.assertAllClose(expected_logits, logits)
-
- # Do without parallelization, e.g. INFER - the result is the same
+ # Prediction should work fine.
predict_op = boosted_trees_ops.predict(
tree_ensemble_handle,
bucketized_features=[feature_0_values, feature_1_values],