Add hook for checkpointing input pipeline while training with Estimator.
authorSaurabh Saxena <srbs@google.com>
Sat, 12 May 2018 00:14:29 +0000 (17:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 12 May 2018 00:17:23 +0000 (17:17 -0700)
PiperOrigin-RevId: 196331223

tensorflow/contrib/data/__init__.py
tensorflow/contrib/data/python/ops/BUILD
tensorflow/contrib/data/python/ops/iterator_ops.py
tensorflow/contrib/data/python/ops/iterator_ops_test.py [new file with mode: 0644]
tensorflow/python/data/ops/iterator_ops.py

index 077cbba..4f2c72b 100644 (file)
@@ -72,6 +72,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window
 from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
 from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
 from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
+from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
 from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
 from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
 from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
index 5b04c53..144460f 100644 (file)
@@ -45,6 +45,27 @@ py_library(
         "//tensorflow/python:dataset_ops_gen",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:training",
+        "//tensorflow/python/data/ops:iterator_ops",
+    ],
+)
+
+py_test(
+    name = "iterator_ops_test",
+    size = "small",
+    srcs = ["iterator_ops_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["no_pip"],
+    deps = [
+        ":iterator_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/estimator",
+        "//tensorflow/python/estimator:model_fn",
     ],
 )
 
index d736029..f1d0e5c 100644 (file)
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-
+from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import saver
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
 
 
 def make_saveable_from_iterator(iterator):
@@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator):
   return _Saveable(iterator._iterator_resource)  # pylint: disable=protected-access
 
 
-class _Saveable(saver.BaseSaverBuilder.SaveableObject):
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
   """SaveableObject for saving/restoring iterator state."""
 
   def __init__(self, iterator_resource):
     serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
     specs = [
-        saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
-                                        iterator_resource.name + "-state")
+        saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+                                            iterator_resource.name + "-state")
     ]
     super(_Saveable, self).__init__(iterator_resource, specs,
                                     iterator_resource.name)
@@ -75,3 +77,160 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject):
   def restore(self, restored_tensors, unused_restored_shapes):
     with ops.colocate_with(self.op):
       return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+  """Checkpoints input pipeline state every N steps or seconds.
+
+  This hook saves the state of the iterators in the `Graph` so that when
+  training is resumed the input pipeline continues from where it left off.
+  This could potentially avoid overfitting in certain pipelines where the
+  number of training steps per eval are small compared to the dataset
+  size or if the training pipeline is pre-empted.
+
+  Differences from `CheckpointSaverHook`:
+  1. Saves only the input pipelines in the "iterators" collection and not the
+     global variables or other saveable objects.
+  2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+  Example of checkpointing the training pipeline:
+
+  ```python
+  est = tf.estimator.Estimator(model_fn)
+  while True:
+    est.train(
+        train_input_fn,
+        hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)],
+        steps=train_steps_per_eval)
+    # Note: We do not pass the hook here.
+    metrics = est.evaluate(eval_input_fn)
+    if should_stop_the_training(metrics):
+      break
+  ```
+
+  This hook should be used if the input pipeline state needs to be saved
+  separate from the model checkpoint. Doing so may be useful for a few reasons:
+  1. The input pipeline checkpoint may be large, if there are large shuffle
+     or prefetch buffers for instance, and may bloat the checkpoint size.
+  2. If the input pipeline is shared between training and validation, restoring
+     the checkpoint during validation may override the validation input
+     pipeline.
+
+  For saving the input pipeline checkpoint alongside the model weights use
+  @{tf.contrib.data.make_saveable_from_iterator} directly to create a
+  `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+  that you will need to be careful not to restore the training iterator during
+  eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+  collector when building the eval graph.
+  """
+
+  def __init__(self, estimator):
+    """Initializes a `CheckpointInputPipelineHook`.
+
+    Args:
+      estimator: Estimator.
+
+    Raises:
+      ValueError: One of `save_steps` or `save_secs` should be set.
+      ValueError: At most one of saver or scaffold should be set.
+    """
+    # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+    # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+    # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+    # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+    # to be different to avoid conflicts with the model checkpoint.
+
+    # pylint: disable=protected-access
+    checkpoint_prefix = "input"
+    if estimator._config.num_worker_replicas > 1:
+      # Distributed setting.
+      suffix = "_{}_{}".format(estimator._config.task_type,
+                               estimator._config.task_id)
+      checkpoint_prefix += suffix
+    # pylint: enable=protected-access
+
+    # We use a composition paradigm instead of inheriting from
+    # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+    # to check whether a `CheckpointSaverHook` is already present in the list
+    # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+    # would thwart this behavior. This hook checkpoints *only the iterators*
+    # and not the graph variables.
+    self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+        estimator.model_dir,
+        save_secs=estimator._config.save_checkpoints_secs,  # pylint: disable=protected-access
+        save_steps=estimator._config.save_checkpoints_steps,  # pylint: disable=protected-access
+        checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+    # Name for the protocol buffer file that will contain the list of most
+    # recent checkpoints stored as a `CheckpointState` protocol buffer.
+    # This file, kept in the same directory as the checkpoint files, is
+    # automatically managed by the `Saver` to keep track of recent checkpoints.
+    # The default name used by the `Saver` for this file is "checkpoint". Here
+    # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+    # `checkpoint_dir` is the same as the model checkpoint directory, there are
+    # no conflicts during restore.
+    self._latest_filename = "checkpoint_" + checkpoint_prefix
+
+  def begin(self):
+    # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+    # collection if no `Saver` or `Scaffold` is provided.
+    # pylint: disable=protected-access
+    if (self._checkpoint_saver_hook._saver is None and
+        self._checkpoint_saver_hook._scaffold is None):
+      iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+      saveables = [_Saveable(i) for i in iterators]
+      self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+                                                        self._latest_filename)
+    # pylint: enable=protected-access
+    self._checkpoint_saver_hook.begin()
+
+  def after_create_session(self, session, coord):
+    # Check if there is an existing checkpoint. If so, restore from it.
+    # pylint: disable=protected-access
+    latest_checkpoint_path = saver_lib.latest_checkpoint(
+        self._checkpoint_saver_hook._checkpoint_dir,
+        latest_filename=self._latest_filename)
+    if latest_checkpoint_path:
+      self._checkpoint_saver_hook._get_saver().restore(session,
+                                                       latest_checkpoint_path)
+    else:
+      # The checkpoint saved here is the state at step "global_step".
+      # Note: We do not save the GraphDef or MetaGraphDef here.
+      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+      self._checkpoint_saver_hook._save(session, global_step)
+      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+    # pylint: enable=protected-access
+
+  def before_run(self, run_context):
+    return self._checkpoint_saver_hook.before_run(run_context)
+
+  def after_run(self, run_context, run_values):
+    self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+  def end(self, session):
+    self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+  """`Saver` with a different default `latest_filename`.
+
+  This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+  the model ckpt saved by the `CheckpointSaverHook`.
+  """
+
+  def __init__(self, var_list, latest_filename):
+    super(_CustomSaver, self).__init__(var_list)
+    self._latest_filename = latest_filename
+
+  def save(self,
+           sess,
+           save_path,
+           global_step=None,
+           latest_filename=None,
+           meta_graph_suffix="meta",
+           write_meta_graph=True,
+           write_state=True,
+           strip_default_attrs=False):
+    return super(_CustomSaver, self).save(
+        sess, save_path, global_step, latest_filename or self._latest_filename,
+        meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py
new file mode 100644 (file)
index 0000000..30a993b
--- /dev/null
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental iterator_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import training_util
+
+
+class CheckpointInputPipelineHookTest(test.TestCase):
+
+  @staticmethod
+  def _model_fn(features, labels, mode, config):
+    del labels
+    del mode
+    del config
+    global_step = training_util.get_or_create_global_step()
+    update_global_step_op = global_step.assign_add(1)
+    latest_feature = variables.Variable(
+        0, name='latest_feature', dtype=dtypes.int64)
+    store_latest_feature_op = latest_feature.assign(features)
+    ops.add_to_collection('my_vars', global_step)
+    ops.add_to_collection('my_vars', latest_feature)
+    return model_fn.EstimatorSpec(
+        mode='train',
+        train_op=control_flow_ops.group(
+            [update_global_step_op, store_latest_feature_op]),
+        loss=constant_op.constant(2.0))
+
+  def _read_vars(self, model_dir):
+    """Returns (global_step, latest_feature)."""
+    with ops.Graph().as_default() as g:
+      ckpt_path = saver_lib.latest_checkpoint(model_dir)
+      meta_filename = ckpt_path + '.meta'
+      saver_lib.import_meta_graph(meta_filename)
+      saver = saver_lib.Saver()
+      with self.test_session(graph=g) as sess:
+        saver.restore(sess, ckpt_path)
+        return sess.run(ops.get_collection('my_vars'))
+
+  def _build_iterator_saver_hook(self, est):
+    return iterator_ops.CheckpointInputPipelineHook(est)
+
+  def testReturnDatasetFromInputFn(self):
+
+    def _input_fn():
+      return dataset_ops.Dataset.range(10)
+
+    est = estimator.Estimator(model_fn=self._model_fn)
+
+    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+  def testBuildIteratorInInputFn(self):
+
+    def _input_fn():
+      ds = dataset_ops.Dataset.range(10)
+      iterator = ds.make_one_shot_iterator()
+      return iterator.get_next()
+
+    est = estimator.Estimator(model_fn=self._model_fn)
+
+    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+  def testDoNotRestore(self):
+
+    def _input_fn():
+      return dataset_ops.Dataset.range(10)
+
+    est = estimator.Estimator(model_fn=self._model_fn)
+
+    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+    # Hook not provided, input pipeline was not restored.
+    est.train(_input_fn, steps=2)
+    self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
+
+  def testRaiseErrorIfNoIterator(self):
+
+    def _input_fn():
+      return constant_op.constant(1, dtype=dtypes.int64)
+
+    est = estimator.Estimator(model_fn=self._model_fn)
+
+    with self.assertRaises(ValueError):
+      est.train(
+          _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+
+
+if __name__ == '__main__':
+  test.main()
index 0c76afd..fd16427 100644 (file)
@@ -52,6 +52,9 @@ GET_NEXT_CALL_WARNING_MESSAGE = (
     "`next_element` as the input to some computation that is invoked inside "
     "the loop.")
 
+# Collection of all IteratorResources in the `Graph`.
+GLOBAL_ITERATORS = "iterators"
+
 
 @tf_export("data.Iterator")
 class Iterator(object):
@@ -75,8 +78,7 @@ class Iterator(object):
       output_shapes: A nested structure of `tf.TensorShape` objects
         corresponding to each component of an element of this dataset.
       output_classes: A nested structure of Python `type` object corresponding
-        to each
-        component of an element of this iterator.
+        to each component of an element of this iterator.
     """
     self._iterator_resource = iterator_resource
     self._initializer = initializer
@@ -86,6 +88,7 @@ class Iterator(object):
     self._string_handle = gen_dataset_ops.iterator_to_string_handle(
         self._iterator_resource)
     self._get_next_call_count = 0
+    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
 
   @staticmethod
   def from_structure(output_types,