From: Russell Power Date: Sun, 29 Apr 2018 22:30:22 +0000 (-0700) Subject: Add support for a clean checkpoint and shutdown in response to a termination notice. X-Git-Tag: upstream/v1.9.0_rc1~190^2^2~2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c41b546e4c193d61a79acf4cf4be621233d68ec0;p=platform%2Fupstream%2Ftensorflow.git Add support for a clean checkpoint and shutdown in response to a termination notice. PiperOrigin-RevId: 194722985 --- diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index eac2104..0bdf6f6 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -24,6 +24,7 @@ cc_library( name = "all_ops", deps = [ ":cross_replica_ops_op_lib", + ":heartbeat_ops_op_lib", ":host_compute_ops_op_lib", ":infeed_ops_op_lib", ":outfeed_ops_op_lib", @@ -71,6 +72,7 @@ py_library( tf_gen_op_libs( op_lib_names = [ "cross_replica_ops", + "heartbeat_ops", "host_compute_ops", "infeed_ops", "outfeed_ops", @@ -89,6 +91,7 @@ tf_custom_op_library( name = "python/ops/_tpu_ops.so", srcs = [ "ops/cross_replica_ops.cc", + "ops/heartbeat_ops.cc", "ops/host_compute_ops.cc", "ops/infeed_ops.cc", "ops/outfeed_ops.cc", @@ -106,6 +109,7 @@ tf_gen_op_wrapper_py( name = "tpu_ops", deps = [ ":cross_replica_ops_op_lib", + ":heartbeat_ops_op_lib", ":host_compute_ops_op_lib", ":infeed_ops_op_lib", ":outfeed_ops_op_lib", @@ -163,6 +167,7 @@ py_library( "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", "python/tpu/keras_support.py", + "python/tpu/session_support.py", "python/tpu/topology.py", "python/tpu/tpu.py", "python/tpu/tpu_feed.py", diff --git a/tensorflow/contrib/tpu/ops/heartbeat_ops.cc b/tensorflow/contrib/tpu/ops/heartbeat_ops.cc new file mode 100644 index 0000000..ca0f5bc --- /dev/null +++ b/tensorflow/contrib/tpu/ops/heartbeat_ops.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +REGISTER_OP("WorkerHeartbeat") + .Input("request: string") + .Output("response: string") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Worker heartbeat op. + +Heartbeats may be sent periodically to indicate the coordinator is still active, +to retrieve the current worker status and to expedite shutdown when necessary. + +request: A string tensor containing a serialized WorkerHeartbeatRequest +response: A string tensor containing a serialized WorkerHeartbeatResponse +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc index 7bf5c21..d5600ee 100644 --- a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc @@ -214,20 +214,4 @@ An op that shuts down a running distributed TPU system. The Op returns an error if no system is running. )doc"); -REGISTER_OP("SessionStatus") - .Input("fetch_start_timestamp: double") - .Output("status: string") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Not for public usage. - -Returns messages from the current session as a serialized SessionStatusProto. - -This includes the current state of the compiler, along with any critical -logging or warning messages. - -fetch_start_timestamp: any messages earlier than this will be excluded from the -returned proto. -)doc"); - } // end namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py new file mode 100644 index 0000000..7c25f66 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -0,0 +1,311 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Operations for handling session logging and shutdown notifications.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import time +from google.protobuf import text_format + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.util import event_pb2 +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + + +class CoordinatorShutdownException(Exception): + """Raised when the coordinator needs to shutdown.""" + pass + + +class WorkerHeartbeatManager(object): + """Manages the status/heartbeat monitor for a set of workers.""" + + def __init__(self, session, devices, heartbeat_ops, request_placeholder): + """Construct a new WorkerHeartbeatManager. + + (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) + + Args: + session: `tf.Session`, session to use for heartbeat operations. + devices: `list[string]` Set of devices to connect to. + heartbeat_ops: `list[tf.Operation]` Heartbeat operations. + request_placeholder: `tf.Placeholder[String]` Placeholder used to specify + the WorkerHeartbeatRequest protocol buffer. + """ + self._session = session + self._devices = devices + self._ops = heartbeat_ops + self._request_placeholder = request_placeholder + + @staticmethod + def from_devices(session, devices): + """Construct a heartbeat manager for the given devices.""" + if not devices: + logging.error('Trying to create heartbeat manager with no devices?') + + logging.info('Creating heartbeat manager for %s', devices) + request_placeholder = array_ops.placeholder( + name='worker_heartbeat_request', dtype=dtypes.string) + + heartbeat_ops = [] + for device in devices: + with ops.device(device): + heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) + + return WorkerHeartbeatManager(session, devices, heartbeat_ops, + request_placeholder) + + def configure(self, message): + """Configure heartbeat manager for all devices. + + Args: + message: `event_pb2.WorkerHeartbeatRequest` + + Returns: `None` + + """ + logging.info('Configuring worker heartbeat: %s', + text_format.MessageToString(message)) + self._session.run(self._ops, + {self._request_placeholder: message.SerializeToString()}) + + def ping(self, request=None, timeout_in_ms=5000): + """Ping all workers, returning the parsed status results.""" + if request is None: + request = event_pb2.WorkerHeartbeatRequest() + + options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) + results = self._session.run( + self._ops, + feed_dict={self._request_placeholder: request.SerializeToString()}, + options=options) + parsed_results = [ + event_pb2.WorkerHeartbeatResponse.FromString(res_pb) + for res_pb in results + ] + logging.info('Results: %s', parsed_results) + return parsed_results + + def lame_workers(self): + """Ping all workers, returning manager containing lame workers (or None).""" + ping_results = self.ping() + lame_workers = [] + + for ping_response, device, op in zip(ping_results, self._devices, + self._ops): + if ping_response.health_status != event_pb2.OK: + lame_workers.append((device, op)) + + if not lame_workers: + return None + + bad_devices, bad_ops = zip(*lame_workers) + return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, + self._request_placeholder) + + def shutdown(self, timeout_ms=10000): + """Shutdown all workers after `shutdown_timeout_secs`.""" + req = event_pb2.WorkerHeartbeatRequest( + watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms)) + self.configure(req) + + +def all_worker_devices(session): + """Return a list of devices for each worker in the system.""" + devices = session.list_devices() + return [device.name for device in devices if 'CPU' in device.name] + + +class WatchdogManager(threading.Thread): + """Configures worker watchdog timer and handles periodic pings. + + Usage: + # Ping workers every minute, shutting down workers if they haven't received + # a ping after 1 hour. + watchdog_manager = WatchdogManager( + ping_interval=60, shutdown_timeout=3600 + ) + + # Use as a context manager, resetting watchdog on context exit: + with watchdog_manager: + session.run(...) + + # Or setup globally; watchdog will remain active until program exit. + watchdog_manager.configure_and_run() + """ + + def __init__(self, + session, + devices=None, + ping_interval=60, + shutdown_timeout=3600): + """Initialize a watchdog manager. + + Args: + + session: Session connected to worker devices. A cloned session and graph + will be created for managing worker pings. + devices: Set of devices to monitor. If none, all workers will be + monitored. + ping_interval: Time, in seconds, between watchdog pings. + shutdown_timeout: Time, in seconds, before watchdog timeout. + """ + threading.Thread.__init__(self) + self.ping_interval = ping_interval + self.shutdown_timeout = shutdown_timeout + self.daemon = True + self._running = False + self._graph = ops.Graph() + self._session = session_lib.Session( + target=session.sess_str, graph=self._graph) + + with self._graph.as_default(): + if devices is None: + devices = all_worker_devices(self._session) + self._worker_manager = WorkerHeartbeatManager.from_devices( + self._session, devices) + + def configure_and_run(self): + logging.info('Enabling worker watchdog.') + self._running = True + self._worker_manager.configure( + event_pb2.WorkerHeartbeatRequest( + watchdog_config=event_pb2.WatchdogConfig( + timeout_ms=self.shutdown_timeout * 1000,))) + + self.start() + + def __enter__(self): + self.configure_and_run() + + def __exit__(self, exc_type, exc_val, exc_tb): + logging.info('Disabling worker watchdog.') + self._worker_manager.configure( + event_pb2.WorkerHeartbeatRequest( + watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,))) + self._running = False + self.join() + + def run(self): + # Don't fetch logs or adjust timing: just ping the watchdog. + while self._running: + self._worker_manager.ping(request=None) + time.sleep(self.ping_interval) + + +class GracefulShutdownHook(session_run_hook.SessionRunHook): + """Session hook that watches for shutdown events. + + If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a + SystemShutdown exception is raised to terminate the main session. If `saver` + is None the `SAVERS` collection will be read to find a saver. + + `on_shutdown_hooks` is an optional list of functions that should be called + after checkpointing. The function is called with (`run_context`, + `all_workers`, `lame_workers`). + + If `heartbeat_group` is not specified, it will default to all CPU workers + in the system. + """ + + def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): + self._saver = saver + self._checkpoint_prefix = checkpoint_prefix + self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] + + # Worker heartbeats are managed independently of the main training graph. + self._graph = ops.Graph() + self._workers = None + self._session = None + + def after_create_session(self, training_session, coord): # pylint: disable=unused-argument + # N.B. We have to pull the global step here to avoid it being unavailable + # at checkpoint time; the graph has been frozen at that point. + if training_util.get_global_step() is None and self.saver() is not None: + raise ValueError( + 'Saver defined but no global step. Run `get_or_create_global_step()`' + ' in your model definition to allow checkpointing.') + + with self._graph.as_default(): + self._session = session_lib.Session( + target=training_session.sess_str, graph=self._graph) + self._workers = WorkerHeartbeatManager.from_devices( + self._session, all_worker_devices(self._session)) + + self._workers.configure( + event_pb2.WorkerHeartbeatRequest( + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) + + def saver(self): + if self._saver: + return self._saver + + savers = ops.get_collection(ops.GraphKeys.SAVERS)[0] + if not savers: + return None + + if not isinstance(savers, list): + return savers + + assert len(savers) == 1, 'Only one saver supported.' + return savers[0] + + def after_run(self, run_context, run_values): + del run_values + + lame_workers = self._workers.lame_workers() + if lame_workers: + logging.info('ShutdownHook: lame workers found: %s', lame_workers) + + if self.saver(): + logging.info('ShutdownHook: saving checkpoint to %s', + self._checkpoint_prefix) + self.saver().save( + run_context.session, + self._checkpoint_prefix, + global_step=training_util.get_global_step(), + write_state=True, + ) + else: + logging.info('ShutdownHook: no Saver defined.') + + for fn in self._on_shutdown_hooks: + fn(run_context, self._workers, lame_workers) + + +def restart_computation(run_context, all_workers, lame_workers): + del run_context, lame_workers + logging.info('Shutting down all workers.') + all_workers.shutdown() + + logging.info('Terminating coordinator.') + raise CoordinatorShutdownException() + + +def shutdown_lame_workers(run_context, all_workers, lame_workers): + del run_context, all_workers + logging.info('Shutting down %s', lame_workers) + lame_workers.shutdown() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 98eb0e2..eb537b7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import copy +import os import signal import threading import time @@ -31,6 +32,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import session_support from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_context @@ -1551,7 +1553,7 @@ class _OutfeedHostCallHook(session_run_hook.SessionRunHook): class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): - """Count examples during runtime.""" + """"Calculate and report the number of examples/sec during training.""" def __init__(self, batch_size, @@ -2037,6 +2039,11 @@ class TPUEstimator(estimator_lib.Estimator): host_ops = host_call.create_tpu_hostcall() if host_ops is None: host_ops = [] + + shutdown_hooks = [] + if os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN', '0') != '0': + shutdown_hooks.append(session_support.GracefulShutdownHook()) + hooks = [ TPUInfeedOutfeedSessionHook( ctx, @@ -2044,8 +2051,8 @@ class TPUEstimator(estimator_lib.Estimator): host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator)), - ExamplesPerSecondHook(ctx.global_batch_size, - output_dir=self.model_dir), + ExamplesPerSecondHook( + ctx.global_batch_size, output_dir=self.model_dir), InstallSignalHandlerHook(), training.LoggingTensorHook( { @@ -2053,7 +2060,8 @@ class TPUEstimator(estimator_lib.Estimator): 'step': training.get_global_step() }, every_n_secs=30) - ] + input_hooks + ] + input_hooks + shutdown_hooks + chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index 3ae350c..894f21d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -60,7 +60,7 @@ def _query_tpu_system_metadata(master_address, run_config, with ops.Graph().as_default(): with session_lib.Session( master_address, - config=_get_session_config_with_timeout( + config=get_session_config_with_timeout( _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess: devices = sess.list_devices() for device in devices: @@ -133,7 +133,7 @@ def _obtain_topology(master_address, run_config): 'for model parallelism. This might take a while.', master_address) with ops.Graph().as_default(): - session_config = _get_session_config_with_timeout( + session_config = get_session_config_with_timeout( _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config) with session_lib.Session( master_address, config=session_config) as sess: @@ -146,7 +146,7 @@ def _obtain_topology(master_address, run_config): master_address)) -def _get_session_config_with_timeout(timeout_in_secs, run_config): +def get_session_config_with_timeout(timeout_in_secs, run_config): cluster_def = None if run_config.session_config and run_config.session_config.cluster_def.job: cluster_def = run_config.session_config.cluster_def diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 32ef0a9..2a849a3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -457,17 +457,6 @@ cc_library( ], ) -cc_library( - name = "session_message", - srcs = ["util/session_message.cc"], - hdrs = ["util/session_message.h"], - deps = [ - ":framework", - ":lib", - ":protos_all_cc", - ], -) - # Libraries that will eventually be moved into lib/core # Note that stringpiece_test can't be place here yet, because we are # required to use tf_cc_test, and that rule will change / into _ @@ -2149,7 +2138,6 @@ tf_cuda_library( "framework/resource_handle.cc", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", - "util/session_message.cc", "util/version_info.cc", ], ) + select({ diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto index 65d2c5a..9ce85be 100644 --- a/tensorflow/core/util/event.proto +++ b/tensorflow/core/util/event.proto @@ -81,7 +81,35 @@ message TaggedRunMetadata { bytes run_metadata = 2; } -// For communicating live events back to a coordinator -message SessionStatus { - repeated Event event = 1; +// Worker heartbeat messages. Support for these operations is currently +// internal and expected to change. + +// Current health status of a worker. +enum WorkerHealth { + OK = 0; // By default a worker is healthy. + RECEIVED_SHUTDOWN_SIGNAL = 1; + INTERNAL_ERROR = 2; +} + +// Indicates the behavior of the worker when an internal error or shutdown +// signal is received. +enum WorkerShutdownMode { + DEFAULT = 0; + SHUTDOWN_IMMEDIATELY = 1; + WAIT_FOR_COORDINATOR = 2; +} + +message WatchdogConfig { + int64 timeout_ms = 1; +} + +message WorkerHeartbeatRequest { + WorkerShutdownMode shutdown_mode = 1; + WatchdogConfig watchdog_config = 2; +} + +message WorkerHeartbeatResponse { + WorkerHealth health_status = 1; + repeated Event worker_log = 2; + string hostname = 3; } diff --git a/tensorflow/core/util/session_message.cc b/tensorflow/core/util/session_message.cc deleted file mode 100644 index 28a6517..0000000 --- a/tensorflow/core/util/session_message.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/util/session_message.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/util/event.pb.h" - -static const int kMaxLogEvents = 1000; - -namespace tensorflow { - -SessionLogger::SessionLogger() : status_(new SessionStatus) {} - -SessionLogger::~SessionLogger() {} - -string SessionLogger::DebugString() { return "SessionLogger"; } - -void SessionLogger::Log(StringPiece message) { - mutex_lock lock(mu_); - - Event* event = status_->add_event(); - event->set_wall_time(Env::Default()->NowMicros()); - event->set_step(0); - LogMessage* log = event->mutable_log_message(); - log->set_message(message.ToString()); - log->set_level(LogMessage::INFO); - - // Clip log events by 10% if we overflow - if (status_->event_size() > kMaxLogEvents) { - auto events = status_->mutable_event(); - events->DeleteSubrange(0, kMaxLogEvents / 10); - } -} - -SessionLogger* GetSessionLogger(ResourceMgr* rm) { - SessionLogger* logger; - - std::function status_creator = - [](SessionLogger** result) { - *result = new SessionLogger(); - return Status::OK(); - }; - - if (!rm->LookupOrCreate("session", "status", &logger, - status_creator) - .ok()) { - return nullptr; - } - - return logger; -} - -void LogSessionMessage(ResourceMgr* rm, StringPiece message) { - return GetSessionLogger(rm)->Log(message); -} - -} // namespace tensorflow diff --git a/tensorflow/core/util/session_message.h b/tensorflow/core/util/session_message.h deleted file mode 100644 index c0f3d78..0000000 --- a/tensorflow/core/util/session_message.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_ -#define TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_ - -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -class ResourceMgr; -class SessionStatus; - -class SessionLogger : public ResourceBase { - public: - SessionLogger(); - ~SessionLogger(); - - void Log(StringPiece message); - string DebugString() override; - - const SessionStatus& status() { return *status_; } - - private: - std::unique_ptr status_; - mutex mu_; -}; - -// Return a SessionLogger instance for the current session. If the logger -// will be used across multiple computations, you must explicitly acquire -// and release references using Ref()/Unref(). -// -// Returns nullptr if a logger cannot be created. -SessionLogger* GetSessionLogger(ResourceMgr* rm); - -// Attach `message` to the logger for the current session. -void LogSessionMessage(ResourceMgr* rm, StringPiece message); - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H