From: Younghee Kwon Date: Tue, 17 Apr 2018 02:10:10 +0000 (-0700) Subject: BoostedTreesEstimator in contrib: train_in_memory works with input_fns returning... X-Git-Tag: upstream/v1.9.0_rc1~287^2~1^2~15 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d48c55db5fc8ab07d2bf679b4ea7c3c4c84ace76;p=platform%2Fupstream%2Ftensorflow.git BoostedTreesEstimator in contrib: train_in_memory works with input_fns returning data.Dataset. Only one batch of data is expected, so dataset.batch() is disallowed, and dataset.repeat() will be ignored (only the first one would be used) PiperOrigin-RevId: 193137094 --- diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 00356ce..bd64101 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -17,10 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +def _validate_input_fn_and_repeat_dataset(train_input_fn): + """Validates whether the input_fn is valid, and repeat() if tf.Dataset.""" + def _input_fn(): + result_input_fn = train_input_fn() + if isinstance(result_input_fn, dataset_ops.Dataset): + return result_input_fn.repeat() + return result_input_fn + + return _input_fn + + class _BoostedTreesEstimator(estimator.Estimator): """An Estimator for Tensorflow Boosted Trees models.""" @@ -113,10 +125,13 @@ def boosted_trees_classifier_train_in_memory( bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) - def input_fn_train(): + def train_input_fn(): dataset = create-dataset-from-training-data - # Don't use repeat or cache, since it is assumed to be one epoch - # This is either tf.data.Dataset, or a tuple of feature dict and label. + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. return dataset classifier = boosted_trees_classifier_train_in_memory( @@ -210,7 +225,9 @@ def boosted_trees_classifier_train_in_memory( in_memory_classifier = estimator.Estimator( model_fn=_model_fn, model_dir=model_dir, config=config) - in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks) + in_memory_classifier.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) return in_memory_classifier # pylint: enable=protected-access @@ -241,10 +258,13 @@ def boosted_trees_regressor_train_in_memory( bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) - def input_fn_train(): + def train_input_fn(): dataset = create-dataset-from-training-data - # Don't use repeat or cache, since it is assumed to be one epoch - # This is either tf.data.Dataset, or a tuple of feature dict and label. + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. return dataset regressor = boosted_trees_regressor_train_in_memory( @@ -329,7 +349,9 @@ def boosted_trees_regressor_train_in_memory( in_memory_regressor = estimator.Estimator( model_fn=_model_fn, model_dir=model_dir, config=config) - in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks) + in_memory_regressor.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) return in_memory_regressor # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index eee5910..76cbefe 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.contrib.estimator.python.estimator import boosted_trees from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column @@ -49,12 +50,24 @@ def _make_train_input_fn(is_classification): """Makes train input_fn for classification/regression.""" def _input_fn(): - features = dict(FEATURES_DICT) - if is_classification: - labels = CLASSIFICATION_LABELS - else: - labels = REGRESSION_LABELS - return features, labels + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + return features_dict, labels + + return _input_fn + + +def _make_train_input_fn_dataset(is_classification): + """Makes input_fn using Dataset.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(features_dict), + dataset_ops.Dataset.from_tensors(labels) + )) + return ds return _input_fn @@ -132,15 +145,13 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5) + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - # Check eval. + # Check evaluate and predict. eval_res = est.evaluate(input_fn=train_input_fn, steps=1) self.assertAllClose(eval_res['accuracy'], 1.0) # Validate predictions. @@ -148,24 +159,59 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testRegressorTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_regressor_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5) + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - # Check eval. + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testRegressorTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Check evaluate and predict. eval_res = est.evaluate(input_fn=train_input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 2.478283) - # Validate predictions. predictions = list(est.predict(input_fn=predict_input_fn)) self.assertAllClose( [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 536bd2b..085dace 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary @@ -50,6 +51,32 @@ _HOLD_FOR_MULTI_CLASS_SUPPORT = object() _HOLD_FOR_MULTI_DIM_SUPPORT = object() +def _get_max_buckets(feature_columns): + """Gets the maximum number of buckets from feature_columns. + + Args: + feature_columns: a list/set of tf.feature_column. + + Returns: + max_buckets: the maximum number of buckets among bucketized_columns. + + Raises: + ValueError: when unsupported feature_columns are given. + """ + if not feature_columns: + raise ValueError('feature_columns must be a non-empty list/set of ' + 'tf.feature_column.') + max_buckets = 1 + for fc in feature_columns: + if isinstance(fc, feature_column_lib._BucketizedColumn): # pylint:disable=protected-access + # N boundaries creates (N+1) buckets. + max_buckets = max(max_buckets, len(fc.boundaries) + 1) + else: + raise ValueError('For now, only bucketized_column is supported but ' + 'got: {}'.format(fc)) + return max_buckets + + def _get_transformed_features(features, feature_columns): """Gets the transformed features from features/feature_columns pair. @@ -59,36 +86,31 @@ def _get_transformed_features(features, feature_columns): Returns: result_features: a list of the transformed features, sorted by the name. - num_buckets: the maximum number of buckets across bucketized_columns. Raises: ValueError: when unsupported features/columns are tried. """ - num_buckets = 1 # pylint:disable=protected-access for fc in feature_columns: - if isinstance(fc, feature_column_lib._BucketizedColumn): - # N boundaries creates (N+1) buckets. - num_buckets = max(num_buckets, len(fc.boundaries) + 1) - else: + if not isinstance(fc, feature_column_lib._BucketizedColumn): raise ValueError('For now, only bucketized_column is supported but ' 'got: {}'.format(fc)) - transformed = feature_column_lib._transform_features(features, - feature_columns) + transformed_features = feature_column_lib._transform_features( + features, feature_columns) # pylint:enable=protected-access result_features = [] - for column in sorted(transformed, key=lambda tc: tc.name): + for column in sorted(transformed_features, key=lambda tc: tc.name): source_name = column.source_column.name - squeezed_tensor = array_ops.squeeze(transformed[column], axis=1) + squeezed_tensor = array_ops.squeeze(transformed_features[column], axis=1) if len(squeezed_tensor.shape) > 1: raise ValueError('For now, only supports features equivalent to rank 1 ' 'but column `{}` got: {}'.format( source_name, features[source_name].shape)) result_features.append(squeezed_tensor) - return result_features, num_buckets + return result_features -def _keep_as_local_variable(tensor, name=None): +def _local_variable(tensor, name=None): """Stores a tensor as a local Variable for faster read.""" return variable_scope.variable( initial_value=tensor, @@ -98,6 +120,48 @@ def _keep_as_local_variable(tensor, name=None): name=name) +def _cache_transformed_features(features, feature_columns, batch_size): + """Transform features and cache, then returns (cached_features, cache_op).""" + num_features = len(feature_columns) + cached_features = [ + _local_variable( + array_ops.zeros([batch_size], dtype=dtypes.int32), + name='cached_feature_{}'.format(i)) + for i in range(num_features) + ] + are_features_cached = _local_variable(False, name='are_features_cached') + + def cache_features_and_return(): + """Caches transoformed features. + + The intention is to hide get_transformed_features() from the graph by + caching the result except the first step, since bucketize operation + (inside get_transformed_features) is expensive. + + Returns: + input_feature_list: a list of input features. + cache_flip_op: op to add to graph to make sure cache update is included to + the graph. + """ + + transformed_features = _get_transformed_features(features, feature_columns) + cached = [ + state_ops.assign(cached_features[i], transformed_features[i]) + for i in range(num_features) + ] + # TODO(youngheek): Try other combination of dependencies so that the + # function returns a single result, not a tuple. + with ops.control_dependencies(cached): + cache_flip_op = are_features_cached.assign(True) + return cached, cache_flip_op + + input_feature_list, cache_flip_op = control_flow_ops.cond( + are_features_cached, + lambda: (cached_features, control_flow_ops.no_op()), + cache_features_and_return) + return input_feature_list, cache_flip_op + + class _CacheTrainingStatesUsingHashTable(object): """Caching logits, etc. using MutableHashTable.""" @@ -186,13 +250,13 @@ class _CacheTrainingStatesUsingVariables(object): logits_dimension: a constant (int) for the dimension of logits. """ self._logits_dimension = logits_dimension - self._tree_ids = _keep_as_local_variable( + self._tree_ids = _local_variable( array_ops.zeros([batch_size], dtype=dtypes.int32), name='tree_ids_cache') - self._node_ids = _keep_as_local_variable( + self._node_ids = _local_variable( array_ops.zeros([batch_size], dtype=dtypes.int32), name='node_ids_cache') - self._logits = _keep_as_local_variable( + self._logits = _local_variable( array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32), name='logits_cache') @@ -290,33 +354,38 @@ def _bt_model_fn( 'When train_in_memory is enabled, input_fn should return the entire ' 'dataset as a single batch, and n_batches_per_layer should be set as ' '1.') + if (not config.is_chief or config.num_worker_replicas > 1 or + config.num_ps_replicas > 0): + raise ValueError('train_in_memory is supported only for ' + 'non-distributed training.') worker_device = control_flow_ops.no_op().device # maximum number of splits possible in the whole tree =2^(D-1)-1 # TODO(youngheek): perhaps storage could be optimized by storing stats with # the dimension max_splits_per_layer, instead of max_splits (for the entire # tree). max_splits = (1 << tree_hparams.max_depth) - 1 + max_buckets = _get_max_buckets(feature_columns) + train_op = [] with ops.name_scope(name) as name: # Prepare. global_step = training_util.get_or_create_global_step() - input_feature_list, num_buckets = _get_transformed_features( - features, feature_columns) - if train_in_memory and mode == model_fn.ModeKeys.TRAIN: - input_feature_list = [ - _keep_as_local_variable(feature) for feature in input_feature_list - ] - num_features = len(input_feature_list) - - cache = None - if mode == model_fn.ModeKeys.TRAIN: - if train_in_memory and is_single_machine: # maybe just train_in_memory? - batch_size = array_ops.shape(input_feature_list[0])[0] - cache = _CacheTrainingStatesUsingVariables(batch_size, - head.logits_dimension) - elif example_id_column_name: + num_features = len(feature_columns) + # Extract input features and set up cache for training. + training_state_cache = None + if mode == model_fn.ModeKeys.TRAIN and train_in_memory: + # cache transformed features as well for in-memory training. + batch_size = array_ops.shape(labels)[0] + input_feature_list, input_cache_op = _cache_transformed_features( + features, feature_columns, batch_size) + train_op.append(input_cache_op) + training_state_cache = _CacheTrainingStatesUsingVariables( + batch_size, head.logits_dimension) + else: + input_feature_list = _get_transformed_features(features, feature_columns) + if mode == model_fn.ModeKeys.TRAIN and example_id_column_name: example_ids = features[example_id_column_name] - cache = _CacheTrainingStatesUsingHashTable(example_ids, - head.logits_dimension) + training_state_cache = _CacheTrainingStatesUsingHashTable( + example_ids, head.logits_dimension) # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) @@ -340,11 +409,12 @@ def _bt_model_fn( # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) - if cache: - cached_tree_ids, cached_node_ids, cached_logits = cache.lookup() + if training_state_cache: + cached_tree_ids, cached_node_ids, cached_logits = ( + training_state_cache.lookup()) else: # Always start from the beginning when no cache is set up. - batch_size = array_ops.shape(input_feature_list[0])[0] + batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros([batch_size], dtype=dtypes.int32), @@ -368,9 +438,8 @@ def _bt_model_fn( # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" - train_op = [] - if cache: - train_op.append(cache.insert(tree_ids, node_ids, logits)) + if training_state_cache: + train_op.append(training_state_cache.insert(tree_ids, node_ids, logits)) if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn(logits, labels) else: @@ -385,7 +454,7 @@ def _bt_model_fn( hessians=hessians, bucketized_features_list=[input_feature_list[f]], max_splits=max_splits, - num_buckets=num_buckets), + num_buckets=max_buckets), axis=0) for f in range(num_features) ] @@ -422,7 +491,7 @@ def _bt_model_fn( summary_accumulator = data_flow_ops.ConditionalAccumulator( dtype=dtypes.float32, # The stats consist of gradients and hessians (the last dimension). - shape=[num_features, max_splits, num_buckets, 2], + shape=[num_features, max_splits, max_buckets, 2], shared_name='stats_summary_accumulator') apply_grad = summary_accumulator.apply_grad( array_ops.stack(stats_summary_list, axis=0), stamp_token) diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 56e67a6..c8c52d3 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn from tensorflow.python.estimator import run_config from tensorflow.python.estimator.canned import boosted_trees @@ -58,13 +59,32 @@ def _make_train_input_fn(is_classification): """Makes train input_fn for classification/regression.""" def _input_fn(): - features = dict(FEATURES_DICT) - features[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS) - if is_classification: - labels = CLASSIFICATION_LABELS + features_dict = dict(FEATURES_DICT) + features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + return features_dict, labels + + return _input_fn + + +def _make_train_input_fn_dataset(is_classification, batch=None, repeat=None): + """Makes input_fn using Dataset.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + if batch: + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensor_slices(features_dict), + dataset_ops.Dataset.from_tensor_slices(labels))).batch(batch) else: - labels = REGRESSION_LABELS - return features, labels + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(features_dict), + dataset_ops.Dataset.from_tensors(labels))) + # repeat indefinitely by default, or stop at the given step. + ds = ds.repeat(repeat) + return ds return _input_fn @@ -125,9 +145,28 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): num_steps = 100 # Train for a few steps, and validate final checkpoint. est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testTrainClassifierWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + est.train(train_input_fn, steps=100) # will stop after 5 steps anyway. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) predictions = list(est.predict(input_fn=predict_input_fn)) - # All labels are correct. self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) @@ -166,12 +205,126 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): est.train(train_input_fn, steps=num_steps) self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testTrainRegressorWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.BoostedTreesRegressor( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + est.train(train_input_fn, steps=100) # will stop after 5 steps anyway. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testTrainRegressorWithDatasetBatch(self): + # The batch_size as the entire data size should yield the same result as + # dataset without batching. + train_input_fn = _make_train_input_fn_dataset( + is_classification=False, batch=5) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.BoostedTreesRegressor( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + est.train(train_input_fn, steps=100) # will stop after 5 steps anyway. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testTrainRegressorWithDatasetLargerBatch(self): + # The batch_size as the multiple of the entire data size should still yield + # the same result. + train_input_fn = _make_train_input_fn_dataset( + is_classification=False, batch=15) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.BoostedTreesRegressor( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + est.train(train_input_fn, steps=100) # will stop after 5 steps anyway. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testTrainRegressorWithDatasetSmallerBatch(self): + # Even when using small batches, if (n_batches_per_layer * batch_size) makes + # the same entire data size, the result should be the same. + train_input_fn = _make_train_input_fn_dataset( + is_classification=False, batch=1) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + est = boosted_trees.BoostedTreesRegressor( + feature_columns=self._feature_columns, + n_batches_per_layer=5, + n_trees=1, + max_depth=5) + # Train stops after (n_batches_per_layer * n_trees * max_depth) steps. + est.train(train_input_fn, steps=100) + self._assert_checkpoint( + est.model_dir, global_step=25, finalized_trees=1, attempted_layers=5) + # 5 batches = one epoch. + eval_res = est.evaluate(input_fn=train_input_fn, steps=5) + self.assertAllClose(eval_res['average_loss'], 2.478283) predictions = list(est.predict(input_fn=predict_input_fn)) self.assertAllClose( [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], [pred['predictions'] for pred in predictions]) + def testTrainRegressorWithDatasetWhenInputIsOverEarlier(self): + train_input_fn = _make_train_input_fn_dataset( + is_classification=False, repeat=3) # to stop input after 3 steps. + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.BoostedTreesRegressor( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + # Note that training will stop when input exhausts. + # This might not be a typical pattern, but dataset.repeat(3) causes + # the input stream to cease after 3 steps. + est.train(train_input_fn, steps=100) + self._assert_checkpoint( + est.model_dir, global_step=3, finalized_trees=0, attempted_layers=3) + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 3.777295) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.353850], [0.254100], [0.106850], [0.712100], [1.012100]], + [pred['predictions'] for pred in predictions]) + class ModelFnTests(test_util.TensorFlowTestCase): """Tests bt_model_fn including unexposed internal functionalities."""