Adjust worker shutdown hooks for TPUs
authorRussell Power <power@google.com>
Thu, 3 May 2018 23:16:05 +0000 (16:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:35:35 +0000 (10:35 -0700)
PiperOrigin-RevId: 195328247

tensorflow/contrib/tpu/python/tpu/session_support.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 7c25f66..3455e0b 100644 (file)
@@ -126,12 +126,21 @@ class WorkerHeartbeatManager(object):
     return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
                                   self._request_placeholder)
 
+  def __repr__(self):
+    return 'HeartbeatManager(%s)' % ','.join(self._devices)
+
   def shutdown(self, timeout_ms=10000):
     """Shutdown all workers after `shutdown_timeout_secs`."""
+    logging.info('Shutting down %s.', self)
     req = event_pb2.WorkerHeartbeatRequest(
         watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms))
     self.configure(req)
 
+    # Wait for workers to shutdown.  This isn't strictly required
+    # but it avoids triggering multiple checkpoints with the same lame worker.
+    logging.info('Waiting %dms for worker shutdown.', timeout_ms)
+    time.sleep(timeout_ms / 1000)
+
 
 def all_worker_devices(session):
   """Return a list of devices for each worker in the system."""
@@ -250,6 +259,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
           ' in your model definition to allow checkpointing.')
 
     with self._graph.as_default():
+      logging.info('Installing graceful shutdown hook.')
       self._session = session_lib.Session(
           target=training_session.sess_str, graph=self._graph)
       self._workers = WorkerHeartbeatManager.from_devices(
@@ -296,16 +306,33 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
         fn(run_context, self._workers, lame_workers)
 
 
-def restart_computation(run_context, all_workers, lame_workers):
-  del run_context, lame_workers
-  logging.info('Shutting down all workers.')
-  all_workers.shutdown()
+class RestartComputation(object):
+  """Restart the entire computation.
+
+  This hook shuts down all workers and returns control to the top-level by
+  throwing a CoordinatorShutdownException.
+  """
+
+  def __init__(self, timeout_ms=10000):
+    self.timeout_ms = timeout_ms
+
+  def __call__(self, run_context, all_workers, lame_workers):
+    del run_context, lame_workers
+    all_workers.shutdown(timeout_ms=self.timeout_ms)
+
+    logging.info('Terminating coordinator.')
+    raise CoordinatorShutdownException()
+
 
-  logging.info('Terminating coordinator.')
-  raise CoordinatorShutdownException()
+class ShutdownLameWorkers(object):
+  """Shutdown lamed workers.
+
+  Processing will continue normally (typically by waiting for the down
+  workers to be restarted).
+  """
 
+  def __init__(self, timeout_ms=10000):
+    self.timeout_in_ms = timeout_ms
 
-def shutdown_lame_workers(run_context, all_workers, lame_workers):
-  del run_context, all_workers
-  logging.info('Shutting down %s', lame_workers)
-  lame_workers.shutdown()
+  def __call__(self, run_context, all_workers, lame_workers):
+    lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
index 534042b..a69bfa9 100644 (file)
@@ -2049,9 +2049,28 @@ class TPUEstimator(estimator_lib.Estimator):
           host_ops = host_call.create_tpu_hostcall()
           if host_ops is None:
             host_ops = []
+
           shutdown_hooks = []
-          if os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN', '0') != '0':
-            shutdown_hooks.append(session_support.GracefulShutdownHook())
+          shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
+                                         'shutdown_worker')
+          if shutdown_mode:
+            if shutdown_mode == 'shutdown_worker':
+              finalizer_hooks = [
+                  session_support.ShutdownLameWorkers(timeout_ms=1000),
+              ]
+            elif shutdown_mode == 'shutdown_computation':
+              finalizer_hooks = [
+                  session_support.RestartComputation(timeout_ms=1000),
+              ]
+            else:
+              raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' %
+                               shutdown_mode)
+
+            shutdown_hooks.append(session_support.GracefulShutdownHook(
+                checkpoint_prefix=self.model_dir + '/model.ckpt',
+                on_shutdown_hooks=finalizer_hooks
+            ))
+
           with ops.control_dependencies([loss]):
             global_step = array_ops.identity(training.get_global_step())
           hooks = input_hooks + shutdown_hooks