BoostedTreesEstimator in contrib: train_in_memory works with input_fns returning...
authorYounghee Kwon <youngheek@google.com>
Tue, 17 Apr 2018 02:10:10 +0000 (19:10 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 17 Apr 2018 02:12:41 +0000 (19:12 -0700)
Only one batch of data is expected, so dataset.batch() is disallowed,
and dataset.repeat() will be ignored (only the first one would be used)

PiperOrigin-RevId: 193137094

tensorflow/contrib/estimator/python/estimator/boosted_trees.py
tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
tensorflow/python/estimator/canned/boosted_trees.py
tensorflow/python/estimator/canned/boosted_trees_test.py

index 00356ce..bd64101 100644 (file)
@@ -17,10 +17,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import estimator
 from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
 
 
+def _validate_input_fn_and_repeat_dataset(train_input_fn):
+  """Validates whether the input_fn is valid, and repeat() if tf.Dataset."""
+  def _input_fn():
+    result_input_fn = train_input_fn()
+    if isinstance(result_input_fn, dataset_ops.Dataset):
+      return result_input_fn.repeat()
+    return result_input_fn
+
+  return _input_fn
+
+
 class _BoostedTreesEstimator(estimator.Estimator):
   """An Estimator for Tensorflow Boosted Trees models."""
 
@@ -113,10 +125,13 @@ def boosted_trees_classifier_train_in_memory(
   bucketized_feature_2 = bucketized_column(
     numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
 
-  def input_fn_train():
+  def train_input_fn():
     dataset = create-dataset-from-training-data
-    # Don't use repeat or cache, since it is assumed to be one epoch
-    # This is either tf.data.Dataset, or a tuple of feature dict and label.
+    # This is tf.data.Dataset of a tuple of feature dict and label.
+    #   e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}),
+    #                     Dataset.from_tensors(label_array)))
+    # The returned Dataset shouldn't be batched.
+    # If Dataset repeats, only the first repetition would be used for training.
     return dataset
 
   classifier = boosted_trees_classifier_train_in_memory(
@@ -210,7 +225,9 @@ def boosted_trees_classifier_train_in_memory(
   in_memory_classifier = estimator.Estimator(
       model_fn=_model_fn, model_dir=model_dir, config=config)
 
-  in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
+  in_memory_classifier.train(
+      input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn),
+      hooks=train_hooks)
 
   return in_memory_classifier
   # pylint: enable=protected-access
@@ -241,10 +258,13 @@ def boosted_trees_regressor_train_in_memory(
   bucketized_feature_2 = bucketized_column(
     numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
 
-  def input_fn_train():
+  def train_input_fn():
     dataset = create-dataset-from-training-data
-    # Don't use repeat or cache, since it is assumed to be one epoch
-    # This is either tf.data.Dataset, or a tuple of feature dict and label.
+    # This is tf.data.Dataset of a tuple of feature dict and label.
+    #   e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}),
+    #                     Dataset.from_tensors(label_array)))
+    # The returned Dataset shouldn't be batched.
+    # If Dataset repeats, only the first repetition would be used for training.
     return dataset
 
   regressor = boosted_trees_regressor_train_in_memory(
@@ -329,7 +349,9 @@ def boosted_trees_regressor_train_in_memory(
   in_memory_regressor = estimator.Estimator(
       model_fn=_model_fn, model_dir=model_dir, config=config)
 
-  in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks)
+  in_memory_regressor.train(
+      input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn),
+      hooks=train_hooks)
 
   return in_memory_regressor
   # pylint: enable=protected-access
index eee5910..76cbefe 100644 (file)
@@ -21,6 +21,7 @@ import numpy as np
 
 from tensorflow.contrib.estimator.python.estimator import boosted_trees
 from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
 from tensorflow.python.estimator.inputs import numpy_io
 from tensorflow.python.feature_column import feature_column
@@ -49,12 +50,24 @@ def _make_train_input_fn(is_classification):
   """Makes train input_fn for classification/regression."""
 
   def _input_fn():
-    features = dict(FEATURES_DICT)
-    if is_classification:
-      labels = CLASSIFICATION_LABELS
-    else:
-      labels = REGRESSION_LABELS
-    return features, labels
+    features_dict = dict(FEATURES_DICT)
+    labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+    return features_dict, labels
+
+  return _input_fn
+
+
+def _make_train_input_fn_dataset(is_classification):
+  """Makes input_fn using Dataset."""
+
+  def _input_fn():
+    features_dict = dict(FEATURES_DICT)
+    labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+    ds = dataset_ops.Dataset.zip(
+        (dataset_ops.Dataset.from_tensors(features_dict),
+         dataset_ops.Dataset.from_tensors(labels)
+        ))
+    return ds
 
   return _input_fn
 
@@ -132,15 +145,13 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
         x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
 
     est = boosted_trees.boosted_trees_classifier_train_in_memory(
-        train_input_fn=train_input_fn,
-        feature_columns=self._feature_columns,
-        n_trees=1,
-        max_depth=5)
+        train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+        n_trees=1, max_depth=5)
     # It will stop after 5 steps because of the max depth and num trees.
     self._assert_checkpoint(
         est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
 
-    # Check eval.
+    # Check evaluate and predict.
     eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
     self.assertAllClose(eval_res['accuracy'], 1.0)
     # Validate predictions.
@@ -148,24 +159,59 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
     self.assertAllClose([[0], [1], [1], [0], [0]],
                         [pred['class_ids'] for pred in predictions])
 
+  def testBinaryClassifierTrainInMemoryWithDataset(self):
+    train_input_fn = _make_train_input_fn_dataset(is_classification=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.boosted_trees_classifier_train_in_memory(
+        train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+        n_trees=1, max_depth=5)
+    # It will stop after 5 steps because of the max depth and num trees.
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+
+    # Check evaluate and predict.
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 1.0)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose([[0], [1], [1], [0], [0]],
+                        [pred['class_ids'] for pred in predictions])
+
   def testRegressorTrainInMemoryAndEvalAndInfer(self):
     train_input_fn = _make_train_input_fn(is_classification=False)
     predict_input_fn = numpy_io.numpy_input_fn(
         x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
 
     est = boosted_trees.boosted_trees_regressor_train_in_memory(
-        train_input_fn=train_input_fn,
-        feature_columns=self._feature_columns,
-        n_trees=1,
-        max_depth=5)
+        train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+        n_trees=1, max_depth=5)
     # It will stop after 5 steps because of the max depth and num trees.
     self._assert_checkpoint(
         est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
 
-    # Check eval.
+    # Check evaluate and predict.
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 2.478283)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose(
+        [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+        [pred['predictions'] for pred in predictions])
+
+  def testRegressorTrainInMemoryWithDataset(self):
+    train_input_fn = _make_train_input_fn_dataset(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.boosted_trees_regressor_train_in_memory(
+        train_input_fn=train_input_fn, feature_columns=self._feature_columns,
+        n_trees=1, max_depth=5)
+    # It will stop after 5 steps because of the max depth and num trees.
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    # Check evaluate and predict.
     eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
     self.assertAllClose(eval_res['average_loss'], 2.478283)
-    # Validate predictions.
     predictions = list(est.predict(input_fn=predict_input_fn))
     self.assertAllClose(
         [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
index 536bd2b..085dace 100644 (file)
@@ -32,6 +32,7 @@ from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.summary import summary
@@ -50,6 +51,32 @@ _HOLD_FOR_MULTI_CLASS_SUPPORT = object()
 _HOLD_FOR_MULTI_DIM_SUPPORT = object()
 
 
+def _get_max_buckets(feature_columns):
+  """Gets the maximum number of buckets from feature_columns.
+
+  Args:
+    feature_columns: a list/set of tf.feature_column.
+
+  Returns:
+    max_buckets: the maximum number of buckets among bucketized_columns.
+
+  Raises:
+    ValueError: when unsupported feature_columns are given.
+  """
+  if not feature_columns:
+    raise ValueError('feature_columns must be a non-empty list/set of '
+                     'tf.feature_column.')
+  max_buckets = 1
+  for fc in feature_columns:
+    if isinstance(fc, feature_column_lib._BucketizedColumn):  # pylint:disable=protected-access
+      # N boundaries creates (N+1) buckets.
+      max_buckets = max(max_buckets, len(fc.boundaries) + 1)
+    else:
+      raise ValueError('For now, only bucketized_column is supported but '
+                       'got: {}'.format(fc))
+  return max_buckets
+
+
 def _get_transformed_features(features, feature_columns):
   """Gets the transformed features from features/feature_columns pair.
 
@@ -59,36 +86,31 @@ def _get_transformed_features(features, feature_columns):
 
   Returns:
     result_features: a list of the transformed features, sorted by the name.
-    num_buckets: the maximum number of buckets across bucketized_columns.
 
   Raises:
     ValueError: when unsupported features/columns are tried.
   """
-  num_buckets = 1
   # pylint:disable=protected-access
   for fc in feature_columns:
-    if isinstance(fc, feature_column_lib._BucketizedColumn):
-      # N boundaries creates (N+1) buckets.
-      num_buckets = max(num_buckets, len(fc.boundaries) + 1)
-    else:
+    if not isinstance(fc, feature_column_lib._BucketizedColumn):
       raise ValueError('For now, only bucketized_column is supported but '
                        'got: {}'.format(fc))
-  transformed = feature_column_lib._transform_features(features,
-                                                       feature_columns)
+  transformed_features = feature_column_lib._transform_features(
+      features, feature_columns)
   # pylint:enable=protected-access
   result_features = []
-  for column in sorted(transformed, key=lambda tc: tc.name):
+  for column in sorted(transformed_features, key=lambda tc: tc.name):
     source_name = column.source_column.name
-    squeezed_tensor = array_ops.squeeze(transformed[column], axis=1)
+    squeezed_tensor = array_ops.squeeze(transformed_features[column], axis=1)
     if len(squeezed_tensor.shape) > 1:
       raise ValueError('For now, only supports features equivalent to rank 1 '
                        'but column `{}` got: {}'.format(
                            source_name, features[source_name].shape))
     result_features.append(squeezed_tensor)
-  return result_features, num_buckets
+  return result_features
 
 
-def _keep_as_local_variable(tensor, name=None):
+def _local_variable(tensor, name=None):
   """Stores a tensor as a local Variable for faster read."""
   return variable_scope.variable(
       initial_value=tensor,
@@ -98,6 +120,48 @@ def _keep_as_local_variable(tensor, name=None):
       name=name)
 
 
+def _cache_transformed_features(features, feature_columns, batch_size):
+  """Transform features and cache, then returns (cached_features, cache_op)."""
+  num_features = len(feature_columns)
+  cached_features = [
+      _local_variable(
+          array_ops.zeros([batch_size], dtype=dtypes.int32),
+          name='cached_feature_{}'.format(i))
+      for i in range(num_features)
+  ]
+  are_features_cached = _local_variable(False, name='are_features_cached')
+
+  def cache_features_and_return():
+    """Caches transoformed features.
+
+    The intention is to hide get_transformed_features() from the graph by
+    caching the result except the first step, since bucketize operation
+    (inside get_transformed_features) is expensive.
+
+    Returns:
+      input_feature_list: a list of input features.
+      cache_flip_op: op to add to graph to make sure cache update is included to
+          the graph.
+    """
+
+    transformed_features = _get_transformed_features(features, feature_columns)
+    cached = [
+        state_ops.assign(cached_features[i], transformed_features[i])
+        for i in range(num_features)
+    ]
+    # TODO(youngheek): Try other combination of dependencies so that the
+    # function returns a single result, not a tuple.
+    with ops.control_dependencies(cached):
+      cache_flip_op = are_features_cached.assign(True)
+    return cached, cache_flip_op
+
+  input_feature_list, cache_flip_op = control_flow_ops.cond(
+      are_features_cached,
+      lambda: (cached_features, control_flow_ops.no_op()),
+      cache_features_and_return)
+  return input_feature_list, cache_flip_op
+
+
 class _CacheTrainingStatesUsingHashTable(object):
   """Caching logits, etc. using MutableHashTable."""
 
@@ -186,13 +250,13 @@ class _CacheTrainingStatesUsingVariables(object):
       logits_dimension: a constant (int) for the dimension of logits.
     """
     self._logits_dimension = logits_dimension
-    self._tree_ids = _keep_as_local_variable(
+    self._tree_ids = _local_variable(
         array_ops.zeros([batch_size], dtype=dtypes.int32),
         name='tree_ids_cache')
-    self._node_ids = _keep_as_local_variable(
+    self._node_ids = _local_variable(
         array_ops.zeros([batch_size], dtype=dtypes.int32),
         name='node_ids_cache')
-    self._logits = _keep_as_local_variable(
+    self._logits = _local_variable(
         array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
         name='logits_cache')
 
@@ -290,33 +354,38 @@ def _bt_model_fn(
         'When train_in_memory is enabled, input_fn should return the entire '
         'dataset as a single batch, and n_batches_per_layer should be set as '
         '1.')
+    if (not config.is_chief or config.num_worker_replicas > 1 or
+        config.num_ps_replicas > 0):
+      raise ValueError('train_in_memory is supported only for '
+                       'non-distributed training.')
   worker_device = control_flow_ops.no_op().device
   # maximum number of splits possible in the whole tree =2^(D-1)-1
   # TODO(youngheek): perhaps storage could be optimized by storing stats with
   # the dimension max_splits_per_layer, instead of max_splits (for the entire
   # tree).
   max_splits = (1 << tree_hparams.max_depth) - 1
+  max_buckets = _get_max_buckets(feature_columns)
+  train_op = []
   with ops.name_scope(name) as name:
     # Prepare.
     global_step = training_util.get_or_create_global_step()
-    input_feature_list, num_buckets = _get_transformed_features(
-        features, feature_columns)
-    if train_in_memory and mode == model_fn.ModeKeys.TRAIN:
-      input_feature_list = [
-          _keep_as_local_variable(feature) for feature in input_feature_list
-      ]
-    num_features = len(input_feature_list)
-
-    cache = None
-    if mode == model_fn.ModeKeys.TRAIN:
-      if train_in_memory and is_single_machine:  # maybe just train_in_memory?
-        batch_size = array_ops.shape(input_feature_list[0])[0]
-        cache = _CacheTrainingStatesUsingVariables(batch_size,
-                                                   head.logits_dimension)
-      elif example_id_column_name:
+    num_features = len(feature_columns)
+    # Extract input features and set up cache for training.
+    training_state_cache = None
+    if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
+      # cache transformed features as well for in-memory training.
+      batch_size = array_ops.shape(labels)[0]
+      input_feature_list, input_cache_op = _cache_transformed_features(
+          features, feature_columns, batch_size)
+      train_op.append(input_cache_op)
+      training_state_cache = _CacheTrainingStatesUsingVariables(
+          batch_size, head.logits_dimension)
+    else:
+      input_feature_list = _get_transformed_features(features, feature_columns)
+      if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
         example_ids = features[example_id_column_name]
-        cache = _CacheTrainingStatesUsingHashTable(example_ids,
-                                                   head.logits_dimension)
+        training_state_cache = _CacheTrainingStatesUsingHashTable(
+            example_ids, head.logits_dimension)
 
     # Create Ensemble resources.
     tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
@@ -340,11 +409,12 @@ def _bt_model_fn(
         # TODO(soroush): Do partial updates if this becomes a bottleneck.
         ensemble_reload = local_tree_ensemble.deserialize(
             *tree_ensemble.serialize())
-      if cache:
-        cached_tree_ids, cached_node_ids, cached_logits = cache.lookup()
+      if training_state_cache:
+        cached_tree_ids, cached_node_ids, cached_logits = (
+            training_state_cache.lookup())
       else:
         # Always start from the beginning when no cache is set up.
-        batch_size = array_ops.shape(input_feature_list[0])[0]
+        batch_size = array_ops.shape(labels)[0]
         cached_tree_ids, cached_node_ids, cached_logits = (
             array_ops.zeros([batch_size], dtype=dtypes.int32),
             array_ops.zeros([batch_size], dtype=dtypes.int32),
@@ -368,9 +438,8 @@ def _bt_model_fn(
     # Create training graph.
     def _train_op_fn(loss):
       """Run one training iteration."""
-      train_op = []
-      if cache:
-        train_op.append(cache.insert(tree_ids, node_ids, logits))
+      if training_state_cache:
+        train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
       if closed_form_grad_and_hess_fn:
         gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
       else:
@@ -385,7 +454,7 @@ def _bt_model_fn(
                   hessians=hessians,
                   bucketized_features_list=[input_feature_list[f]],
                   max_splits=max_splits,
-                  num_buckets=num_buckets),
+                  num_buckets=max_buckets),
               axis=0) for f in range(num_features)
       ]
 
@@ -422,7 +491,7 @@ def _bt_model_fn(
         summary_accumulator = data_flow_ops.ConditionalAccumulator(
             dtype=dtypes.float32,
             # The stats consist of gradients and hessians (the last dimension).
-            shape=[num_features, max_splits, num_buckets, 2],
+            shape=[num_features, max_splits, max_buckets, 2],
             shared_name='stats_summary_accumulator')
         apply_grad = summary_accumulator.apply_grad(
             array_ops.stack(stats_summary_list, axis=0), stamp_token)
index 56e67a6..c8c52d3 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import model_fn
 from tensorflow.python.estimator import run_config
 from tensorflow.python.estimator.canned import boosted_trees
@@ -58,13 +59,32 @@ def _make_train_input_fn(is_classification):
   """Makes train input_fn for classification/regression."""
 
   def _input_fn():
-    features = dict(FEATURES_DICT)
-    features[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
-    if is_classification:
-      labels = CLASSIFICATION_LABELS
+    features_dict = dict(FEATURES_DICT)
+    features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
+    labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+    return features_dict, labels
+
+  return _input_fn
+
+
+def _make_train_input_fn_dataset(is_classification, batch=None, repeat=None):
+  """Makes input_fn using Dataset."""
+
+  def _input_fn():
+    features_dict = dict(FEATURES_DICT)
+    features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
+    labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS
+    if batch:
+      ds = dataset_ops.Dataset.zip(
+          (dataset_ops.Dataset.from_tensor_slices(features_dict),
+           dataset_ops.Dataset.from_tensor_slices(labels))).batch(batch)
     else:
-      labels = REGRESSION_LABELS
-    return features, labels
+      ds = dataset_ops.Dataset.zip(
+          (dataset_ops.Dataset.from_tensors(features_dict),
+           dataset_ops.Dataset.from_tensors(labels)))
+    # repeat indefinitely by default, or stop at the given step.
+    ds = ds.repeat(repeat)
+    return ds
 
   return _input_fn
 
@@ -125,9 +145,28 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
     num_steps = 100
     # Train for a few steps, and validate final checkpoint.
     est.train(train_input_fn, steps=num_steps)
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose([[0], [1], [1], [0], [0]],
+                        [pred['class_ids'] for pred in predictions])
 
+  def testTrainClassifierWithDataset(self):
+    train_input_fn = _make_train_input_fn_dataset(is_classification=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+    est.train(train_input_fn, steps=100)  # will stop after 5 steps anyway.
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 1.0)
     predictions = list(est.predict(input_fn=predict_input_fn))
-    # All labels are correct.
     self.assertAllClose([[0], [1], [1], [0], [0]],
                         [pred['class_ids'] for pred in predictions])
 
@@ -166,12 +205,126 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
     est.train(train_input_fn, steps=num_steps)
     self._assert_checkpoint(
         est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose(
+        [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+        [pred['predictions'] for pred in predictions])
+
+  def testTrainRegressorWithDataset(self):
+    train_input_fn = _make_train_input_fn_dataset(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+    est.train(train_input_fn, steps=100)  # will stop after 5 steps anyway.
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 2.478283)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose(
+        [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+        [pred['predictions'] for pred in predictions])
+
+  def testTrainRegressorWithDatasetBatch(self):
+    # The batch_size as the entire data size should yield the same result as
+    # dataset without batching.
+    train_input_fn = _make_train_input_fn_dataset(
+        is_classification=False, batch=5)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+    est.train(train_input_fn, steps=100)  # will stop after 5 steps anyway.
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 2.478283)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose(
+        [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+        [pred['predictions'] for pred in predictions])
+
+  def testTrainRegressorWithDatasetLargerBatch(self):
+    # The batch_size as the multiple of the entire data size should still yield
+    # the same result.
+    train_input_fn = _make_train_input_fn_dataset(
+        is_classification=False, batch=15)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+    est.train(train_input_fn, steps=100)  # will stop after 5 steps anyway.
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 2.478283)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose(
+        [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+        [pred['predictions'] for pred in predictions])
+
+  def testTrainRegressorWithDatasetSmallerBatch(self):
+    # Even when using small batches, if (n_batches_per_layer * batch_size) makes
+    # the same entire data size, the result should be the same.
+    train_input_fn = _make_train_input_fn_dataset(
+        is_classification=False, batch=1)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
 
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=5,
+        n_trees=1,
+        max_depth=5)
+    # Train stops after (n_batches_per_layer * n_trees * max_depth) steps.
+    est.train(train_input_fn, steps=100)
+    self._assert_checkpoint(
+        est.model_dir, global_step=25, finalized_trees=1, attempted_layers=5)
+    # 5 batches = one epoch.
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=5)
+    self.assertAllClose(eval_res['average_loss'], 2.478283)
     predictions = list(est.predict(input_fn=predict_input_fn))
     self.assertAllClose(
         [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
         [pred['predictions'] for pred in predictions])
 
+  def testTrainRegressorWithDatasetWhenInputIsOverEarlier(self):
+    train_input_fn = _make_train_input_fn_dataset(
+        is_classification=False, repeat=3)  # to stop input after 3 steps.
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+    # Note that training will stop when input exhausts.
+    # This might not be a typical pattern, but dataset.repeat(3) causes
+    # the input stream to cease after 3 steps.
+    est.train(train_input_fn, steps=100)
+    self._assert_checkpoint(
+        est.model_dir, global_step=3, finalized_trees=0, attempted_layers=3)
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 3.777295)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose(
+        [[0.353850], [0.254100], [0.106850], [0.712100], [1.012100]],
+        [pred['predictions'] for pred in predictions])
+
 
 class ModelFnTests(test_util.TensorFlowTestCase):
   """Tests bt_model_fn including unexposed internal functionalities."""