Add support for a clean checkpoint and shutdown in response to a termination notice.
authorRussell Power <power@google.com>
Sun, 29 Apr 2018 22:30:22 +0000 (15:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 29 Apr 2018 22:32:53 +0000 (15:32 -0700)
PiperOrigin-RevId: 194722985

tensorflow/contrib/tpu/BUILD
tensorflow/contrib/tpu/ops/heartbeat_ops.cc [new file with mode: 0644]
tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
tensorflow/contrib/tpu/python/tpu/session_support.py [new file with mode: 0644]
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py
tensorflow/core/BUILD
tensorflow/core/util/event.proto
tensorflow/core/util/session_message.cc [deleted file]
tensorflow/core/util/session_message.h [deleted file]

index eac2104..0bdf6f6 100644 (file)
@@ -24,6 +24,7 @@ cc_library(
     name = "all_ops",
     deps = [
         ":cross_replica_ops_op_lib",
+        ":heartbeat_ops_op_lib",
         ":host_compute_ops_op_lib",
         ":infeed_ops_op_lib",
         ":outfeed_ops_op_lib",
@@ -71,6 +72,7 @@ py_library(
 tf_gen_op_libs(
     op_lib_names = [
         "cross_replica_ops",
+        "heartbeat_ops",
         "host_compute_ops",
         "infeed_ops",
         "outfeed_ops",
@@ -89,6 +91,7 @@ tf_custom_op_library(
     name = "python/ops/_tpu_ops.so",
     srcs = [
         "ops/cross_replica_ops.cc",
+        "ops/heartbeat_ops.cc",
         "ops/host_compute_ops.cc",
         "ops/infeed_ops.cc",
         "ops/outfeed_ops.cc",
@@ -106,6 +109,7 @@ tf_gen_op_wrapper_py(
     name = "tpu_ops",
     deps = [
         ":cross_replica_ops_op_lib",
+        ":heartbeat_ops_op_lib",
         ":host_compute_ops_op_lib",
         ":infeed_ops_op_lib",
         ":outfeed_ops_op_lib",
@@ -163,6 +167,7 @@ py_library(
         "python/tpu/bfloat16.py",
         "python/tpu/device_assignment.py",
         "python/tpu/keras_support.py",
+        "python/tpu/session_support.py",
         "python/tpu/topology.py",
         "python/tpu/tpu.py",
         "python/tpu/tpu_feed.py",
diff --git a/tensorflow/contrib/tpu/ops/heartbeat_ops.cc b/tensorflow/contrib/tpu/ops/heartbeat_ops.cc
new file mode 100644 (file)
index 0000000..ca0f5bc
--- /dev/null
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("WorkerHeartbeat")
+    .Input("request: string")
+    .Output("response: string")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Worker heartbeat op.
+
+Heartbeats may be sent periodically to indicate the coordinator is still active,
+to retrieve the current worker status and to expedite shutdown when necessary.
+
+request: A string tensor containing a serialized WorkerHeartbeatRequest
+response: A string tensor containing a serialized WorkerHeartbeatResponse
+)doc");
+
+}  // namespace tensorflow
index 7bf5c21..d5600ee 100644 (file)
@@ -214,20 +214,4 @@ An op that shuts down a running distributed TPU system. The Op returns
 an error if no system is running.
 )doc");
 
-REGISTER_OP("SessionStatus")
-    .Input("fetch_start_timestamp: double")
-    .Output("status: string")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Not for public usage.
-
-Returns messages from the current session as a serialized SessionStatusProto.
-
-This includes the current state of the compiler, along with any critical
-logging or warning messages.
-
-fetch_start_timestamp: any messages earlier than this will be excluded from the
-returned proto.
-)doc");
-
 }  // end namespace tensorflow
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
new file mode 100644 (file)
index 0000000..7c25f66
--- /dev/null
@@ -0,0 +1,311 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Operations for handling session logging and shutdown notifications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+import time
+from google.protobuf import text_format
+
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+
+
+class CoordinatorShutdownException(Exception):
+  """Raised when the coordinator needs to shutdown."""
+  pass
+
+
+class WorkerHeartbeatManager(object):
+  """Manages the status/heartbeat monitor for a set of workers."""
+
+  def __init__(self, session, devices, heartbeat_ops, request_placeholder):
+    """Construct a new WorkerHeartbeatManager.
+
+    (Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
+
+    Args:
+      session: `tf.Session`, session to use for heartbeat operations.
+      devices: `list[string]` Set of devices to connect to.
+      heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
+      request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
+        the WorkerHeartbeatRequest protocol buffer.
+    """
+    self._session = session
+    self._devices = devices
+    self._ops = heartbeat_ops
+    self._request_placeholder = request_placeholder
+
+  @staticmethod
+  def from_devices(session, devices):
+    """Construct a heartbeat manager for the given devices."""
+    if not devices:
+      logging.error('Trying to create heartbeat manager with no devices?')
+
+    logging.info('Creating heartbeat manager for %s', devices)
+    request_placeholder = array_ops.placeholder(
+        name='worker_heartbeat_request', dtype=dtypes.string)
+
+    heartbeat_ops = []
+    for device in devices:
+      with ops.device(device):
+        heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
+
+    return WorkerHeartbeatManager(session, devices, heartbeat_ops,
+                                  request_placeholder)
+
+  def configure(self, message):
+    """Configure heartbeat manager for all devices.
+
+    Args:
+      message: `event_pb2.WorkerHeartbeatRequest`
+
+    Returns: `None`
+
+    """
+    logging.info('Configuring worker heartbeat: %s',
+                 text_format.MessageToString(message))
+    self._session.run(self._ops,
+                      {self._request_placeholder: message.SerializeToString()})
+
+  def ping(self, request=None, timeout_in_ms=5000):
+    """Ping all workers, returning the parsed status results."""
+    if request is None:
+      request = event_pb2.WorkerHeartbeatRequest()
+
+    options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
+    results = self._session.run(
+        self._ops,
+        feed_dict={self._request_placeholder: request.SerializeToString()},
+        options=options)
+    parsed_results = [
+        event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
+        for res_pb in results
+    ]
+    logging.info('Results: %s', parsed_results)
+    return parsed_results
+
+  def lame_workers(self):
+    """Ping all workers, returning manager containing lame workers (or None)."""
+    ping_results = self.ping()
+    lame_workers = []
+
+    for ping_response, device, op in zip(ping_results, self._devices,
+                                         self._ops):
+      if ping_response.health_status != event_pb2.OK:
+        lame_workers.append((device, op))
+
+    if not lame_workers:
+      return None
+
+    bad_devices, bad_ops = zip(*lame_workers)
+    return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
+                                  self._request_placeholder)
+
+  def shutdown(self, timeout_ms=10000):
+    """Shutdown all workers after `shutdown_timeout_secs`."""
+    req = event_pb2.WorkerHeartbeatRequest(
+        watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms))
+    self.configure(req)
+
+
+def all_worker_devices(session):
+  """Return a list of devices for each worker in the system."""
+  devices = session.list_devices()
+  return [device.name for device in devices if 'CPU' in device.name]
+
+
+class WatchdogManager(threading.Thread):
+  """Configures worker watchdog timer and handles periodic pings.
+
+  Usage:
+    # Ping workers every minute, shutting down workers if they haven't received
+    # a ping after 1 hour.
+    watchdog_manager = WatchdogManager(
+      ping_interval=60, shutdown_timeout=3600
+    )
+
+    # Use as a context manager, resetting watchdog on context exit:
+    with watchdog_manager:
+      session.run(...)
+
+    # Or setup globally; watchdog will remain active until program exit.
+    watchdog_manager.configure_and_run()
+  """
+
+  def __init__(self,
+               session,
+               devices=None,
+               ping_interval=60,
+               shutdown_timeout=3600):
+    """Initialize a watchdog manager.
+
+    Args:
+
+      session: Session connected to worker devices.  A cloned session and graph
+        will be created for managing worker pings.
+      devices: Set of devices to monitor.  If none, all workers will be
+        monitored.
+      ping_interval: Time, in seconds, between watchdog pings.
+      shutdown_timeout: Time, in seconds, before watchdog timeout.
+    """
+    threading.Thread.__init__(self)
+    self.ping_interval = ping_interval
+    self.shutdown_timeout = shutdown_timeout
+    self.daemon = True
+    self._running = False
+    self._graph = ops.Graph()
+    self._session = session_lib.Session(
+        target=session.sess_str, graph=self._graph)
+
+    with self._graph.as_default():
+      if devices is None:
+        devices = all_worker_devices(self._session)
+      self._worker_manager = WorkerHeartbeatManager.from_devices(
+          self._session, devices)
+
+  def configure_and_run(self):
+    logging.info('Enabling worker watchdog.')
+    self._running = True
+    self._worker_manager.configure(
+        event_pb2.WorkerHeartbeatRequest(
+            watchdog_config=event_pb2.WatchdogConfig(
+                timeout_ms=self.shutdown_timeout * 1000,)))
+
+    self.start()
+
+  def __enter__(self):
+    self.configure_and_run()
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    logging.info('Disabling worker watchdog.')
+    self._worker_manager.configure(
+        event_pb2.WorkerHeartbeatRequest(
+            watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,)))
+    self._running = False
+    self.join()
+
+  def run(self):
+    # Don't fetch logs or adjust timing: just ping the watchdog.
+    while self._running:
+      self._worker_manager.ping(request=None)
+      time.sleep(self.ping_interval)
+
+
+class GracefulShutdownHook(session_run_hook.SessionRunHook):
+  """Session hook that watches for shutdown events.
+
+  If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
+  SystemShutdown exception is raised to terminate the main session.  If `saver`
+  is None the `SAVERS` collection will be read to find a saver.
+
+  `on_shutdown_hooks` is an optional list of functions that should be called
+  after checkpointing.  The function is called with (`run_context`,
+  `all_workers`, `lame_workers`).
+
+  If `heartbeat_group` is not specified, it will default to all CPU workers
+  in the system.
+  """
+
+  def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
+    self._saver = saver
+    self._checkpoint_prefix = checkpoint_prefix
+    self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
+
+    # Worker heartbeats are managed independently of the main training graph.
+    self._graph = ops.Graph()
+    self._workers = None
+    self._session = None
+
+  def after_create_session(self, training_session, coord):  # pylint: disable=unused-argument
+    # N.B. We have to pull the global step here to avoid it being unavailable
+    # at checkpoint time; the graph has been frozen at that point.
+    if training_util.get_global_step() is None and self.saver() is not None:
+      raise ValueError(
+          'Saver defined but no global step.  Run `get_or_create_global_step()`'
+          ' in your model definition to allow checkpointing.')
+
+    with self._graph.as_default():
+      self._session = session_lib.Session(
+          target=training_session.sess_str, graph=self._graph)
+      self._workers = WorkerHeartbeatManager.from_devices(
+          self._session, all_worker_devices(self._session))
+
+      self._workers.configure(
+          event_pb2.WorkerHeartbeatRequest(
+              shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
+
+  def saver(self):
+    if self._saver:
+      return self._saver
+
+    savers = ops.get_collection(ops.GraphKeys.SAVERS)[0]
+    if not savers:
+      return None
+
+    if not isinstance(savers, list):
+      return savers
+
+    assert len(savers) == 1, 'Only one saver supported.'
+    return savers[0]
+
+  def after_run(self, run_context, run_values):
+    del run_values
+
+    lame_workers = self._workers.lame_workers()
+    if lame_workers:
+      logging.info('ShutdownHook: lame workers found: %s', lame_workers)
+
+      if self.saver():
+        logging.info('ShutdownHook: saving checkpoint to %s',
+                     self._checkpoint_prefix)
+        self.saver().save(
+            run_context.session,
+            self._checkpoint_prefix,
+            global_step=training_util.get_global_step(),
+            write_state=True,
+        )
+      else:
+        logging.info('ShutdownHook: no Saver defined.')
+
+      for fn in self._on_shutdown_hooks:
+        fn(run_context, self._workers, lame_workers)
+
+
+def restart_computation(run_context, all_workers, lame_workers):
+  del run_context, lame_workers
+  logging.info('Shutting down all workers.')
+  all_workers.shutdown()
+
+  logging.info('Terminating coordinator.')
+  raise CoordinatorShutdownException()
+
+
+def shutdown_lame_workers(run_context, all_workers, lame_workers):
+  del run_context, all_workers
+  logging.info('Shutting down %s', lame_workers)
+  lame_workers.shutdown()
index 98eb0e2..eb537b7 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import collections
 import copy
+import os
 import signal
 import threading
 import time
@@ -31,6 +32,7 @@ from six.moves import queue as Queue  # pylint: disable=redefined-builtin
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import session_support
 from tensorflow.contrib.tpu.python.tpu import tpu
 from tensorflow.contrib.tpu.python.tpu import tpu_config
 from tensorflow.contrib.tpu.python.tpu import tpu_context
@@ -1551,7 +1553,7 @@ class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
 
 
 class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
-  """Count examples during runtime."""
+  """"Calculate and report the number of examples/sec during training."""
 
   def __init__(self,
                batch_size,
@@ -2037,6 +2039,11 @@ class TPUEstimator(estimator_lib.Estimator):
           host_ops = host_call.create_tpu_hostcall()
           if host_ops is None:
             host_ops = []
+
+          shutdown_hooks = []
+          if os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN', '0') != '0':
+            shutdown_hooks.append(session_support.GracefulShutdownHook())
+
           hooks = [
               TPUInfeedOutfeedSessionHook(
                   ctx,
@@ -2044,8 +2051,8 @@ class TPUEstimator(estimator_lib.Estimator):
                   host_ops,
                   run_infeed_loop_on_coordinator=(
                       run_infeed_loop_on_coordinator)),
-              ExamplesPerSecondHook(ctx.global_batch_size,
-                                    output_dir=self.model_dir),
+              ExamplesPerSecondHook(
+                  ctx.global_batch_size, output_dir=self.model_dir),
               InstallSignalHandlerHook(),
               training.LoggingTensorHook(
                   {
@@ -2053,7 +2060,8 @@ class TPUEstimator(estimator_lib.Estimator):
                       'step': training.get_global_step()
                   },
                   every_n_secs=30)
-          ] + input_hooks
+          ] + input_hooks + shutdown_hooks
+
           chief_hooks = []
           if (self._config.save_checkpoints_secs or
               self._config.save_checkpoints_steps):
index 3ae350c..894f21d 100644 (file)
@@ -60,7 +60,7 @@ def _query_tpu_system_metadata(master_address, run_config,
       with ops.Graph().as_default():
         with session_lib.Session(
             master_address,
-            config=_get_session_config_with_timeout(
+            config=get_session_config_with_timeout(
                 _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess:
           devices = sess.list_devices()
           for device in devices:
@@ -133,7 +133,7 @@ def _obtain_topology(master_address, run_config):
                  'for model parallelism. This might take a while.',
                  master_address)
     with ops.Graph().as_default():
-      session_config = _get_session_config_with_timeout(
+      session_config = get_session_config_with_timeout(
           _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config)
       with session_lib.Session(
           master_address, config=session_config) as sess:
@@ -146,7 +146,7 @@ def _obtain_topology(master_address, run_config):
             master_address))
 
 
-def _get_session_config_with_timeout(timeout_in_secs, run_config):
+def get_session_config_with_timeout(timeout_in_secs, run_config):
   cluster_def = None
   if run_config.session_config and run_config.session_config.cluster_def.job:
     cluster_def = run_config.session_config.cluster_def
index 32ef0a9..2a849a3 100644 (file)
@@ -457,17 +457,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "session_message",
-    srcs = ["util/session_message.cc"],
-    hdrs = ["util/session_message.h"],
-    deps = [
-        ":framework",
-        ":lib",
-        ":protos_all_cc",
-    ],
-)
-
 # Libraries that will eventually be moved into lib/core
 # Note that stringpiece_test can't be place here yet, because we are
 # required to use tf_cc_test, and that rule will change / into _
@@ -2149,7 +2138,6 @@ tf_cuda_library(
             "framework/resource_handle.cc",
             "util/memmapped_file_system.*",
             "util/memmapped_file_system_writer.*",
-            "util/session_message.cc",
             "util/version_info.cc",
         ],
     ) + select({
index 65d2c5a..9ce85be 100644 (file)
@@ -81,7 +81,35 @@ message TaggedRunMetadata {
   bytes run_metadata = 2;
 }
 
-// For communicating live events back to a coordinator
-message SessionStatus {
-  repeated Event event = 1;
+// Worker heartbeat messages.  Support for these operations is currently
+// internal and expected to change.
+
+// Current health status of a worker.
+enum WorkerHealth {
+  OK = 0;  // By default a worker is healthy.
+  RECEIVED_SHUTDOWN_SIGNAL = 1;
+  INTERNAL_ERROR = 2;
+}
+
+// Indicates the behavior of the worker when an internal error or shutdown
+// signal is received.
+enum WorkerShutdownMode {
+  DEFAULT = 0;
+  SHUTDOWN_IMMEDIATELY = 1;
+  WAIT_FOR_COORDINATOR = 2;
+}
+
+message WatchdogConfig {
+  int64 timeout_ms = 1;
+}
+
+message WorkerHeartbeatRequest {
+  WorkerShutdownMode shutdown_mode = 1;
+  WatchdogConfig watchdog_config = 2;
+}
+
+message WorkerHeartbeatResponse {
+  WorkerHealth health_status = 1;
+  repeated Event worker_log = 2;
+  string hostname = 3;
 }
diff --git a/tensorflow/core/util/session_message.cc b/tensorflow/core/util/session_message.cc
deleted file mode 100644 (file)
index 28a6517..0000000
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/util/session_message.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/util/event.pb.h"
-
-static const int kMaxLogEvents = 1000;
-
-namespace tensorflow {
-
-SessionLogger::SessionLogger() : status_(new SessionStatus) {}
-
-SessionLogger::~SessionLogger() {}
-
-string SessionLogger::DebugString() { return "SessionLogger"; }
-
-void SessionLogger::Log(StringPiece message) {
-  mutex_lock lock(mu_);
-
-  Event* event = status_->add_event();
-  event->set_wall_time(Env::Default()->NowMicros());
-  event->set_step(0);
-  LogMessage* log = event->mutable_log_message();
-  log->set_message(message.ToString());
-  log->set_level(LogMessage::INFO);
-
-  // Clip log events by 10% if we overflow
-  if (status_->event_size() > kMaxLogEvents) {
-    auto events = status_->mutable_event();
-    events->DeleteSubrange(0, kMaxLogEvents / 10);
-  }
-}
-
-SessionLogger* GetSessionLogger(ResourceMgr* rm) {
-  SessionLogger* logger;
-
-  std::function<Status(SessionLogger**)> status_creator =
-      [](SessionLogger** result) {
-        *result = new SessionLogger();
-        return Status::OK();
-      };
-
-  if (!rm->LookupOrCreate<SessionLogger>("session", "status", &logger,
-                                         status_creator)
-           .ok()) {
-    return nullptr;
-  }
-
-  return logger;
-}
-
-void LogSessionMessage(ResourceMgr* rm, StringPiece message) {
-  return GetSessionLogger(rm)->Log(message);
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/util/session_message.h b/tensorflow/core/util/session_message.h
deleted file mode 100644 (file)
index c0f3d78..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_
-#define TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_
-
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/platform/mutex.h"
-
-namespace tensorflow {
-
-class ResourceMgr;
-class SessionStatus;
-
-class SessionLogger : public ResourceBase {
- public:
-  SessionLogger();
-  ~SessionLogger();
-
-  void Log(StringPiece message);
-  string DebugString() override;
-
-  const SessionStatus& status() { return *status_; }
-
- private:
-  std::unique_ptr<SessionStatus> status_;
-  mutex mu_;
-};
-
-// Return a SessionLogger instance for the current session.  If the logger
-// will be used across multiple computations, you must explicitly acquire
-// and release references using Ref()/Unref().
-//
-// Returns nullptr if a logger cannot be created.
-SessionLogger* GetSessionLogger(ResourceMgr* rm);
-
-// Attach `message` to the logger for the current session.
-void LogSessionMessage(ResourceMgr* rm, StringPiece message);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H