--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Operations for handling session logging and shutdown notifications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+import time
+from google.protobuf import text_format
+
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+
+
+class CoordinatorShutdownException(Exception):
+ """Raised when the coordinator needs to shutdown."""
+ pass
+
+
+class WorkerHeartbeatManager(object):
+ """Manages the status/heartbeat monitor for a set of workers."""
+
+ def __init__(self, session, devices, heartbeat_ops, request_placeholder):
+ """Construct a new WorkerHeartbeatManager.
+
+ (Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
+
+ Args:
+ session: `tf.Session`, session to use for heartbeat operations.
+ devices: `list[string]` Set of devices to connect to.
+ heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
+ request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
+ the WorkerHeartbeatRequest protocol buffer.
+ """
+ self._session = session
+ self._devices = devices
+ self._ops = heartbeat_ops
+ self._request_placeholder = request_placeholder
+
+ @staticmethod
+ def from_devices(session, devices):
+ """Construct a heartbeat manager for the given devices."""
+ if not devices:
+ logging.error('Trying to create heartbeat manager with no devices?')
+
+ logging.info('Creating heartbeat manager for %s', devices)
+ request_placeholder = array_ops.placeholder(
+ name='worker_heartbeat_request', dtype=dtypes.string)
+
+ heartbeat_ops = []
+ for device in devices:
+ with ops.device(device):
+ heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
+
+ return WorkerHeartbeatManager(session, devices, heartbeat_ops,
+ request_placeholder)
+
+ def configure(self, message):
+ """Configure heartbeat manager for all devices.
+
+ Args:
+ message: `event_pb2.WorkerHeartbeatRequest`
+
+ Returns: `None`
+
+ """
+ logging.info('Configuring worker heartbeat: %s',
+ text_format.MessageToString(message))
+ self._session.run(self._ops,
+ {self._request_placeholder: message.SerializeToString()})
+
+ def ping(self, request=None, timeout_in_ms=5000):
+ """Ping all workers, returning the parsed status results."""
+ if request is None:
+ request = event_pb2.WorkerHeartbeatRequest()
+
+ options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
+ results = self._session.run(
+ self._ops,
+ feed_dict={self._request_placeholder: request.SerializeToString()},
+ options=options)
+ parsed_results = [
+ event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
+ for res_pb in results
+ ]
+ logging.info('Results: %s', parsed_results)
+ return parsed_results
+
+ def lame_workers(self):
+ """Ping all workers, returning manager containing lame workers (or None)."""
+ ping_results = self.ping()
+ lame_workers = []
+
+ for ping_response, device, op in zip(ping_results, self._devices,
+ self._ops):
+ if ping_response.health_status != event_pb2.OK:
+ lame_workers.append((device, op))
+
+ if not lame_workers:
+ return None
+
+ bad_devices, bad_ops = zip(*lame_workers)
+ return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
+ self._request_placeholder)
+
+ def shutdown(self, timeout_ms=10000):
+ """Shutdown all workers after `shutdown_timeout_secs`."""
+ req = event_pb2.WorkerHeartbeatRequest(
+ watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms))
+ self.configure(req)
+
+
+def all_worker_devices(session):
+ """Return a list of devices for each worker in the system."""
+ devices = session.list_devices()
+ return [device.name for device in devices if 'CPU' in device.name]
+
+
+class WatchdogManager(threading.Thread):
+ """Configures worker watchdog timer and handles periodic pings.
+
+ Usage:
+ # Ping workers every minute, shutting down workers if they haven't received
+ # a ping after 1 hour.
+ watchdog_manager = WatchdogManager(
+ ping_interval=60, shutdown_timeout=3600
+ )
+
+ # Use as a context manager, resetting watchdog on context exit:
+ with watchdog_manager:
+ session.run(...)
+
+ # Or setup globally; watchdog will remain active until program exit.
+ watchdog_manager.configure_and_run()
+ """
+
+ def __init__(self,
+ session,
+ devices=None,
+ ping_interval=60,
+ shutdown_timeout=3600):
+ """Initialize a watchdog manager.
+
+ Args:
+
+ session: Session connected to worker devices. A cloned session and graph
+ will be created for managing worker pings.
+ devices: Set of devices to monitor. If none, all workers will be
+ monitored.
+ ping_interval: Time, in seconds, between watchdog pings.
+ shutdown_timeout: Time, in seconds, before watchdog timeout.
+ """
+ threading.Thread.__init__(self)
+ self.ping_interval = ping_interval
+ self.shutdown_timeout = shutdown_timeout
+ self.daemon = True
+ self._running = False
+ self._graph = ops.Graph()
+ self._session = session_lib.Session(
+ target=session.sess_str, graph=self._graph)
+
+ with self._graph.as_default():
+ if devices is None:
+ devices = all_worker_devices(self._session)
+ self._worker_manager = WorkerHeartbeatManager.from_devices(
+ self._session, devices)
+
+ def configure_and_run(self):
+ logging.info('Enabling worker watchdog.')
+ self._running = True
+ self._worker_manager.configure(
+ event_pb2.WorkerHeartbeatRequest(
+ watchdog_config=event_pb2.WatchdogConfig(
+ timeout_ms=self.shutdown_timeout * 1000,)))
+
+ self.start()
+
+ def __enter__(self):
+ self.configure_and_run()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ logging.info('Disabling worker watchdog.')
+ self._worker_manager.configure(
+ event_pb2.WorkerHeartbeatRequest(
+ watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,)))
+ self._running = False
+ self.join()
+
+ def run(self):
+ # Don't fetch logs or adjust timing: just ping the watchdog.
+ while self._running:
+ self._worker_manager.ping(request=None)
+ time.sleep(self.ping_interval)
+
+
+class GracefulShutdownHook(session_run_hook.SessionRunHook):
+ """Session hook that watches for shutdown events.
+
+ If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
+ SystemShutdown exception is raised to terminate the main session. If `saver`
+ is None the `SAVERS` collection will be read to find a saver.
+
+ `on_shutdown_hooks` is an optional list of functions that should be called
+ after checkpointing. The function is called with (`run_context`,
+ `all_workers`, `lame_workers`).
+
+ If `heartbeat_group` is not specified, it will default to all CPU workers
+ in the system.
+ """
+
+ def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
+ self._saver = saver
+ self._checkpoint_prefix = checkpoint_prefix
+ self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
+
+ # Worker heartbeats are managed independently of the main training graph.
+ self._graph = ops.Graph()
+ self._workers = None
+ self._session = None
+
+ def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
+ # N.B. We have to pull the global step here to avoid it being unavailable
+ # at checkpoint time; the graph has been frozen at that point.
+ if training_util.get_global_step() is None and self.saver() is not None:
+ raise ValueError(
+ 'Saver defined but no global step. Run `get_or_create_global_step()`'
+ ' in your model definition to allow checkpointing.')
+
+ with self._graph.as_default():
+ self._session = session_lib.Session(
+ target=training_session.sess_str, graph=self._graph)
+ self._workers = WorkerHeartbeatManager.from_devices(
+ self._session, all_worker_devices(self._session))
+
+ self._workers.configure(
+ event_pb2.WorkerHeartbeatRequest(
+ shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
+
+ def saver(self):
+ if self._saver:
+ return self._saver
+
+ savers = ops.get_collection(ops.GraphKeys.SAVERS)[0]
+ if not savers:
+ return None
+
+ if not isinstance(savers, list):
+ return savers
+
+ assert len(savers) == 1, 'Only one saver supported.'
+ return savers[0]
+
+ def after_run(self, run_context, run_values):
+ del run_values
+
+ lame_workers = self._workers.lame_workers()
+ if lame_workers:
+ logging.info('ShutdownHook: lame workers found: %s', lame_workers)
+
+ if self.saver():
+ logging.info('ShutdownHook: saving checkpoint to %s',
+ self._checkpoint_prefix)
+ self.saver().save(
+ run_context.session,
+ self._checkpoint_prefix,
+ global_step=training_util.get_global_step(),
+ write_state=True,
+ )
+ else:
+ logging.info('ShutdownHook: no Saver defined.')
+
+ for fn in self._on_shutdown_hooks:
+ fn(run_context, self._workers, lame_workers)
+
+
+def restart_computation(run_context, all_workers, lame_workers):
+ del run_context, lame_workers
+ logging.info('Shutting down all workers.')
+ all_workers.shutdown()
+
+ logging.info('Terminating coordinator.')
+ raise CoordinatorShutdownException()
+
+
+def shutdown_lame_workers(run_context, all_workers, lame_workers):
+ del run_context, all_workers
+ logging.info('Shutting down %s', lame_workers)
+ lame_workers.shutdown()
+++ /dev/null
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/util/session_message.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/util/event.pb.h"
-
-static const int kMaxLogEvents = 1000;
-
-namespace tensorflow {
-
-SessionLogger::SessionLogger() : status_(new SessionStatus) {}
-
-SessionLogger::~SessionLogger() {}
-
-string SessionLogger::DebugString() { return "SessionLogger"; }
-
-void SessionLogger::Log(StringPiece message) {
- mutex_lock lock(mu_);
-
- Event* event = status_->add_event();
- event->set_wall_time(Env::Default()->NowMicros());
- event->set_step(0);
- LogMessage* log = event->mutable_log_message();
- log->set_message(message.ToString());
- log->set_level(LogMessage::INFO);
-
- // Clip log events by 10% if we overflow
- if (status_->event_size() > kMaxLogEvents) {
- auto events = status_->mutable_event();
- events->DeleteSubrange(0, kMaxLogEvents / 10);
- }
-}
-
-SessionLogger* GetSessionLogger(ResourceMgr* rm) {
- SessionLogger* logger;
-
- std::function<Status(SessionLogger**)> status_creator =
- [](SessionLogger** result) {
- *result = new SessionLogger();
- return Status::OK();
- };
-
- if (!rm->LookupOrCreate<SessionLogger>("session", "status", &logger,
- status_creator)
- .ok()) {
- return nullptr;
- }
-
- return logger;
-}
-
-void LogSessionMessage(ResourceMgr* rm, StringPiece message) {
- return GetSessionLogger(rm)->Log(message);
-}
-
-} // namespace tensorflow