Defer logging infeed error messages for a short time to see if the main session returns.
authorRussell Power <power@google.com>
Thu, 25 Jan 2018 00:41:50 +0000 (16:41 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 25 Jan 2018 00:45:13 +0000 (16:45 -0800)
PiperOrigin-RevId: 183162866

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index b6d685b3fca22a14c6f97d2d3b7c5668ebf4e297..2ae3a26a853bf4941ac3855ec525293b5a508a2a 100644 (file)
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ===================================================================
-
 """TPUEstimator class."""
 
 from __future__ import absolute_import
@@ -24,6 +23,7 @@ from contextlib import contextmanager
 import copy
 import threading
 import time
+import traceback
 
 import six
 from six.moves import queue as Queue  # pylint: disable=redefined-builtin
@@ -60,7 +60,6 @@ from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training
 from tensorflow.python.training import training_util
 
-
 _INITIAL_LOSS = 1e7
 _ZERO_LOSS = 0.
 _TPU_ESTIMATOR = 'tpu_estimator'
@@ -86,8 +85,7 @@ def _create_global_step(graph):
         initializer=init_ops.zeros_initializer(),
         trainable=False,
         use_resource=True,
-        collections=[ops.GraphKeys.GLOBAL_VARIABLES,
-                     ops.GraphKeys.GLOBAL_STEP])
+        collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
 
 
 def _create_or_get_iterations_per_loop():
@@ -100,8 +98,8 @@ def _create_or_get_iterations_per_loop():
     raise RuntimeError('Multiple iterations_per_loop_var in collection.')
 
   with ops.colocate_with(training_util.get_global_step()):
-    with variable_scope.variable_scope(_TPU_ESTIMATOR,
-                                       reuse=variable_scope.AUTO_REUSE):
+    with variable_scope.variable_scope(
+        _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
       return variable_scope.get_variable(
           _ITERATIONS_PER_LOOP_VAR,
           initializer=init_ops.zeros_initializer(),
@@ -242,9 +240,9 @@ class _TPUContext(object):
         return self._eval_batch_size
       return None
 
-    global_batch_size = (self._train_batch_size if
-                         mode == model_fn_lib.ModeKeys.TRAIN
-                         else self._eval_batch_size)
+    global_batch_size = (
+        self._train_batch_size
+        if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size)
     # On TPU
     if self.is_input_sharded_per_core():
       return global_batch_size // self.num_cores
@@ -291,8 +289,9 @@ class _TPUContext(object):
     # The tpu job is determined by the run_config. Right now, this method is
     # required as tpu_config is not part of the RunConfig.
     mode = self._assert_mode()
-    master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL
-              else run_config.master)
+    master = (
+        run_config.evaluation_master
+        if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
     if master in _LOCAL_MASTERS:
       return None
 
@@ -319,6 +318,7 @@ class _TPUContext(object):
   def tpu_host_placement_function(self):
     """Returns the TPU host place function."""
     master = self.master_job
+
     def _placement_function(_sentinal=None, core_id=None, host_id=None):  # pylint: disable=invalid-name
       assert _sentinal is None
       if core_id is not None and host_id is not None:
@@ -333,19 +333,23 @@ class _TPUContext(object):
         if core_id is not None:
           host_id = core_id / 8
         return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
+
     return _placement_function
 
   @property
   def tpu_device_placement_function(self):
     master = self.master_job
     job_device = '' if master is None else ('/job:%s' % master)
+
     def _placement_function(i):
       return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)
+
     return _placement_function
 
   @property
   def tpu_ordinal_function(self):
     """Returns the TPU ordinal fn."""
+
     def _tpu_ordinal_function(index):
       """Return the TPU ordinal associated with a shard.
 
@@ -358,6 +362,7 @@ class _TPUContext(object):
         The ordinal of the TPU device the shard's infeed should be placed on.
       """
       return index % 8
+
     return _tpu_ordinal_function
 
 
@@ -371,14 +376,16 @@ class _SIGNAL(object):
   STOP = -2
 
 
-class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
-    'mode',
-    'predictions',
-    'loss',
-    'train_op',
-    'eval_metrics',
-    'export_outputs',
-    'scaffold_fn'])):
+class TPUEstimatorSpec(
+    collections.namedtuple('TPUEstimatorSpec', [
+        'mode',
+        'predictions',
+        'loss',
+        'train_op',
+        'eval_metrics',
+        'export_outputs',
+        'scaffold_fn'
+    ])):
   """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
 
   See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
@@ -416,111 +423,116 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
     """Creates a validated `TPUEstimatorSpec` instance."""
     if eval_metrics is not None:
       _EvalMetrics.validate(eval_metrics)
-    return super(TPUEstimatorSpec, cls).__new__(cls,
-                                                mode=mode,
-                                                predictions=predictions,
-                                                loss=loss,
-                                                train_op=train_op,
-                                                eval_metrics=eval_metrics,
-                                                export_outputs=export_outputs,
-                                                scaffold_fn=scaffold_fn)
+    return super(TPUEstimatorSpec, cls).__new__(
+        cls,
+        mode=mode,
+        predictions=predictions,
+        loss=loss,
+        train_op=train_op,
+        eval_metrics=eval_metrics,
+        export_outputs=export_outputs,
+        scaffold_fn=scaffold_fn)
 
   def as_estimator_spec(self):
     """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
     eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu(
         self.eval_metrics)
     scaffold = self.scaffold_fn() if self.scaffold_fn else None
-    return model_fn_lib.EstimatorSpec(mode=self.mode,
-                                      predictions=self.predictions,
-                                      loss=self.loss,
-                                      train_op=self.train_op,
-                                      eval_metric_ops=eval_metric_ops,
-                                      export_outputs=self.export_outputs,
-                                      scaffold=scaffold)
+    return model_fn_lib.EstimatorSpec(
+        mode=self.mode,
+        predictions=self.predictions,
+        loss=self.loss,
+        train_op=self.train_op,
+        eval_metric_ops=eval_metric_ops,
+        export_outputs=self.export_outputs,
+        scaffold=scaffold)
+
+
+class _OpQueueContext(object):
+  """Manages work queue and thread for a infeed/outfeed thread."""
+
+  def __init__(self, name, target, args):
+    self._name = name
+    self._queue = Queue.Queue()
+    args = (self,) + args
+    self._thread = threading.Thread(name=name, target=target, args=args)
+    self._thread.daemon = True
+    self._thread.start()
+
+  def stop(self):
+    self._queue.put(_SIGNAL.STOP)
+
+  def send_next_batch_signal(self, iterations):
+    self._queue.put(iterations)
+
+  def read_iteration_counts(self):
+    while True:
+      signal = self._queue.get(block=True)
+      logging.debug('%s read signal %s', self._name, signal)
+      if signal == _SIGNAL.STOP:
+        logging.info('%s received signal, stopping.', self._name)
+        return
+      yield signal
 
+  def join(self):
+    logging.info('Shutting down %s thread.' % self._name)
+    self.stop()
+    self._thread.join()
 
-class _InfeedOutfeedThreadBaseController(object):
-  """This wraps the infeed/outfeed thread and stops when Estimator finishes."""
 
-  def __init__(self, thd):
-    self._signal_queue = Queue.Queue()
-    thd.daemon = True
-    thd.start()
-    self._thd = thd
+class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
+  """A Session hook setting up the TPU initialization, infeed, and outfeed.
 
-  def block_and_get_signal(self):
-    return self._signal_queue.get()
+  This hook does two major things:
+  1. initialize and shutdown TPU system.
+  2. launch and join the threads for infeed enqueue and (optional) outfeed
+     dequeue.
+  """
 
-  def send_next_batch_signal(self, signal=_SIGNAL.NEXT_BATCH):
-    self._signal_queue.put(signal)
+  def __init__(self, ctx, enqueue_ops, dequeue_ops=None):
+    self._master_job = ctx.master_job
+    self._enqueue_ops = enqueue_ops
+    self._dequeue_ops = dequeue_ops
+    self._initial_infeed_sleep_secs = (
+        ctx.config.tpu_config.initial_infeed_sleep_secs)
+    self._session_cancel_timer = None
 
-  def join(self):
-    self._signal_queue.put(_SIGNAL.STOP)
-    self._thd.join()
+    self._feed_error = None
+    self._finished = False
 
+  def begin(self):
+    logging.info('TPU job name %s', self._master_job)
+    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
+    self._init_op = [tpu.initialize_system(job=self._master_job)]
+    self._finalize_op = [tpu.shutdown_system(job=self._master_job)]
 
-class _OutfeedThreadController(_InfeedOutfeedThreadBaseController):
-  """This wraps the outfeed thread and stops when Estimator finishes."""
+  def _log_error(self, session, error):
+    """Log an infeed or outfeed error.
 
-  def __init__(self, session, dequeue_ops):
-    super(_OutfeedThreadController, self).__init__(
-        threading.Thread(target=self._execute_dequeue_ops,
-                         args=(session, dequeue_ops)))
+    This logs a short error message immediately, and schedules a timer to
+    emit the full stack trace and error message after a short period of time.
+    If the main session has terminated by the time the timer triggers, we
+    assume the real source of the error was from the main session and avoid
+    emitting a stack trace for the infeed.
 
-  def _execute_dequeue_ops(self, session, dequeue_ops):
-    count = 0
-    while True:
-      signal = self.block_and_get_signal()
-      if signal == _SIGNAL.STOP:
-        logging.info('Stop outfeed thread.')
-        return
+    Args:
+      session: `tf.Session`, session to be terminated
+      error: exception that triggered logging.
+    """
+    logging.warning(
+        '\n\n'
+        'Error occurred during infeed/outfeed.  This may be due to a compile '
+        'error in the main session.  Waiting for a short time for the main '
+        'session to come back.\n\n%s', error)
 
-      iterations = signal
-      for i in range(iterations):
-        logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
-        session.run(dequeue_ops)
-      count += 1
+    self._feed_error = traceback.format_exc()
 
-  def join(self):
-    logging.info('Waiting for Outfeed Thread to exit.')
-    super(_OutfeedThreadController, self).join()
-
-
-class _InfeedThreadController(_InfeedOutfeedThreadBaseController):
-  """This wraps the infeed thread and stops when Estimator finishes."""
-
-  def __init__(self, session, enqueue_ops, initial_infeed_sleep_secs):
-    super(_InfeedThreadController, self).__init__(
-        threading.Thread(
-            target=self._input_thread_fn_for_loading,
-            args=(session, enqueue_ops, initial_infeed_sleep_secs)))
-
-  def _input_thread_fn_for_loading(self, session, enqueue_ops,
-                                   initial_infeed_sleep_secs):
-    count = 0
-    if initial_infeed_sleep_secs:
-      logging.info('Infeed thread sleeping for %d seconds.',
-                   initial_infeed_sleep_secs)
-      time.sleep(initial_infeed_sleep_secs)
-      logging.info('Infeed thread starting after sleep')
-    try:
-      while True:
-        signal = self._signal_queue.get()
-        if signal == _SIGNAL.STOP:
-          logging.info('Stop Infeed input thread.')
-          return
-
-        if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
-          # Enqueue batches for next loop.
-          session.run(enqueue_ops)
-        else:
-          iterations = signal
-          for i in range(iterations):
-            logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
-            session.run(enqueue_ops)
-          count += 1
+    # If we've already encountered a feed error, don't schedule another
+    # cancellation op.
+    if self._session_cancel_timer:
+      return
 
-    except Exception:  # pylint: disable=broad-except
+    def _cancel_session():
       # Close the session to avoid the main thread from hanging. If input
       # pipeline triggers any error, the infeed thread dies but the main thread
       # for TPU computation waits for the infeed enqueue forever. Close the
@@ -535,77 +547,94 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController):
       # exception in the main thread, instead of the expected compile error.
       # User code that depends on having the proper exception type will
       # therefore be confused.
-      logging.error(
-          'Failed running infeed, closing session.\n'
-          'You may see an exception from your main session after this. '
-          'Sleep for 2 minutes before close Session from infeed thread to '
-          'allow the main thread returning an error first, if any.',
-          exc_info=1
-      )
-      time.sleep(120)
-      logging.error('Closing the failed session.')
-      session.close()
-
-  def join(self):
-    logging.info('Waiting for Infeed Thread to exit.')
-    super(_InfeedThreadController, self).join()
-
-
-class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
-  """A Session hook setting up the TPU initialization, infeed, and outfeed.
-
-  This hook does two major things:
-  1. initialize and shutdown TPU system.
-  2. launch and join the threads for infeed enqueue and (optional) outfeed
-     dequeue.
-  """
+      time.sleep(5)
+
+      # If the main session is still running, the infeed/outfeed errors are
+      # legitimate, and should be logged.
+      if not self._finished:
+        logging.error('Feed error: %s', self._feed_error)
+        logging.error('Closing session.  A RuntimeError should follow.')
+        session.close()
+
+    self._session_cancel_timer = threading.Thread(target=_cancel_session)
+    self._session_cancel_timer.daemon = True
+    self._session_cancel_timer.start()
+
+  def _run_infeed(self, queue_ctx, session):
+    logging.info('Starting infeed thread controller.')
+    if self._initial_infeed_sleep_secs:
+      logging.info('%s thread sleeping for %d seconds.', self._name,
+                   self._initial_infeed_sleep_secs)
+      time.sleep(self._initial_infeed_sleep_secs)
+      logging.info('%s thread starting after sleep', self._name)
 
-  def __init__(self, ctx, enqueue_ops, dequeue_ops=None):
-    self._master_job = ctx.master_job
-    self._enqueue_ops = enqueue_ops
-    self._dequeue_ops = dequeue_ops
-    self._initial_infeed_sleep_secs = (
-        ctx.config.tpu_config.initial_infeed_sleep_secs)
+    try:
+      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+        for _ in queue_ctx.read_iteration_counts():
+          session.run(self._enqueue_ops)
+      else:
+        for count, steps in enumerate(queue_ctx.read_iteration_counts()):
+          for i in xrange(steps):
+            logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
+            session.run(self._enqueue_ops)
+      logging.debug('Infeed thread finished, shutting down.')
+    except Exception as e:  # pylint: disable=broad-except
+      self._log_error(session, e)
 
-  def begin(self):
-    logging.info('TPU job name %s', self._master_job)
-    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
-    self._init_op = [tpu.initialize_system(job=self._master_job)]
-    self._finalize_op = [tpu.shutdown_system(job=self._master_job)]
+  def _run_outfeed(self, queue_ctx, session):
+    logging.info('Starting outfeed thread controller.')
+    try:
+      for count, steps in enumerate(queue_ctx.read_iteration_counts()):
+        for i in xrange(steps):
+          logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
+          session.run(self._dequeue_ops)
+    except Exception as e:  # pylint: disable=broad-except
+      self._log_error(session, e)
 
   def after_create_session(self, session, coord):
     logging.info('Init TPU system')
-    session.run(self._init_op,
-                options=config_pb2.RunOptions(timeout_in_ms=5*60*1000))
+    session.run(
+        self._init_op,
+        options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
 
     logging.info('Start infeed thread controller')
-    self._infeed_thd_controller = _InfeedThreadController(
-        session, self._enqueue_ops, self._initial_infeed_sleep_secs)
+    self._infeed_controller = _OpQueueContext(
+        name='InfeedController', target=self._run_infeed, args=(session,))
 
     if self._dequeue_ops is not None:
       logging.info('Start outfeed thread controller')
-      self._outfeed_thd_controller = _OutfeedThreadController(
-          session, self._dequeue_ops)
+      self._outfeed_controller = _OpQueueContext(
+          name='OutfeedController', target=self._run_outfeed, args=(session,))
 
   def before_run(self, run_context):
+    if self._feed_error:
+      logging.warning('Feed error occurred, terminating session.')
+      run_context.request_stop()
+      return
+
     iterations = run_context.session.run(self._iterations_per_loop_var)
 
     logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
+    self._infeed_controller.send_next_batch_signal(iterations)
 
-    self._infeed_thd_controller.send_next_batch_signal(iterations)
     if self._dequeue_ops is not None:
       # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
-      logging.info(
-          'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
-      self._outfeed_thd_controller.send_next_batch_signal(iterations)
+      logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
+                   iterations)
+      self._outfeed_controller.send_next_batch_signal(iterations)
 
   def end(self, session):
+    if self._session_cancel_timer:
+      logging.warning('Feed error occurred; waiting for message.')
+      self._session_cancel_timer.join()
+
+    self._finished = True
     logging.info('Stop infeed thread controller')
-    self._infeed_thd_controller.join()
+    self._infeed_controller.join()
 
     if self._dequeue_ops is not None:
       logging.info('Stop output thread controller')
-      self._outfeed_thd_controller.join()
+      self._outfeed_controller.join()
 
     logging.info('Shutdown TPU system.')
     session.run(self._finalize_op)
@@ -676,8 +705,8 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook):
       run_context.request_stop()
     else:
       iterations = self._next_iterations(global_step, self._last_step)
-      self._iterations_per_loop_var.load(iterations,
-                                         session=run_context.session)
+      self._iterations_per_loop_var.load(
+          iterations, session=run_context.session)
 
 
 class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
@@ -698,8 +727,8 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
     self._iterations_per_loop_var.load(self._num_steps, session=session)
 
 
-def generate_per_core_enqueue_ops_fn_for_host(
-    ctx, input_fn, inputs_structure_recorder):
+def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,
+                                              inputs_structure_recorder):
   """Generates infeed enqueue ops for per-core input_fn on a single host."""
   captured_infeed_queue = _CapturedObject()
 
@@ -729,9 +758,9 @@ def generate_per_core_enqueue_ops_fn_for_host(
         per_host_sharded_inputs)
 
     per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
-        per_host_sharded_inputs,
-        tpu_ordinal_function=ctx.tpu_ordinal_function)
+        per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
     return per_host_enqueue_ops
+
   return enqueue_ops_fn, captured_infeed_queue
 
 
@@ -748,8 +777,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
         features, labels = inputs
       else:
         features, labels = inputs, None
-      inputs_structure_recorder.validate_and_record_structure(
-          features, labels)
+      inputs_structure_recorder.validate_and_record_structure(features, labels)
       unsharded_tensor_list = (
           inputs_structure_recorder.flatten_features_and_labels(
               features, labels))
@@ -763,9 +791,9 @@ def generate_per_host_enqueue_ops_fn_for_host(
 
       per_host_enqueue_ops = (
           infeed_queue.split_inputs_and_generate_enqueue_ops(
-              unsharded_tensor_list,
-              placement_function=lambda x: device))
+              unsharded_tensor_list, placement_function=lambda x: device))
       return per_host_enqueue_ops
+
   return enqueue_ops_fn, captured_infeed_queue
 
 
@@ -815,6 +843,7 @@ class _InputPipeline(object):
 
     def validate_and_record_structure(self, features, labels):
       """Validates and records the structure of features` and `labels`."""
+
       def _extract_key_names(tensor_or_dict):
         if tensor_or_dict is None:
           return []
@@ -842,8 +871,8 @@ class _InputPipeline(object):
       flattened_inputs = []
       if self._feature_names:
         # We need a fixed ordering for enqueueing and dequeueing.
-        flattened_inputs.extend([features[name]
-                                 for name in self._feature_names])
+        flattened_inputs.extend(
+            [features[name] for name in self._feature_names])
       else:
         flattened_inputs.append(features)
 
@@ -870,11 +899,11 @@ class _InputPipeline(object):
         ValueError: If the number of expected tensors from `flattened_inputs`
           mismatches the recorded structure.
       """
-      expected_num_features = (len(self._feature_names) if self._feature_names
-                               else 1)
+      expected_num_features = (
+          len(self._feature_names) if self._feature_names else 1)
       if self._has_labels:
-        expected_num_labels = (len(self._label_names) if self._label_names
-                               else 1)
+        expected_num_labels = (
+            len(self._label_names) if self._label_names else 1)
       else:
         expected_num_labels = 0
 
@@ -895,8 +924,8 @@ class _InputPipeline(object):
       if expected_num_labels == 0:
         unflattened_label = None
       elif self._label_names:
-        unflattened_label = dict(zip(self._label_names,
-                                     flattened_inputs[expected_num_features:]))
+        unflattened_label = dict(
+            zip(self._label_names, flattened_inputs[expected_num_features:]))
       else:
         # Single tensor case.
         unflattened_label = flattened_inputs[expected_num_features]
@@ -961,8 +990,9 @@ class _InputPipeline(object):
                     self._ctx, self._input_fn, self._inputs_structure_recorder))
 
             if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
-              enqueue_ops.append(_wrap_computation_in_while_loop(
-                  device=host_device, op_fn=enqueue_ops_fn))
+              enqueue_ops.append(
+                  _wrap_computation_in_while_loop(
+                      device=host_device, op_fn=enqueue_ops_fn))
             else:
               enqueue_ops.append(enqueue_ops_fn())
             # Infeed_queue_getter must be called after enqueue_ops_fn is called.
@@ -979,8 +1009,9 @@ class _InputPipeline(object):
                     self._batch_axis, host_device))
 
             if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
-              enqueue_ops.append(_wrap_computation_in_while_loop(
-                  device=host_device, op_fn=enqueue_ops_fn))
+              enqueue_ops.append(
+                  _wrap_computation_in_while_loop(
+                      device=host_device, op_fn=enqueue_ops_fn))
             else:
               enqueue_ops.append(enqueue_ops_fn())
             infeed_queues.append(captured_infeed_queue.get())
@@ -1066,6 +1097,7 @@ class _ModelFnWrapper(object):
 
       with ops.control_dependencies([train_op]):
         return array_ops.identity(loss)
+
     return train_step, captured_scaffold_fn
 
   def convert_to_single_tpu_eval_step(self, dequeue_fn):
@@ -1114,6 +1146,7 @@ class _ModelFnWrapper(object):
 
       with ops.control_dependencies([outfeed_ops]):
         return math_ops.add(total_loss, loss)
+
     return eval_step, eval_metrics, captured_scaffold_fn
 
   def _call_model_fn(self, features, labels):
@@ -1138,10 +1171,9 @@ class _ModelFnWrapper(object):
       kwargs['params'] = params
 
     if 'params' not in model_fn_args:
-      raise ValueError(
-          'model_fn ({}) does not include params argument, '
-          'required by TPUEstimator to pass batch size as '
-          'params[\'batch_size\']'.format(self._model_fn))
+      raise ValueError('model_fn ({}) does not include params argument, '
+                       'required by TPUEstimator to pass batch size as '
+                       'params[\'batch_size\']'.format(self._model_fn))
 
     batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
     if batch_size_for_model_fn is not None:
@@ -1348,8 +1380,9 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
   def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
     examples_per_sec = self._batch_size * elapsed_steps / elapsed_time
     if self._summary_writer is not None:
-      example_summary = Summary(value=[Summary.Value(
-          tag='examples_sec', simple_value=examples_per_sec)])
+      example_summary = Summary(value=[
+          Summary.Value(tag='examples_sec', simple_value=examples_per_sec)
+      ])
       self._summary_writer.add_summary(example_summary, global_step)
     logging.info('examples/sec: %g', examples_per_sec)
 
@@ -1488,9 +1521,8 @@ class TPUEstimator(estimator_lib.Estimator):
           '`config` must be provided with type `tpu_config.RunConfig`')
 
     if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
-      raise ValueError(
-          '{} are reserved keys but existed in params {}.'.format(
-              _RESERVED_PARAMS_KEYS, params))
+      raise ValueError('{} are reserved keys but existed in params {}.'.format(
+          _RESERVED_PARAMS_KEYS, params))
 
     if use_tpu:
       if train_batch_size is None:
@@ -1571,8 +1603,9 @@ class TPUEstimator(estimator_lib.Estimator):
     if max_steps is not None:
       util_lib.check_positive_integer(max_steps, 'Train max_steps')
 
-    return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps,
-                               max_steps)]
+    return [
+        _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
+    ]
 
   def _convert_eval_steps_to_hooks(self, steps):
     with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
@@ -1640,6 +1673,7 @@ class TPUEstimator(estimator_lib.Estimator):
       # `features` in `model_fn` signature.
       def _input_fn():
         return input_fn(**kwargs)
+
       return _input_fn
 
   def _augment_model_fn(self, model_fn, batch_axis):
@@ -1695,9 +1729,10 @@ class TPUEstimator(estimator_lib.Estimator):
         total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system(
             ctx, model_fn_wrapper, dequeue_fn)
         iterations_per_loop_var = _create_or_get_iterations_per_loop()
-        mean_loss = math_ops.div(
-            total_loss,
-            math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype))
+        mean_loss = math_ops.div(total_loss,
+                                 math_ops.cast(
+                                     iterations_per_loop_var,
+                                     dtype=total_loss.dtype))
 
         # Creates a dummy metric update_op for all metrics. Estimator expects
         # all metrics in eval_metric_ops have update_op and calls them one by
@@ -1725,6 +1760,7 @@ class TPUEstimator(estimator_lib.Estimator):
             evaluation_hooks=hooks,
             eval_metric_ops=eval_metric_ops,
             scaffold=scaffold)
+
     return _model_fn
 
 
@@ -1737,15 +1773,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
       model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
 
   def multi_tpu_eval_steps_on_single_shard():
-    return training_loop.repeat(iterations_per_loop_var,
-                                single_tpu_eval_step,
-                                [_ZERO_LOSS],
-                                name='loop')
+    return training_loop.repeat(
+        iterations_per_loop_var,
+        single_tpu_eval_step, [_ZERO_LOSS],
+        name='loop')
 
-  (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard,
-                      inputs=[],
-                      num_shards=num_cores,
-                      outputs_from_all_shards=False)
+  (loss,) = tpu.shard(
+      multi_tpu_eval_steps_on_single_shard,
+      inputs=[],
+      num_shards=num_cores,
+      outputs_from_all_shards=False)
 
   scaffold = _get_scaffold(captured_scaffold_fn)
   return loss, eval_metric_ops, scaffold
@@ -1762,14 +1799,14 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   def multi_tpu_train_steps_on_single_shard():
     return training_loop.repeat(
         iterations_per_loop_var,
-        single_tpu_train_step,
-        [_INITIAL_LOSS],
+        single_tpu_train_step, [_INITIAL_LOSS],
         name=b'loop')
 
-  (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
-                      inputs=[],
-                      num_shards=num_cores,
-                      outputs_from_all_shards=False)
+  (loss,) = tpu.shard(
+      multi_tpu_train_steps_on_single_shard,
+      inputs=[],
+      num_shards=num_cores,
+      outputs_from_all_shards=False)
 
   scaffold = _get_scaffold(captured_scaffold_fn)
   return loss, scaffold
@@ -1777,6 +1814,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
 
 def _wrap_computation_in_while_loop(device, op_fn):
   """Wraps the ops generated by `op_fn` in tf.while_loop."""
+
   def computation(i):
     with ops.control_dependencies(op_fn()):
       return i + 1
@@ -1788,7 +1826,8 @@ def _wrap_computation_in_while_loop(device, op_fn):
     iterations = array_ops.identity(iterations_per_loop_var)
     return control_flow_ops.while_loop(
         lambda i: i < iterations,
-        computation, [constant_op.constant(0)], parallel_iterations=1)
+        computation, [constant_op.constant(0)],
+        parallel_iterations=1)
 
 
 def _validate_tpu_training_graph():
@@ -1801,8 +1840,9 @@ def _validate_tpu_training_graph():
 
   # Check if there is atleast one CrossReplicaSum operation in the graph
   # This should be introduced by using the CrossShardOptimizer wrapper
-  cross_replica_sum_ops = [o for o in operations
-                           if o.type == _CROSS_REPLICA_SUM_OP]
+  cross_replica_sum_ops = [
+      o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
+  ]
   if not cross_replica_sum_ops:
     raise ValueError(
         'CrossShardOptimizer must be used for model training on TPUs.')
@@ -1849,9 +1889,11 @@ def _get_scaffold(captured_scaffold_fn):
 
   if scaffold:
     wrapped_finalize = scaffold.finalize
+
     def _finalize():
       with _CapturingContext('Inside Scaffold.finalize'):
         wrapped_finalize()
+
     scaffold.finalize = _finalize
   return scaffold
 
@@ -1866,9 +1908,8 @@ class _CapturingContext(control_flow_ops.ControlFlowContext):
   def AddOp(self, op):  # pylint: disable=invalid-name
     for c in op.inputs:
       if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr:  # pylint: disable=protected-access
-        raise ValueError(
-            '{}: Op {} depends on TPU computation {}, '
-            'which is not allowed.'.format(self._message, op, c))
+        raise ValueError('{}: Op {} depends on TPU computation {}, '
+                         'which is not allowed.'.format(self._message, op, c))
 
   def __enter__(self):
     # pylint: disable=protected-access