From: Saurabh Saxena Date: Sat, 12 May 2018 00:14:29 +0000 (-0700) Subject: Add hook for checkpointing input pipeline while training with Estimator. X-Git-Tag: upstream/v1.9.0_rc1~116^2^2~76 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d43cb8d7358fecacef076fdab42dae03911edfc5;p=platform%2Fupstream%2Ftensorflow.git Add hook for checkpointing input pipeline while training with Estimator. PiperOrigin-RevId: 196331223 --- diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 077cbba..4f2c72b 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -72,6 +72,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 5b04c53..144460f 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -45,6 +45,27 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029..f1d0e5c 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,10 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import saver +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook def make_saveable_from_iterator(iterator): @@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator): return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access -class _Saveable(saver.BaseSaverBuilder.SaveableObject): +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): """SaveableObject for saving/restoring iterator state.""" def __init__(self, iterator_resource): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ - saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") ] super(_Saveable, self).__init__(iterator_resource, specs, iterator_resource.name) @@ -75,3 +77,160 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, unused_restored_shapes): with ops.colocate_with(self.op): return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input__.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def after_create_session(self, session, coord): + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = saver_lib.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py new file mode 100644 index 0000000..30a993b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test.TestCase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.Variable( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = saver_lib.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 0c76afd..fd16427 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -52,6 +52,9 @@ GET_NEXT_CALL_WARNING_MESSAGE = ( "`next_element` as the input to some computation that is invoked inside " "the loop.") +# Collection of all IteratorResources in the `Graph`. +GLOBAL_ITERATORS = "iterators" + @tf_export("data.Iterator") class Iterator(object): @@ -75,8 +78,7 @@ class Iterator(object): output_shapes: A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. output_classes: A nested structure of Python `type` object corresponding - to each - component of an element of this iterator. + to each component of an element of this iterator. """ self._iterator_resource = iterator_resource self._initializer = initializer @@ -86,6 +88,7 @@ class Iterator(object): self._string_handle = gen_dataset_ops.iterator_to_string_handle( self._iterator_resource) self._get_next_call_count = 0 + ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource) @staticmethod def from_structure(output_types,