return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
self._request_placeholder)
+ def __repr__(self):
+ return 'HeartbeatManager(%s)' % ','.join(self._devices)
+
def shutdown(self, timeout_ms=10000):
"""Shutdown all workers after `shutdown_timeout_secs`."""
+ logging.info('Shutting down %s.', self)
req = event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms))
self.configure(req)
+ # Wait for workers to shutdown. This isn't strictly required
+ # but it avoids triggering multiple checkpoints with the same lame worker.
+ logging.info('Waiting %dms for worker shutdown.', timeout_ms)
+ time.sleep(timeout_ms / 1000)
+
def all_worker_devices(session):
"""Return a list of devices for each worker in the system."""
' in your model definition to allow checkpointing.')
with self._graph.as_default():
+ logging.info('Installing graceful shutdown hook.')
self._session = session_lib.Session(
target=training_session.sess_str, graph=self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
fn(run_context, self._workers, lame_workers)
-def restart_computation(run_context, all_workers, lame_workers):
- del run_context, lame_workers
- logging.info('Shutting down all workers.')
- all_workers.shutdown()
+class RestartComputation(object):
+ """Restart the entire computation.
+
+ This hook shuts down all workers and returns control to the top-level by
+ throwing a CoordinatorShutdownException.
+ """
+
+ def __init__(self, timeout_ms=10000):
+ self.timeout_ms = timeout_ms
+
+ def __call__(self, run_context, all_workers, lame_workers):
+ del run_context, lame_workers
+ all_workers.shutdown(timeout_ms=self.timeout_ms)
+
+ logging.info('Terminating coordinator.')
+ raise CoordinatorShutdownException()
+
- logging.info('Terminating coordinator.')
- raise CoordinatorShutdownException()
+class ShutdownLameWorkers(object):
+ """Shutdown lamed workers.
+
+ Processing will continue normally (typically by waiting for the down
+ workers to be restarted).
+ """
+ def __init__(self, timeout_ms=10000):
+ self.timeout_in_ms = timeout_ms
-def shutdown_lame_workers(run_context, all_workers, lame_workers):
- del run_context, all_workers
- logging.info('Shutting down %s', lame_workers)
- lame_workers.shutdown()
+ def __call__(self, run_context, all_workers, lame_workers):
+ lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
host_ops = host_call.create_tpu_hostcall()
if host_ops is None:
host_ops = []
+
shutdown_hooks = []
- if os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN', '0') != '0':
- shutdown_hooks.append(session_support.GracefulShutdownHook())
+ shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
+ 'shutdown_worker')
+ if shutdown_mode:
+ if shutdown_mode == 'shutdown_worker':
+ finalizer_hooks = [
+ session_support.ShutdownLameWorkers(timeout_ms=1000),
+ ]
+ elif shutdown_mode == 'shutdown_computation':
+ finalizer_hooks = [
+ session_support.RestartComputation(timeout_ms=1000),
+ ]
+ else:
+ raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' %
+ shutdown_mode)
+
+ shutdown_hooks.append(session_support.GracefulShutdownHook(
+ checkpoint_prefix=self.model_dir + '/model.ckpt',
+ on_shutdown_hooks=finalizer_hooks
+ ))
+
with ops.control_dependencies([loss]):
global_step = array_ops.identity(training.get_global_step())
hooks = input_hooks + shutdown_hooks