tfdbg: fix issue where total source file size exceeds gRPC message size limit
authorShanqing Cai <cais@google.com>
Thu, 24 May 2018 21:02:30 +0000 (14:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 21:05:21 +0000 (14:05 -0700)
* Source file content is now sent one by one, making it less likely that individual
  messages will have sizes above the 4-MB gRPC message size limit.
* In case the message for a single source file exceeds the limit, the client handles
  it gracefully by skipping the sending and print a warning message.

Fixes: https://github.com/tensorflow/tensorboard/issues/1118
PiperOrigin-RevId: 197949416

tensorflow/python/debug/BUILD
tensorflow/python/debug/lib/grpc_debug_test_server.py
tensorflow/python/debug/lib/source_remote.py
tensorflow/python/debug/lib/source_remote_test.py

index 16ae74a..09062ab 100644 (file)
@@ -572,6 +572,7 @@ py_test(
         ":source_utils",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:client",
+        "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:math_ops",
index 9170046..a7be209 100644 (file)
@@ -245,7 +245,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
     self._origin_id_to_strings = []
     self._graph_tracebacks = []
     self._graph_versions = []
-    self._source_files = None
+    self._source_files = []
 
   def _initialize_toggle_watch_state(self, toggle_watches):
     self._toggle_watches = toggle_watches
@@ -274,7 +274,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
     self._origin_id_to_strings = []
     self._graph_tracebacks = []
     self._graph_versions = []
-    self._source_files = None
+    self._source_files = []
 
   def SendTracebacks(self, request, context):
     self._call_types.append(request.call_type)
@@ -286,7 +286,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
     return debug_service_pb2.EventReply()
 
   def SendSourceFiles(self, request, context):
-    self._source_files = request
+    self._source_files.append(request)
     return debug_service_pb2.EventReply()
 
   def query_op_traceback(self, op_name):
@@ -351,9 +351,10 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
     if not self._source_files:
       raise ValueError(
           "This debug server has not received any source file contents yet.")
-    for source_file_proto in self._source_files.source_files:
-      if source_file_proto.file_path == file_path:
-        return source_file_proto.lines[lineno - 1]
+    for source_files in self._source_files:
+      for source_file_proto in source_files.source_files:
+        if source_file_proto.file_path == file_path:
+          return source_file_proto.lines[lineno - 1]
     raise ValueError(
         "Source file at path %s has not been received by the debug server",
         file_path)
index 4b6b2b9..4afae41 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.debug.lib import common
 from tensorflow.python.debug.lib import debug_service_pb2_grpc
 from tensorflow.python.debug.lib import source_utils
 from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
 from tensorflow.python.profiler import tfprof_logger
 
 
@@ -95,6 +96,11 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
   return non_tf_files
 
 
+def grpc_message_length_bytes():
+  """Maximum gRPC message length in bytes."""
+  return 4 * 1024 * 1024
+
+
 def _send_call_tracebacks(destinations,
                           origin_stack,
                           is_eager_execution=False,
@@ -155,17 +161,28 @@ def _send_call_tracebacks(destinations,
     source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
         [call_traceback.origin_stack], call_traceback.origin_id_to_string))
 
-    debugged_source_files = debug_pb2.DebuggedSourceFiles()
+    debugged_source_files = []
     for file_path in source_file_paths:
+      source_files = debug_pb2.DebuggedSourceFiles()
       _load_debugged_source_file(
-          file_path, debugged_source_files.source_files.add())
+          file_path, source_files.source_files.add())
+      debugged_source_files.append(source_files)
 
   for destination in destinations:
     channel = grpc.insecure_channel(destination)
     stub = debug_service_pb2_grpc.EventListenerStub(channel)
     stub.SendTracebacks(call_traceback)
     if send_source:
-      stub.SendSourceFiles(debugged_source_files)
+      for path, source_files in zip(
+          source_file_paths, debugged_source_files):
+        if source_files.ByteSize() < grpc_message_length_bytes():
+          stub.SendSourceFiles(source_files)
+        else:
+          tf_logging.warn(
+              "The content of the source file at %s is not sent to "
+              "gRPC debug server %s, because the message size exceeds "
+              "gRPC message length limit (%d bytes)." % (
+                  path, destination, grpc_message_length_bytes()))
 
 
 def send_graph_tracebacks(destinations,
index 27bafa4..29add42 100644 (file)
@@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
 from tensorflow.python.util import tf_inspect
 
 
@@ -155,6 +156,51 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
         self.assertEqual(["dummy_run_key"], server.query_call_keys())
         self.assertEqual([sess.graph.version], server.query_graph_versions())
 
+  def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
+    """In case source file size exceeds the grpc message length limit.
+
+    it ought not to have been sent to the server.
+    """
+    this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
+
+    # Patch the method to simulate a very small message length limit.
+    with test.mock.patch.object(
+        source_remote, "grpc_message_length_bytes", return_value=2):
+      with session.Session() as sess:
+        a = variables.Variable(21.0, name="two/a")
+        a_lineno = line_number_above()
+        b = variables.Variable(2.0, name="two/b")
+        b_lineno = line_number_above()
+        x = math_ops.add(a, b, name="two/x")
+        x_lineno = line_number_above()
+
+        send_traceback = traceback.extract_stack()
+        send_lineno = line_number_above()
+        source_remote.send_graph_tracebacks(
+            [self._server_address, self._server_address_2],
+            "dummy_run_key", send_traceback, sess.graph)
+
+        servers = [self._server, self._server_2]
+        for server in servers:
+          # Even though the source file content is not sent, the traceback
+          # should have been sent.
+          tb = server.query_op_traceback("two/a")
+          self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
+          tb = server.query_op_traceback("two/b")
+          self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
+          tb = server.query_op_traceback("two/x")
+          self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
+
+          self.assertIn(
+              (self._curr_file_path, send_lineno, this_func_name),
+              server.query_origin_stack()[-1])
+
+          tf_trace_file_path = (
+              self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
+          # Verify that the source content is not sent to the server.
+          with self.assertRaises(ValueError):
+            self._server.query_source_file_line(tf_trace_file_path, 0)
+
   def testSendEagerTracebacksToSingleDebugServer(self):
     this_func_name = "testSendEagerTracebacksToSingleDebugServer"
     send_traceback = traceback.extract_stack()