DistributionStrategy-enable Estimator.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 29 Mar 2018 04:52:30 +0000 (21:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 04:54:59 +0000 (21:54 -0700)
PiperOrigin-RevId: 190882152

tensorflow/contrib/learn/python/learn/estimators/run_config.py
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/run_config.py

index 1d16109..f3500bf 100644 (file)
@@ -290,8 +290,15 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
         Note - using this argument, it is easy to provide settings which break
         otherwise perfectly good models. Use with care.
     """
-    super(RunConfig, self).__init__(
-        master=master, evaluation_master=evaluation_master)
+    # Neither parent class calls super().__init__(), so here we have to
+    # manually call their __init__() methods.
+    ClusterConfig.__init__(
+        self, master=master, evaluation_master=evaluation_master)
+    # For too long this code didn't call:
+    #   core_run_config.RunConfig.__init__(self)
+    # so instead of breaking compatibility with that assumption, we
+    # just manually initialize this field:
+    self._distribute = None
 
     gpu_options = config_pb2.GPUOptions(
         per_process_gpu_memory_fraction=gpu_memory_fraction)
index 6a4132b..2fe521b 100644 (file)
@@ -41,8 +41,11 @@ from tensorflow.python.estimator.export.export import get_temp_export_dir
 from tensorflow.python.estimator.export.export import get_timestamped_export_dir
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import builder as saved_model_builder
@@ -50,6 +53,7 @@ from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.summary import summary
 from tensorflow.python.summary.writer import writer_cache
 from tensorflow.python.training import device_setter
+from tensorflow.python.training import distribute as distribute_lib
 from tensorflow.python.training import evaluation
 from tensorflow.python.training import monitored_session
 from tensorflow.python.training import saver
@@ -183,6 +187,9 @@ class Estimator(object):
             config)
       self._config = config
 
+    # The distribute field contains an instance of DistributionStrategy.
+    self._distribution = self._config.distribute
+
     # Model directory.
     model_dir = compat_internal.path_to_str(model_dir)
     if (model_dir is not None) and (self._config.model_dir is not None):
@@ -682,11 +689,25 @@ class Estimator(object):
   def _get_features_and_labels_from_input_fn(self, input_fn, mode):
     """Extracts the `features` and labels from return values of `input_fn`."""
     result = self._call_input_fn(input_fn, mode)
+    # TODO(anjalisridhar): What about the default DistributionStrategy? Perhaps
+    # using any input is alright in that case. There is also a
+    # has_dataset_or_queue_runner function that we may want to extend and use.
+    if (self._distribution is not None and
+        not isinstance(result, dataset_ops.Dataset)):
+      raise ValueError('input_fn() must return a tf.data.Dataset when using a '
+                       'DistributionStrategy.')
     input_hooks = []
     if isinstance(result, dataset_ops.Dataset):
-      iterator = result.make_initializable_iterator()
-      input_hooks.append(_DatasetInitializerHook(iterator))
-      result = iterator.get_next()
+      if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
+        # TODO(josh11b): This is currently using a one-shot iterator, we
+        # will update this to an initializeable iterator once the
+        # necessory support for creating an initializable iterator is
+        # available.
+        result = self._distribution.distribute_dataset(result).get_next()
+      else:
+        iterator = result.make_initializable_iterator()
+        input_hooks.append(_DatasetInitializerHook(iterator))
+        result = iterator.get_next()
     if isinstance(result, (list, tuple)):
       if len(result) != 2:
         raise ValueError(
@@ -815,6 +836,12 @@ class Estimator(object):
     return model_fn_results
 
   def _train_model(self, input_fn, hooks, saving_listeners):
+    if self._distribution:
+      return self._train_model_distributed(input_fn, hooks, saving_listeners)
+    else:
+      return self._train_model_default(input_fn, hooks, saving_listeners)
+
+  def _train_model_default(self, input_fn, hooks, saving_listeners):
     worker_hooks = []
     with ops.Graph().as_default() as g, g.device(self._device_fn):
       random_seed.set_random_seed(self._config.tf_random_seed)
@@ -826,86 +853,209 @@ class Estimator(object):
       worker_hooks.extend(input_hooks)
       estimator_spec = self._call_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
+      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
+                                             hooks, global_step_tensor,
+                                             saving_listeners)
 
-      if self._warm_start_settings:
-        logging.info('Warm-starting with WarmStartSettings: %s' %
-                     (self._warm_start_settings,))
-        # pylint: disable=protected-access
-        warm_starting_util.warm_start(*self._warm_start_settings)
-        # pylint: enable=protected-access
-      # Check if the user created a loss summary, and add one if they didn't.
-      # We assume here that the summary is called 'loss'. If it is not, we will
-      # make another one with the name 'loss' to ensure it shows up in the right
-      # graph in TensorBoard.
-      if not any([x.op.name == 'loss'
-                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
-        summary.scalar('loss', estimator_spec.loss)
-      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
-      worker_hooks.extend(hooks)
-      worker_hooks.extend([
-          training.NanTensorHook(estimator_spec.loss),
-          training.LoggingTensorHook(
-              {
-                  'loss': estimator_spec.loss,
-                  'step': global_step_tensor
-              },
-              every_n_iter=self._config.log_step_count_steps)
-      ])
-      worker_hooks.extend(estimator_spec.training_hooks)
-
-      if not (estimator_spec.scaffold.saver or
-              ops.get_collection(ops.GraphKeys.SAVERS)):
-        ops.add_to_collection(
-            ops.GraphKeys.SAVERS,
-            training.Saver(
-                sharded=True,
-                max_to_keep=self._config.keep_checkpoint_max,
-                keep_checkpoint_every_n_hours=(
-                    self._config.keep_checkpoint_every_n_hours),
-                defer_build=True,
-                save_relative_paths=True))
-
-      chief_hooks = []
-      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
-      saver_hooks = [
-          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
-      if (self._config.save_checkpoints_secs or
-          self._config.save_checkpoints_steps):
-        if not saver_hooks:
-          chief_hooks = [
-              training.CheckpointSaverHook(
-                  self._model_dir,
-                  save_secs=self._config.save_checkpoints_secs,
-                  save_steps=self._config.save_checkpoints_steps,
-                  scaffold=estimator_spec.scaffold)
-          ]
-          saver_hooks = [chief_hooks[0]]
-      if saving_listeners:
-        if not saver_hooks:
-          raise ValueError(
-              'There should be a CheckpointSaverHook to use saving_listeners. '
-              'Please set one of the RunConfig.save_checkpoints_steps or '
-              'RunConfig.save_checkpoints_secs.')
+  def _train_model_distributed(self, input_fn, hooks, saving_listeners):
+    worker_hooks = []
+    with ops.Graph().as_default() as g:
+      with self._distribution.scope():
+        random_seed.set_random_seed(self._config.tf_random_seed)
+        features, labels, input_hooks = (
+            self._get_features_and_labels_from_input_fn(
+                input_fn, model_fn_lib.ModeKeys.TRAIN))
+        worker_hooks.extend(input_hooks)
+        global_step_tensor = self._create_and_assert_global_step(g)
+        # The default destination for the global_step_tensor fetch call is the
+        # CPU.
+        global_step_read_tensor = self._distribution.fetch(global_step_tensor)
+        # we want to add to the global collection in the main thread not the
+        # tower threads.
+        ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
+                              global_step_read_tensor)
+        grouped_estimator_spec = self._distribution.call_for_each_tower(
+            self._call_model_fn,
+            features,
+            labels,  # although this will be None it seems
+            model_fn_lib.ModeKeys.TRAIN,
+            self.config)
+
+        # TODO(anjalisridhar): Figure out how to resolve the folowing scaffold
+        # parameters: init_feed_dict, init_fn.
+        scaffold_list = self._distribution.unwrap(
+            grouped_estimator_spec.scaffold)
+        init_feed_dict = [
+            s.init_feed_dict
+            for s in scaffold_list
+            if s.init_feed_dict is not None
+        ]
+        if init_feed_dict:
+          init_feed_dict = self._distribution.group(init_feed_dict)
         else:
-          # It is expected to have one CheckpointSaverHook. If multiple, we pick
-          # up the first one to add listener.
-          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
-      with training.MonitoredTrainingSession(
-          master=self._config.master,
-          is_chief=self._config.is_chief,
-          checkpoint_dir=self._model_dir,
-          scaffold=estimator_spec.scaffold,
-          hooks=worker_hooks,
-          chief_only_hooks=(
-              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
-          save_checkpoint_secs=0,  # Saving is handled by a hook.
-          save_summaries_steps=self._config.save_summary_steps,
-          config=self._session_config,
-          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
-        loss = None
-        while not mon_sess.should_stop():
-          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
-      return loss
+          init_feed_dict = None
+
+        init_fn = [s.init_fn for s in scaffold_list if s.init_fn is not None]
+        if init_fn:
+          init_fn = self._distribution.group(init_fn)
+        else:
+          init_fn = None
+
+        init_op = [s.init_op for s in scaffold_list if s.init_op is not None]
+        if init_op:
+          init_op = self._distribution.group(init_op)
+        else:
+          init_op = None
+
+        ready_op = self._distribution.call_for_each_tower(
+            create_per_tower_ready_op, grouped_estimator_spec.scaffold)
+        if ready_op is not None:
+          ready_op = self._distribution.group(ready_op)
+        else:
+          ready_op = None
+
+        ready_for_local_init_op = self._distribution.call_for_each_tower(
+            create_per_tower_ready_for_local_init_op,
+            grouped_estimator_spec.scaffold)
+        if ready_for_local_init_op is not None:
+          ready_for_local_init_op = self._distribution.group(
+              ready_for_local_init_op)
+        else:
+          ready_for_local_init_op = None
+
+        local_init_op = [
+            s.local_init_op
+            for s in scaffold_list
+            if s.local_init_op is not None
+        ]
+        if local_init_op:
+          local_init_op = self._distribution.group(local_init_op)
+        else:
+          local_init_op = None
+
+        summary_op = [
+            s.summary_op for s in scaffold_list if s.summary_op is not None
+        ]
+        if summary_op:
+          summary_op = self._distribution.group(summary_op)
+        else:
+          summary_op = None
+
+        scaffold = monitored_session.Scaffold(
+            init_op=init_op,
+            ready_op=ready_op,
+            ready_for_local_init_op=ready_for_local_init_op,
+            local_init_op=local_init_op,
+            summary_op=summary_op,
+            init_feed_dict=init_feed_dict,
+            init_fn=init_fn)
+
+        def get_hooks_from_the_first_device(per_device_hooks):
+          hooks_list = self._distribution.unwrap(per_device_hooks)
+          assert hooks_list
+          return hooks_list[0]
+
+        training_hooks = get_hooks_from_the_first_device(
+            grouped_estimator_spec.training_hooks)
+        training_chief_hooks = get_hooks_from_the_first_device(
+            grouped_estimator_spec.training_chief_hooks)
+
+        estimator_spec = model_fn_lib.EstimatorSpec(
+            mode=grouped_estimator_spec.mode,
+            loss=self._distribution.unwrap(
+                self._distribution.reduce(distribute_lib.get_loss_reduction(),
+                                          grouped_estimator_spec.loss,
+                                          destinations='/device:CPU:0'))[0],
+            train_op=self._distribution.group(grouped_estimator_spec.train_op),
+            training_hooks=training_hooks,
+            training_chief_hooks=training_chief_hooks,
+            scaffold=scaffold)
+        return self._train_with_estimator_spec(estimator_spec, worker_hooks,
+                                               hooks, global_step_read_tensor,
+                                               saving_listeners)
+
+  def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
+                                 global_step_tensor, saving_listeners):
+    """Train a model with the given Estimator Spec."""
+    if self._warm_start_settings:
+      logging.info('Warm-starting with WarmStartSettings: %s' %
+                   (self._warm_start_settings,))
+      # pylint: disable=protected-access
+      warm_starting_util.warm_start(*self._warm_start_settings)
+      # pylint: enable=protected-access
+    # Check if the user created a loss summary, and add one if they didn't.
+    # We assume here that the summary is called 'loss'. If it is not, we will
+    # make another one with the name 'loss' to ensure it shows up in the right
+    # graph in TensorBoard.
+    if not any([x.op.name == 'loss'
+                for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
+      summary.scalar('loss', estimator_spec.loss)
+    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
+    worker_hooks.extend(hooks)
+    worker_hooks.extend([
+        training.NanTensorHook(estimator_spec.loss),
+        training.LoggingTensorHook(
+            {
+                'loss': estimator_spec.loss,
+                'step': global_step_tensor
+            },
+            every_n_iter=self._config.log_step_count_steps)
+    ])
+    worker_hooks.extend(estimator_spec.training_hooks)
+
+    if not (estimator_spec.scaffold.saver or
+            ops.get_collection(ops.GraphKeys.SAVERS)):
+      ops.add_to_collection(
+          ops.GraphKeys.SAVERS,
+          training.Saver(
+              sharded=True,
+              max_to_keep=self._config.keep_checkpoint_max,
+              keep_checkpoint_every_n_hours=(
+                  self._config.keep_checkpoint_every_n_hours),
+              defer_build=True,
+              save_relative_paths=True))
+
+    chief_hooks = []
+    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
+    saver_hooks = [
+        h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
+    if (self._config.save_checkpoints_secs or
+        self._config.save_checkpoints_steps):
+      if not saver_hooks:
+        chief_hooks = [
+            training.CheckpointSaverHook(
+                self._model_dir,
+                save_secs=self._config.save_checkpoints_secs,
+                save_steps=self._config.save_checkpoints_steps,
+                scaffold=estimator_spec.scaffold)
+        ]
+        saver_hooks = [chief_hooks[0]]
+    if saving_listeners:
+      if not saver_hooks:
+        raise ValueError(
+            'There should be a CheckpointSaverHook to use saving_listeners. '
+            'Please set one of the RunConfig.save_checkpoints_steps or '
+            'RunConfig.save_checkpoints_secs.')
+      else:
+        # It is expected to have one CheckpointSaverHook. If multiple, we pick
+        # up the first one to add listener.
+        saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
+    with training.MonitoredTrainingSession(
+        master=self._config.master,
+        is_chief=self._config.is_chief,
+        checkpoint_dir=self._model_dir,
+        scaffold=estimator_spec.scaffold,
+        hooks=worker_hooks,
+        chief_only_hooks=(
+            tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
+        save_checkpoint_secs=0,  # Saving is handled by a hook.
+        save_summaries_steps=self._config.save_summary_steps,
+        config=self._session_config,
+        log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+      loss = None
+      while not mon_sess.should_stop():
+        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
+    return loss
 
   def _evaluate_model(self,
                       input_fn,
@@ -972,6 +1122,35 @@ class Estimator(object):
     return eval_results
 
 
+def create_per_tower_ready_op(scaffold):
+  """Create a Scaffold.ready_op inside a tower."""
+  if scaffold.ready_op:
+    return scaffold.ready_op
+
+  def default_ready_op():
+    return array_ops.concat([
+        variables.report_uninitialized_variables(),
+        resources.report_uninitialized_resources()
+    ], 0)
+
+  return monitored_session.Scaffold.get_or_default(
+      'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
+
+
+def create_per_tower_ready_for_local_init_op(scaffold):
+  """Create a Scaffold.ready_for_local_init_op inside a tower."""
+  if scaffold.ready_for_local_init_op:
+    return scaffold.ready_for_local_init_op
+
+  def default_ready_for_local_init_op():
+    return variables.report_uninitialized_variables(
+        variables.global_variables())
+
+  return monitored_session.Scaffold.get_or_default(
+      'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
+      default_ready_for_local_init_op)
+
+
 def _check_checkpoint_available(model_dir):
   latest_path = saver.latest_checkpoint(model_dir)
   if not latest_path:
index 141eaef..41415b8 100644 (file)
@@ -688,7 +688,7 @@ class RunConfig(object):
 
     Only the properties in the following list are allowed to be replaced:
 
-      - `model_dir`.
+      - `model_dir`,
       - `tf_random_seed`,
       - `save_summary_steps`,
       - `save_checkpoints_steps`,
@@ -697,6 +697,7 @@ class RunConfig(object):
       - `keep_checkpoint_max`,
       - `keep_checkpoint_every_n_hours`,
       - `log_step_count_steps`,
+      - `distribute`.
 
     In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
     can be set (should not be both).