From d2090672fe8305289156460c43f7fcc1a5dd5422 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Thu, 24 May 2018 14:02:30 -0700 Subject: [PATCH] tfdbg: fix issue where total source file size exceeds gRPC message size limit * Source file content is now sent one by one, making it less likely that individual messages will have sizes above the 4-MB gRPC message size limit. * In case the message for a single source file exceeds the limit, the client handles it gracefully by skipping the sending and print a warning message. Fixes: https://github.com/tensorflow/tensorboard/issues/1118 PiperOrigin-RevId: 197949416 --- tensorflow/python/debug/BUILD | 1 + .../python/debug/lib/grpc_debug_test_server.py | 13 +++--- tensorflow/python/debug/lib/source_remote.py | 23 +++++++++-- tensorflow/python/debug/lib/source_remote_test.py | 46 ++++++++++++++++++++++ 4 files changed, 74 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 16ae74a..09062ab 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -572,6 +572,7 @@ py_test( ":source_utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", + "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py index 9170046..a7be209 100644 --- a/tensorflow/python/debug/lib/grpc_debug_test_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py @@ -245,7 +245,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): self._origin_id_to_strings = [] self._graph_tracebacks = [] self._graph_versions = [] - self._source_files = None + self._source_files = [] def _initialize_toggle_watch_state(self, toggle_watches): self._toggle_watches = toggle_watches @@ -274,7 +274,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): self._origin_id_to_strings = [] self._graph_tracebacks = [] self._graph_versions = [] - self._source_files = None + self._source_files = [] def SendTracebacks(self, request, context): self._call_types.append(request.call_type) @@ -286,7 +286,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): return debug_service_pb2.EventReply() def SendSourceFiles(self, request, context): - self._source_files = request + self._source_files.append(request) return debug_service_pb2.EventReply() def query_op_traceback(self, op_name): @@ -351,9 +351,10 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): if not self._source_files: raise ValueError( "This debug server has not received any source file contents yet.") - for source_file_proto in self._source_files.source_files: - if source_file_proto.file_path == file_path: - return source_file_proto.lines[lineno - 1] + for source_files in self._source_files: + for source_file_proto in source_files.source_files: + if source_file_proto.file_path == file_path: + return source_file_proto.lines[lineno - 1] raise ValueError( "Source file at path %s has not been received by the debug server", file_path) diff --git a/tensorflow/python/debug/lib/source_remote.py b/tensorflow/python/debug/lib/source_remote.py index 4b6b2b9..4afae41 100644 --- a/tensorflow/python/debug/lib/source_remote.py +++ b/tensorflow/python/debug/lib/source_remote.py @@ -28,6 +28,7 @@ from tensorflow.python.debug.lib import common from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.debug.lib import source_utils from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging from tensorflow.python.profiler import tfprof_logger @@ -95,6 +96,11 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string): return non_tf_files +def grpc_message_length_bytes(): + """Maximum gRPC message length in bytes.""" + return 4 * 1024 * 1024 + + def _send_call_tracebacks(destinations, origin_stack, is_eager_execution=False, @@ -155,17 +161,28 @@ def _send_call_tracebacks(destinations, source_file_paths.update(_source_file_paths_outside_tensorflow_py_library( [call_traceback.origin_stack], call_traceback.origin_id_to_string)) - debugged_source_files = debug_pb2.DebuggedSourceFiles() + debugged_source_files = [] for file_path in source_file_paths: + source_files = debug_pb2.DebuggedSourceFiles() _load_debugged_source_file( - file_path, debugged_source_files.source_files.add()) + file_path, source_files.source_files.add()) + debugged_source_files.append(source_files) for destination in destinations: channel = grpc.insecure_channel(destination) stub = debug_service_pb2_grpc.EventListenerStub(channel) stub.SendTracebacks(call_traceback) if send_source: - stub.SendSourceFiles(debugged_source_files) + for path, source_files in zip( + source_file_paths, debugged_source_files): + if source_files.ByteSize() < grpc_message_length_bytes(): + stub.SendSourceFiles(source_files) + else: + tf_logging.warn( + "The content of the source file at %s is not sent to " + "gRPC debug server %s, because the message size exceeds " + "gRPC message length limit (%d bytes)." % ( + path, destination, grpc_message_length_bytes())) def send_graph_tracebacks(destinations, diff --git a/tensorflow/python/debug/lib/source_remote_test.py b/tensorflow/python/debug/lib/source_remote_test.py index 27bafa4..29add42 100644 --- a/tensorflow/python/debug/lib/source_remote_test.py +++ b/tensorflow/python/debug/lib/source_remote_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.platform import test from tensorflow.python.util import tf_inspect @@ -155,6 +156,51 @@ class SendTracebacksTest(test_util.TensorFlowTestCase): self.assertEqual(["dummy_run_key"], server.query_call_keys()) self.assertEqual([sess.graph.version], server.query_graph_versions()) + def testSourceFileSizeExceedsGrpcMessageLengthLimit(self): + """In case source file size exceeds the grpc message length limit. + + it ought not to have been sent to the server. + """ + this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit" + + # Patch the method to simulate a very small message length limit. + with test.mock.patch.object( + source_remote, "grpc_message_length_bytes", return_value=2): + with session.Session() as sess: + a = variables.Variable(21.0, name="two/a") + a_lineno = line_number_above() + b = variables.Variable(2.0, name="two/b") + b_lineno = line_number_above() + x = math_ops.add(a, b, name="two/x") + x_lineno = line_number_above() + + send_traceback = traceback.extract_stack() + send_lineno = line_number_above() + source_remote.send_graph_tracebacks( + [self._server_address, self._server_address_2], + "dummy_run_key", send_traceback, sess.graph) + + servers = [self._server, self._server_2] + for server in servers: + # Even though the source file content is not sent, the traceback + # should have been sent. + tb = server.query_op_traceback("two/a") + self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) + tb = server.query_op_traceback("two/b") + self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) + tb = server.query_op_traceback("two/x") + self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) + + self.assertIn( + (self._curr_file_path, send_lineno, this_func_name), + server.query_origin_stack()[-1]) + + tf_trace_file_path = ( + self._findFirstTraceInsideTensorFlowPyLibrary(x.op)) + # Verify that the source content is not sent to the server. + with self.assertRaises(ValueError): + self._server.query_source_file_line(tf_trace_file_path, 0) + def testSendEagerTracebacksToSingleDebugServer(self): this_func_name = "testSendEagerTracebacksToSingleDebugServer" send_traceback = traceback.extract_stack() -- 2.7.4