From 0295c3c212546147e57609a791f126aaca44ba44 Mon Sep 17 00:00:00 2001 From: Younghee Kwon Date: Mon, 9 Apr 2018 17:45:13 -0700 Subject: [PATCH] boosted_trees: early stop hooks are fixed to stop at the right moment by reading tensor values in a separate session after train_op run. PiperOrigin-RevId: 192217338 --- .../python/estimator/boosted_trees_test.py | 97 ++++++++-------------- .../python/estimator/canned/boosted_trees.py | 33 +++----- .../python/estimator/canned/boosted_trees_test.py | 63 ++++++-------- 3 files changed, 71 insertions(+), 122 deletions(-) diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index e99a87f..eee5910 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.estimator.python.estimator import boosted_trees +from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column @@ -69,10 +70,18 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): for i in range(NUM_FEATURES) } - def _assert_checkpoint(self, model_dir, expected_global_step): - self.assertEqual(expected_global_step, - checkpoint_utils.load_variable(model_dir, - ops.GraphKeys.GLOBAL_STEP)) + def _assert_checkpoint(self, model_dir, global_step, finalized_trees, + attempted_layers): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + serialized = reader.get_tensor('boosted_trees:0_serialized') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertEqual( + finalized_trees, + sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized])) + self.assertEqual(attempted_layers, + ensemble_proto.growing_metadata.num_layers_attempted) def testTrainAndEvaluateEstimator(self): input_fn = _make_train_input_fn(is_classification=False) @@ -88,9 +97,10 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): num_steps = 100 # Train for a few steps, and validate final checkpoint. est.train(input_fn, steps=num_steps) - self._assert_checkpoint(est.model_dir, 11) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) eval_res = est.evaluate(input_fn=input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 0.913176) + self.assertAllClose(eval_res['average_loss'], 1.008551) def testInferEstimator(self): train_input_fn = _make_train_input_fn(is_classification=False) @@ -108,31 +118,13 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): num_steps = 100 # Train for a few steps, and validate final checkpoint. est.train(train_input_fn, steps=num_steps) - self._assert_checkpoint(est.model_dir, 6) - + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Validate predictions. predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertEquals(5, len(predictions)) - self.assertAllClose([0.703549], predictions[0]['predictions']) - self.assertAllClose([0.266539], predictions[1]['predictions']) - self.assertAllClose([0.256479], predictions[2]['predictions']) - self.assertAllClose([1.088732], predictions[3]['predictions']) - self.assertAllClose([1.901732], predictions[4]['predictions']) - - -class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase): - - def setUp(self): - self._feature_columns = { - feature_column.bucketized_column( - feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), - BUCKET_BOUNDARIES) - for i in range(NUM_FEATURES) - } - - def _assert_checkpoint(self, model_dir, expected_global_step): - self.assertEqual(expected_global_step, - checkpoint_utils.load_variable(model_dir, - ops.GraphKeys.GLOBAL_STEP)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=True) @@ -145,36 +137,16 @@ class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase): n_trees=1, max_depth=5) # It will stop after 5 steps because of the max depth and num trees. - self._assert_checkpoint(est.model_dir, 6) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) # Check eval. eval_res = est.evaluate(input_fn=train_input_fn, steps=1) self.assertAllClose(eval_res['accuracy'], 1.0) - - # Check predict that all labels are correct. + # Validate predictions. predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertEquals(5, len(predictions)) - self.assertAllClose([0], predictions[0]['class_ids']) - self.assertAllClose([1], predictions[1]['class_ids']) - self.assertAllClose([1], predictions[2]['class_ids']) - self.assertAllClose([0], predictions[3]['class_ids']) - self.assertAllClose([0], predictions[4]['class_ids']) - - -class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase): - - def setUp(self): - self._feature_columns = { - feature_column.bucketized_column( - feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), - BUCKET_BOUNDARIES) - for i in range(NUM_FEATURES) - } - - def _assert_checkpoint(self, model_dir, expected_global_step): - self.assertEqual(expected_global_step, - checkpoint_utils.load_variable(model_dir, - ops.GraphKeys.GLOBAL_STEP)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) def testRegressorTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=False) @@ -187,20 +159,17 @@ class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase): n_trees=1, max_depth=5) # It will stop after 5 steps because of the max depth and num trees. - self._assert_checkpoint(est.model_dir, 6) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) # Check eval. eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 2.2136638) - + self.assertAllClose(eval_res['average_loss'], 2.478283) # Validate predictions. predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertEquals(5, len(predictions)) - self.assertAllClose([0.703549], predictions[0]['predictions']) - self.assertAllClose([0.266539], predictions[1]['predictions']) - self.assertAllClose([0.256479], predictions[2]['predictions']) - self.assertAllClose([1.088732], predictions[3]['predictions']) - self.assertAllClose([1.901732], predictions[4]['predictions']) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) if __name__ == '__main__': diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 500ea03..c5d5455 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -209,8 +209,8 @@ class _CacheTrainingStatesUsingVariables(object): name='cache_insert') -class StopAtAttemptsHook(session_run_hook.SessionRunHook): - """Hook that requests stop at the number of trees.""" +class _StopAtAttemptsHook(session_run_hook.SessionRunHook): + """Hook that requests stop at the number of attempts.""" def __init__(self, num_finalized_trees_tensor, num_attempted_layers_tensor, max_trees, max_depth): @@ -224,25 +224,17 @@ class StopAtAttemptsHook(session_run_hook.SessionRunHook): [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor]) def after_run(self, run_context, run_values): + # num_* tensors should be retrieved by a separate session than the training + # one, in order to read the values after growing. + # So, if it's approaching to the limit, get the actual value by additional + # session. num_finalized_trees, num_attempted_layers = run_values.results + if (num_finalized_trees >= self._max_trees - 1 or + num_attempted_layers > 2 * self._max_trees * self._max_depth - 1): + num_finalized_trees, num_attempted_layers = run_context.session.run( + [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor]) if (num_finalized_trees >= self._max_trees or - 1.0 * num_attempted_layers / self._max_depth > 2 * self._max_trees): - run_context.request_stop() - - -class StopAtNumTreesHook(session_run_hook.SessionRunHook): - """Hook that requests stop at the number of trees.""" - - def __init__(self, num_trees_tensor, max_trees): - self._num_trees_tensor = num_trees_tensor - self._max_trees = max_trees - - def before_run(self, run_context): - return session_run_hook.SessionRunArgs(self._num_trees_tensor) - - def after_run(self, run_context, run_values): - num_trees = run_values.results - if num_trees > self._max_trees: + num_attempted_layers > 2 * self._max_trees * self._max_depth): run_context.request_stop() @@ -468,7 +460,8 @@ def _bt_model_fn( # Add an early stop hook. estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + - (StopAtNumTreesHook(num_trees, tree_hparams.n_trees),)) + (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, + tree_hparams.n_trees, tree_hparams.max_depth),)) return estimator_spec diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 01e5cc7..625745a 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -69,7 +69,7 @@ def _make_train_input_fn(is_classification): return _input_fn -class BoostedTreesClassifierTest(test_util.TensorFlowTestCase): +class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): self._feature_columns = { @@ -79,10 +79,18 @@ class BoostedTreesClassifierTest(test_util.TensorFlowTestCase): for i in range(NUM_FEATURES) } - def _assert_checkpoint(self, model_dir, expected_global_step): - self.assertEqual(expected_global_step, - checkpoint_utils.load_variable(model_dir, - ops.GraphKeys.GLOBAL_STEP)) + def _assert_checkpoint(self, model_dir, global_step, finalized_trees, + attempted_layers): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + serialized = reader.get_tensor('boosted_trees:0_serialized') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertEqual( + finalized_trees, + sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized])) + self.assertEqual(attempted_layers, + ensemble_proto.growing_metadata.num_layers_attempted) def testTrainAndEvaluateBinaryClassifier(self): input_fn = _make_train_input_fn(is_classification=True) @@ -97,7 +105,8 @@ class BoostedTreesClassifierTest(test_util.TensorFlowTestCase): num_steps = 100 # Train for a few steps, and validate final checkpoint. est.train(input_fn, steps=num_steps) - self._assert_checkpoint(est.model_dir, 6) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) eval_res = est.evaluate(input_fn=input_fn, steps=1) self.assertAllClose(eval_res['accuracy'], 1.0) @@ -118,29 +127,9 @@ class BoostedTreesClassifierTest(test_util.TensorFlowTestCase): est.train(train_input_fn, steps=num_steps) predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertEquals(5, len(predictions)) # All labels are correct. - self.assertAllClose([0], predictions[0]['class_ids']) - self.assertAllClose([1], predictions[1]['class_ids']) - self.assertAllClose([1], predictions[2]['class_ids']) - self.assertAllClose([0], predictions[3]['class_ids']) - self.assertAllClose([0], predictions[4]['class_ids']) - - -class BoostedTreesRegressionTest(test_util.TensorFlowTestCase): - - def setUp(self): - self._feature_columns = { - feature_column.bucketized_column( - feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), - BUCKET_BOUNDARIES) - for i in range(NUM_FEATURES) - } - - def _assert_checkpoint(self, model_dir, expected_global_step): - self.assertEqual(expected_global_step, - checkpoint_utils.load_variable(model_dir, - ops.GraphKeys.GLOBAL_STEP)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) def testTrainAndEvaluateRegressor(self): input_fn = _make_train_input_fn(is_classification=False) @@ -155,9 +144,10 @@ class BoostedTreesRegressionTest(test_util.TensorFlowTestCase): num_steps = 100 # Train for a few steps, and validate final checkpoint. est.train(input_fn, steps=num_steps) - self._assert_checkpoint(est.model_dir, 11) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) eval_res = est.evaluate(input_fn=input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 0.913176) + self.assertAllClose(eval_res['average_loss'], 1.008551) def testInferRegressor(self): train_input_fn = _make_train_input_fn(is_classification=False) @@ -174,16 +164,13 @@ class BoostedTreesRegressionTest(test_util.TensorFlowTestCase): num_steps = 100 # Train for a few steps, and validate final checkpoint. est.train(train_input_fn, steps=num_steps) - self._assert_checkpoint(est.model_dir, 6) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) predictions = list(est.predict(input_fn=predict_input_fn)) - - self.assertEquals(5, len(predictions)) - self.assertAllClose([0.703549], predictions[0]['predictions']) - self.assertAllClose([0.266539], predictions[1]['predictions']) - self.assertAllClose([0.256479], predictions[2]['predictions']) - self.assertAllClose([1.088732], predictions[3]['predictions']) - self.assertAllClose([1.901732], predictions[4]['predictions']) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) class ModelFnTests(test_util.TensorFlowTestCase): -- 2.7.4