from tensorflow.python.estimator.export.export import get_timestamped_export_dir
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import device_setter
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import evaluation
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
config)
self._config = config
+ # The distribute field contains an instance of DistributionStrategy.
+ self._distribution = self._config.distribute
+
# Model directory.
model_dir = compat_internal.path_to_str(model_dir)
if (model_dir is not None) and (self._config.model_dir is not None):
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
"""Extracts the `features` and labels from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
+ # TODO(anjalisridhar): What about the default DistributionStrategy? Perhaps
+ # using any input is alright in that case. There is also a
+ # has_dataset_or_queue_runner function that we may want to extend and use.
+ if (self._distribution is not None and
+ not isinstance(result, dataset_ops.Dataset)):
+ raise ValueError('input_fn() must return a tf.data.Dataset when using a '
+ 'DistributionStrategy.')
input_hooks = []
if isinstance(result, dataset_ops.Dataset):
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
+ if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
+ # TODO(josh11b): This is currently using a one-shot iterator, we
+ # will update this to an initializeable iterator once the
+ # necessory support for creating an initializable iterator is
+ # available.
+ result = self._distribution.distribute_dataset(result).get_next()
+ else:
+ iterator = result.make_initializable_iterator()
+ input_hooks.append(_DatasetInitializerHook(iterator))
+ result = iterator.get_next()
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
return model_fn_results
def _train_model(self, input_fn, hooks, saving_listeners):
+ if self._distribution:
+ return self._train_model_distributed(input_fn, hooks, saving_listeners)
+ else:
+ return self._train_model_default(input_fn, hooks, saving_listeners)
+
+ def _train_model_default(self, input_fn, hooks, saving_listeners):
worker_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
+ return self._train_with_estimator_spec(estimator_spec, worker_hooks,
+ hooks, global_step_tensor,
+ saving_listeners)
- if self._warm_start_settings:
- logging.info('Warm-starting with WarmStartSettings: %s' %
- (self._warm_start_settings,))
- # pylint: disable=protected-access
- warm_starting_util.warm_start(*self._warm_start_settings)
- # pylint: enable=protected-access
- # Check if the user created a loss summary, and add one if they didn't.
- # We assume here that the summary is called 'loss'. If it is not, we will
- # make another one with the name 'loss' to ensure it shows up in the right
- # graph in TensorBoard.
- if not any([x.op.name == 'loss'
- for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
- summary.scalar('loss', estimator_spec.loss)
- ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
- worker_hooks.extend(hooks)
- worker_hooks.extend([
- training.NanTensorHook(estimator_spec.loss),
- training.LoggingTensorHook(
- {
- 'loss': estimator_spec.loss,
- 'step': global_step_tensor
- },
- every_n_iter=self._config.log_step_count_steps)
- ])
- worker_hooks.extend(estimator_spec.training_hooks)
-
- if not (estimator_spec.scaffold.saver or
- ops.get_collection(ops.GraphKeys.SAVERS)):
- ops.add_to_collection(
- ops.GraphKeys.SAVERS,
- training.Saver(
- sharded=True,
- max_to_keep=self._config.keep_checkpoint_max,
- keep_checkpoint_every_n_hours=(
- self._config.keep_checkpoint_every_n_hours),
- defer_build=True,
- save_relative_paths=True))
-
- chief_hooks = []
- all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
- saver_hooks = [
- h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
- if (self._config.save_checkpoints_secs or
- self._config.save_checkpoints_steps):
- if not saver_hooks:
- chief_hooks = [
- training.CheckpointSaverHook(
- self._model_dir,
- save_secs=self._config.save_checkpoints_secs,
- save_steps=self._config.save_checkpoints_steps,
- scaffold=estimator_spec.scaffold)
- ]
- saver_hooks = [chief_hooks[0]]
- if saving_listeners:
- if not saver_hooks:
- raise ValueError(
- 'There should be a CheckpointSaverHook to use saving_listeners. '
- 'Please set one of the RunConfig.save_checkpoints_steps or '
- 'RunConfig.save_checkpoints_secs.')
+ def _train_model_distributed(self, input_fn, hooks, saving_listeners):
+ worker_hooks = []
+ with ops.Graph().as_default() as g:
+ with self._distribution.scope():
+ random_seed.set_random_seed(self._config.tf_random_seed)
+ features, labels, input_hooks = (
+ self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN))
+ worker_hooks.extend(input_hooks)
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # The default destination for the global_step_tensor fetch call is the
+ # CPU.
+ global_step_read_tensor = self._distribution.fetch(global_step_tensor)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
+ global_step_read_tensor)
+ grouped_estimator_spec = self._distribution.call_for_each_tower(
+ self._call_model_fn,
+ features,
+ labels, # although this will be None it seems
+ model_fn_lib.ModeKeys.TRAIN,
+ self.config)
+
+ # TODO(anjalisridhar): Figure out how to resolve the folowing scaffold
+ # parameters: init_feed_dict, init_fn.
+ scaffold_list = self._distribution.unwrap(
+ grouped_estimator_spec.scaffold)
+ init_feed_dict = [
+ s.init_feed_dict
+ for s in scaffold_list
+ if s.init_feed_dict is not None
+ ]
+ if init_feed_dict:
+ init_feed_dict = self._distribution.group(init_feed_dict)
else:
- # It is expected to have one CheckpointSaverHook. If multiple, we pick
- # up the first one to add listener.
- saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
- with training.MonitoredTrainingSession(
- master=self._config.master,
- is_chief=self._config.is_chief,
- checkpoint_dir=self._model_dir,
- scaffold=estimator_spec.scaffold,
- hooks=worker_hooks,
- chief_only_hooks=(
- tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
- save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
- config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
- loss = None
- while not mon_sess.should_stop():
- _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
- return loss
+ init_feed_dict = None
+
+ init_fn = [s.init_fn for s in scaffold_list if s.init_fn is not None]
+ if init_fn:
+ init_fn = self._distribution.group(init_fn)
+ else:
+ init_fn = None
+
+ init_op = [s.init_op for s in scaffold_list if s.init_op is not None]
+ if init_op:
+ init_op = self._distribution.group(init_op)
+ else:
+ init_op = None
+
+ ready_op = self._distribution.call_for_each_tower(
+ create_per_tower_ready_op, grouped_estimator_spec.scaffold)
+ if ready_op is not None:
+ ready_op = self._distribution.group(ready_op)
+ else:
+ ready_op = None
+
+ ready_for_local_init_op = self._distribution.call_for_each_tower(
+ create_per_tower_ready_for_local_init_op,
+ grouped_estimator_spec.scaffold)
+ if ready_for_local_init_op is not None:
+ ready_for_local_init_op = self._distribution.group(
+ ready_for_local_init_op)
+ else:
+ ready_for_local_init_op = None
+
+ local_init_op = [
+ s.local_init_op
+ for s in scaffold_list
+ if s.local_init_op is not None
+ ]
+ if local_init_op:
+ local_init_op = self._distribution.group(local_init_op)
+ else:
+ local_init_op = None
+
+ summary_op = [
+ s.summary_op for s in scaffold_list if s.summary_op is not None
+ ]
+ if summary_op:
+ summary_op = self._distribution.group(summary_op)
+ else:
+ summary_op = None
+
+ scaffold = monitored_session.Scaffold(
+ init_op=init_op,
+ ready_op=ready_op,
+ ready_for_local_init_op=ready_for_local_init_op,
+ local_init_op=local_init_op,
+ summary_op=summary_op,
+ init_feed_dict=init_feed_dict,
+ init_fn=init_fn)
+
+ def get_hooks_from_the_first_device(per_device_hooks):
+ hooks_list = self._distribution.unwrap(per_device_hooks)
+ assert hooks_list
+ return hooks_list[0]
+
+ training_hooks = get_hooks_from_the_first_device(
+ grouped_estimator_spec.training_hooks)
+ training_chief_hooks = get_hooks_from_the_first_device(
+ grouped_estimator_spec.training_chief_hooks)
+
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=grouped_estimator_spec.mode,
+ loss=self._distribution.unwrap(
+ self._distribution.reduce(distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0],
+ train_op=self._distribution.group(grouped_estimator_spec.train_op),
+ training_hooks=training_hooks,
+ training_chief_hooks=training_chief_hooks,
+ scaffold=scaffold)
+ return self._train_with_estimator_spec(estimator_spec, worker_hooks,
+ hooks, global_step_read_tensor,
+ saving_listeners)
+
+ def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
+ global_step_tensor, saving_listeners):
+ """Train a model with the given Estimator Spec."""
+ if self._warm_start_settings:
+ logging.info('Warm-starting with WarmStartSettings: %s' %
+ (self._warm_start_settings,))
+ # pylint: disable=protected-access
+ warm_starting_util.warm_start(*self._warm_start_settings)
+ # pylint: enable=protected-access
+ # Check if the user created a loss summary, and add one if they didn't.
+ # We assume here that the summary is called 'loss'. If it is not, we will
+ # make another one with the name 'loss' to ensure it shows up in the right
+ # graph in TensorBoard.
+ if not any([x.op.name == 'loss'
+ for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
+ summary.scalar('loss', estimator_spec.loss)
+ ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
+ worker_hooks.extend(hooks)
+ worker_hooks.extend([
+ training.NanTensorHook(estimator_spec.loss),
+ training.LoggingTensorHook(
+ {
+ 'loss': estimator_spec.loss,
+ 'step': global_step_tensor
+ },
+ every_n_iter=self._config.log_step_count_steps)
+ ])
+ worker_hooks.extend(estimator_spec.training_hooks)
+
+ if not (estimator_spec.scaffold.saver or
+ ops.get_collection(ops.GraphKeys.SAVERS)):
+ ops.add_to_collection(
+ ops.GraphKeys.SAVERS,
+ training.Saver(
+ sharded=True,
+ max_to_keep=self._config.keep_checkpoint_max,
+ keep_checkpoint_every_n_hours=(
+ self._config.keep_checkpoint_every_n_hours),
+ defer_build=True,
+ save_relative_paths=True))
+
+ chief_hooks = []
+ all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
+ saver_hooks = [
+ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
+ if (self._config.save_checkpoints_secs or
+ self._config.save_checkpoints_steps):
+ if not saver_hooks:
+ chief_hooks = [
+ training.CheckpointSaverHook(
+ self._model_dir,
+ save_secs=self._config.save_checkpoints_secs,
+ save_steps=self._config.save_checkpoints_steps,
+ scaffold=estimator_spec.scaffold)
+ ]
+ saver_hooks = [chief_hooks[0]]
+ if saving_listeners:
+ if not saver_hooks:
+ raise ValueError(
+ 'There should be a CheckpointSaverHook to use saving_listeners. '
+ 'Please set one of the RunConfig.save_checkpoints_steps or '
+ 'RunConfig.save_checkpoints_secs.')
+ else:
+ # It is expected to have one CheckpointSaverHook. If multiple, we pick
+ # up the first one to add listener.
+ saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+ with training.MonitoredTrainingSession(
+ master=self._config.master,
+ is_chief=self._config.is_chief,
+ checkpoint_dir=self._model_dir,
+ scaffold=estimator_spec.scaffold,
+ hooks=worker_hooks,
+ chief_only_hooks=(
+ tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
+ save_checkpoint_secs=0, # Saving is handled by a hook.
+ save_summaries_steps=self._config.save_summary_steps,
+ config=self._session_config,
+ log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ loss = None
+ while not mon_sess.should_stop():
+ _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
+ return loss
def _evaluate_model(self,
input_fn,
return eval_results
+def create_per_tower_ready_op(scaffold):
+ """Create a Scaffold.ready_op inside a tower."""
+ if scaffold.ready_op:
+ return scaffold.ready_op
+
+ def default_ready_op():
+ return array_ops.concat([
+ variables.report_uninitialized_variables(),
+ resources.report_uninitialized_resources()
+ ], 0)
+
+ return monitored_session.Scaffold.get_or_default(
+ 'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
+
+
+def create_per_tower_ready_for_local_init_op(scaffold):
+ """Create a Scaffold.ready_for_local_init_op inside a tower."""
+ if scaffold.ready_for_local_init_op:
+ return scaffold.ready_for_local_init_op
+
+ def default_ready_for_local_init_op():
+ return variables.report_uninitialized_variables(
+ variables.global_variables())
+
+ return monitored_session.Scaffold.get_or_default(
+ 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
+ default_ready_for_local_init_op)
+
+
def _check_checkpoint_available(model_dir):
latest_path = saver.latest_checkpoint(model_dir)
if not latest_path: