tfdbg: split session_debug_grpc_test
authorShanqing Cai <cais@google.com>
Wed, 14 Mar 2018 00:11:56 +0000 (17:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 00:25:43 +0000 (17:25 -0700)
* so that the test sizes are medium for both the existing session_debug_grpc_test
  and the new grpc_large_data_test

Also in this CL
* Consolidate the functions for creating no-grappler-rewrite ConfigProtos
  in one place: in session_debug_testlib.py

PiperOrigin-RevId: 188955135

tensorflow/contrib/cmake/tf_tests.cmake
tensorflow/python/debug/BUILD
tensorflow/python/debug/lib/grpc_large_data_test.py [new file with mode: 0644]
tensorflow/python/debug/lib/session_debug_file_test.py
tensorflow/python/debug/lib/session_debug_grpc_test.py

index 1c4ebd7..9f96a4b 100644 (file)
@@ -222,6 +222,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
       "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py"
       # TFDBG grpc:// mode is not yet available on Windows.
       "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py"
+      "${tensorflow_source_dir}/tensorflow/python/debug/lib/grpc_large_data_test.py"
       "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py"
       "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py"
       # stl on windows handles overflows different
index 253588f..512d292 100644 (file)
@@ -957,7 +957,7 @@ cuda_py_test(
 
 cuda_py_test(
     name = "session_debug_grpc_test",
-    size = "large",
+    size = "medium",
     srcs = ["lib/session_debug_grpc_test.py"],
     additional_deps = [
         ":debug_data",
@@ -967,7 +967,6 @@ cuda_py_test(
         ":grpc_wrapper",
         ":hooks",
         ":session_debug_testlib",
-        "//third_party/py/numpy",
         "//tensorflow/python:client",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
@@ -983,6 +982,29 @@ cuda_py_test(
     ],
 )
 
+cuda_py_test(
+    name = "grpc_large_data_test",
+    size = "medium",
+    srcs = ["lib/grpc_large_data_test.py"],
+    additional_deps = [
+        ":dumping_wrapper",
+        ":grpc_debug_test_server",
+        ":grpc_wrapper",
+        ":session_debug_testlib",
+        "//third_party/py/numpy",
+        "//tensorflow/python:client",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:variables",
+    ],
+    tags = [
+        "no_oss",  # Test flaky due to port collisions.
+        "no_windows",
+        "oss_serial",
+    ],
+)
+
 # TODO(cais): Run the test in OSS, perhaps through a sh_test.
 cuda_py_test(
     name = "dist_session_debug_grpc_test",
diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py
new file mode 100644 (file)
index 0000000..5bc477a
--- /dev/null
@@ -0,0 +1,210 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sending large-size data through tfdbg grpc channels.
+
+"Large-size data" includes large GraphDef protos and large Tensor protos.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.python.debug.lib import grpc_debug_test_server
+from tensorflow.python.debug.lib import session_debug_testlib
+from tensorflow.python.debug.wrappers import framework
+from tensorflow.python.debug.wrappers import grpc_wrapper
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
+
+  @classmethod
+  def setUpClass(cls):
+    (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
+     cls.debug_server
+    ) = grpc_debug_test_server.start_server_on_separate_thread(
+        dump_to_filesystem=False)
+    tf_logging.info("debug server url: %s", cls.debug_server_url)
+
+  @classmethod
+  def tearDownClass(cls):
+    cls.debug_server.stop_server().wait()
+    cls.debug_server_thread.join()
+
+  def tearDown(self):
+    ops.reset_default_graph()
+    self.debug_server.clear_data()
+
+  def testSendingLargeGraphDefsWorks(self):
+    with self.test_session(
+        use_gpu=True,
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
+      u = variables.Variable(42.0, name="original_u")
+      for _ in xrange(50 * 1000):
+        u = array_ops.identity(u)
+      sess.run(variables.global_variables_initializer())
+
+      def watch_fn(fetches, feeds):
+        del fetches, feeds
+        return framework.WatchOptions(
+            debug_ops=["DebugIdentity"],
+            node_name_regex_whitelist=r"original_u")
+      sess = grpc_wrapper.GrpcDebugWrapperSession(
+          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+      self.assertAllClose(42.0, sess.run(u))
+
+      self.assertAllClose(
+          [42.0],
+          self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
+      self.assertEqual(2 if test.is_gpu_available() else 1,
+                       len(self.debug_server.partition_graph_defs))
+      max_graph_def_size = max([
+          len(graph_def.SerializeToString())
+          for graph_def in self.debug_server.partition_graph_defs])
+      self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
+
+  def testSendingLargeFloatTensorWorks(self):
+    with self.test_session(
+        use_gpu=True,
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
+      u_init_val_array = list(xrange(1200 * 1024))
+      # Size: 4 * 1200 * 1024 = 4800k > 4M
+
+      u_init = constant_op.constant(
+          u_init_val_array, dtype=dtypes.float32, name="u_init")
+      u = variables.Variable(u_init, name="u")
+
+      def watch_fn(fetches, feeds):
+        del fetches, feeds  # Unused by this watch_fn.
+        return framework.WatchOptions(
+            debug_ops=["DebugIdentity"],
+            node_name_regex_whitelist=r"u_init")
+      sess = grpc_wrapper.GrpcDebugWrapperSession(
+          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+      sess.run(u.initializer)
+
+      self.assertAllEqual(
+          u_init_val_array,
+          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+  def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
+    with self.test_session(
+        use_gpu=True,
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
+      u_init_val = [
+          b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
+      u_init = constant_op.constant(
+          u_init_val, dtype=dtypes.string, name="u_init")
+      u = variables.Variable(u_init, name="u")
+
+      def watch_fn(fetches, feeds):
+        del fetches, feeds
+        return framework.WatchOptions(
+            debug_ops=["DebugIdentity"],
+            node_name_regex_whitelist=r"u_init")
+      sess = grpc_wrapper.GrpcDebugWrapperSession(
+          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+      sess.run(u.initializer)
+
+      self.assertAllEqual(
+          u_init_val,
+          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+  def testSendingLargeStringTensorWorks(self):
+    with self.test_session(
+        use_gpu=True,
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
+      strs_total_size_threshold = 5000 * 1024
+      cum_size = 0
+      u_init_val_array = []
+      while cum_size < strs_total_size_threshold:
+        strlen = np.random.randint(200)
+        u_init_val_array.append(b"A" * strlen)
+        cum_size += strlen
+
+      u_init = constant_op.constant(
+          u_init_val_array, dtype=dtypes.string, name="u_init")
+      u = variables.Variable(u_init, name="u")
+
+      def watch_fn(fetches, feeds):
+        del fetches, feeds
+        return framework.WatchOptions(
+            debug_ops=["DebugIdentity"],
+            node_name_regex_whitelist=r"u_init")
+      sess = grpc_wrapper.GrpcDebugWrapperSession(
+          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+      sess.run(u.initializer)
+
+      self.assertAllEqual(
+          u_init_val_array,
+          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+  def testSendingEmptyFloatTensorWorks(self):
+    with self.test_session(
+        use_gpu=True,
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
+      u_init = constant_op.constant(
+          [], dtype=dtypes.float32, shape=[0], name="u_init")
+      u = variables.Variable(u_init, name="u")
+
+      def watch_fn(fetches, feeds):
+        del fetches, feeds
+        return framework.WatchOptions(
+            debug_ops=["DebugIdentity"],
+            node_name_regex_whitelist=r"u_init")
+      sess = grpc_wrapper.GrpcDebugWrapperSession(
+          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+      sess.run(u.initializer)
+
+      u_init_value = self.debug_server.debug_tensor_values[
+          "u_init:0:DebugIdentity"][0]
+      self.assertEqual(np.float32, u_init_value.dtype)
+      self.assertEqual(0, len(u_init_value))
+
+  def testSendingEmptyStringTensorWorks(self):
+    with self.test_session(
+        use_gpu=True,
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
+      u_init = constant_op.constant(
+          [], dtype=dtypes.string, shape=[0], name="u_init")
+      u = variables.Variable(u_init, name="u")
+
+      def watch_fn(fetches, feeds):
+        del fetches, feeds
+        return framework.WatchOptions(
+            debug_ops=["DebugIdentity"],
+            node_name_regex_whitelist=r"u_init")
+      sess = grpc_wrapper.GrpcDebugWrapperSession(
+          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+      sess.run(u.initializer)
+
+      u_init_value = self.debug_server.debug_tensor_values[
+          "u_init:0:DebugIdentity"][0]
+      self.assertEqual(np.object, u_init_value.dtype)
+      self.assertEqual(0, len(u_init_value))
+
+
+if __name__ == "__main__":
+  googletest.main()
index 1a6bedb..ba0f15b 100644 (file)
@@ -22,7 +22,6 @@ import shutil
 import tempfile
 
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.debug.lib import debug_data
 from tensorflow.python.debug.lib import debug_utils
@@ -36,13 +35,6 @@ from tensorflow.python.platform import googletest
 
 class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
 
-  def _no_rewrite_session_config(self):
-    rewriter_config = rewriter_config_pb2.RewriterConfig(
-        disable_model_pruning=True,
-        arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
-    graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
-    return config_pb2.ConfigProto(graph_options=graph_options)
-
   def _debug_urls(self, run_number=None):
     return ["file://%s" % self._debug_dump_dir(run_number=run_number)]
 
@@ -55,7 +47,8 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
   def testAllowsDifferentWatchesOnDifferentRuns(self):
     """Test watching different tensors on different runs of the same graph."""
 
-    with session.Session(config=self._no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       u_init_val = [[5.0, 3.0], [-1.0, 0.0]]
       v_init_val = [[2.0], [-1.0]]
 
index b623ee3..ff49b69 100644 (file)
@@ -24,11 +24,9 @@ from __future__ import print_function
 import os
 import shutil
 
-import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.debug.lib import debug_data
 from tensorflow.python.debug.lib import debug_utils
@@ -38,28 +36,15 @@ from tensorflow.python.debug.wrappers import framework
 from tensorflow.python.debug.wrappers import grpc_wrapper
 from tensorflow.python.debug.wrappers import hooks
 from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
 from tensorflow.python.training import monitored_session
 
 
-def no_rewrite_session_config():
-  rewriter_config = rewriter_config_pb2.RewriterConfig(
-      disable_model_pruning=True,
-      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
-      dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
-  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
-  return config_pb2.ConfigProto(graph_options=graph_options)
-
-
 class GrpcDebugServerTest(test_util.TensorFlowTestCase):
 
   def testRepeatedRunServerRaisesException(self):
@@ -142,19 +127,22 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
       return os.path.join(self._dump_root, "run_%d" % run_number)
 
   def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self):
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     with self.assertRaisesRegexp(
         TypeError, "Expected type str or list in grpc_debug_server_addresses"):
       grpc_wrapper.GrpcDebugWrapperSession(sess, 1337)
 
   def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     with self.assertRaisesRegexp(
         TypeError, "Expected type str in list grpc_debug_server_addresses"):
       grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
 
   def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self):
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     with self.assertRaises(TypeError):
       grpc_wrapper.GrpcDebugWrapperSession(
           sess, "localhost:%d" % self._server_port, watch_fn="foo")
@@ -164,7 +152,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
     v = variables.Variable(20.0, name="v")
     w = math_ops.multiply(u, v, name="w")
 
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     sess.run(u.initializer)
     sess.run(v.initializer)
 
@@ -190,7 +179,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
     v = variables.Variable(20.0, name="v")
     w = math_ops.multiply(u, v, name="w")
 
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     sess.run(u.initializer)
     sess.run(v.initializer)
 
@@ -223,7 +213,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
     v = variables.Variable(20.0, name="v")
     w = math_ops.multiply(u, v, name="w")
 
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     sess.run(u.initializer)
     sess.run(v.initializer)
 
@@ -254,7 +245,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
     v = variables.Variable(20.0, name="v")
     w = math_ops.multiply(u, v, name="w")
 
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     sess.run(u.initializer)
     sess.run(v.initializer)
 
@@ -298,7 +290,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
     v = variables.Variable(20.0, name="v")
     w = math_ops.multiply(u, v, name="w")
 
-    sess = session.Session(config=no_rewrite_session_config())
+    sess = session.Session(
+        config=session_debug_testlib.no_rewrite_session_config())
     sess.run(variables.global_variables_initializer())
 
     grpc_debug_hook = hooks.TensorBoardDebugHook(
@@ -324,168 +317,6 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
     hooks.GrpcDebugHook(["foo:42424"])
 
 
-class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
-
-  @classmethod
-  def setUpClass(cls):
-    (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
-     cls.debug_server
-    ) = grpc_debug_test_server.start_server_on_separate_thread(
-        dump_to_filesystem=False)
-    tf_logging.info("debug server url: %s", cls.debug_server_url)
-
-  @classmethod
-  def tearDownClass(cls):
-    cls.debug_server.stop_server().wait()
-    cls.debug_server_thread.join()
-
-  def tearDown(self):
-    ops.reset_default_graph()
-    self.debug_server.clear_data()
-
-  def testSendingLargeGraphDefsWorks(self):
-    with self.test_session(
-        use_gpu=True, config=no_rewrite_session_config()) as sess:
-      u = variables.Variable(42.0, name="original_u")
-      for _ in xrange(50 * 1000):
-        u = array_ops.identity(u)
-      sess.run(variables.global_variables_initializer())
-
-      def watch_fn(fetches, feeds):
-        del fetches, feeds
-        return framework.WatchOptions(
-            debug_ops=["DebugIdentity"],
-            node_name_regex_whitelist=r"original_u")
-      sess = grpc_wrapper.GrpcDebugWrapperSession(
-          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
-      self.assertAllClose(42.0, sess.run(u))
-
-      self.assertAllClose(
-          [42.0],
-          self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
-      self.assertEqual(2 if test.is_gpu_available() else 1,
-                       len(self.debug_server.partition_graph_defs))
-      max_graph_def_size = max([
-          len(graph_def.SerializeToString())
-          for graph_def in self.debug_server.partition_graph_defs])
-      self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
-
-  def testSendingLargeFloatTensorWorks(self):
-    with self.test_session(
-        use_gpu=True, config=no_rewrite_session_config()) as sess:
-      u_init_val_array = list(xrange(1200 * 1024))
-      # Size: 4 * 1200 * 1024 = 4800k > 4M
-
-      u_init = constant_op.constant(
-          u_init_val_array, dtype=dtypes.float32, name="u_init")
-      u = variables.Variable(u_init, name="u")
-
-      def watch_fn(fetches, feeds):
-        del fetches, feeds  # Unused by this watch_fn.
-        return framework.WatchOptions(
-            debug_ops=["DebugIdentity"],
-            node_name_regex_whitelist=r"u_init")
-      sess = grpc_wrapper.GrpcDebugWrapperSession(
-          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
-      sess.run(u.initializer)
-
-      self.assertAllEqual(
-          u_init_val_array,
-          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
-
-  def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
-    with self.test_session(
-        use_gpu=True, config=no_rewrite_session_config()) as sess:
-      u_init_val = [
-          b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
-      u_init = constant_op.constant(
-          u_init_val, dtype=dtypes.string, name="u_init")
-      u = variables.Variable(u_init, name="u")
-
-      def watch_fn(fetches, feeds):
-        del fetches, feeds
-        return framework.WatchOptions(
-            debug_ops=["DebugIdentity"],
-            node_name_regex_whitelist=r"u_init")
-      sess = grpc_wrapper.GrpcDebugWrapperSession(
-          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
-      sess.run(u.initializer)
-
-      self.assertAllEqual(
-          u_init_val,
-          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
-
-  def testSendingLargeStringTensorWorks(self):
-    with self.test_session(
-        use_gpu=True, config=no_rewrite_session_config()) as sess:
-      strs_total_size_threshold = 5000 * 1024
-      cum_size = 0
-      u_init_val_array = []
-      while cum_size < strs_total_size_threshold:
-        strlen = np.random.randint(200)
-        u_init_val_array.append(b"A" * strlen)
-        cum_size += strlen
-
-      u_init = constant_op.constant(
-          u_init_val_array, dtype=dtypes.string, name="u_init")
-      u = variables.Variable(u_init, name="u")
-
-      def watch_fn(fetches, feeds):
-        del fetches, feeds
-        return framework.WatchOptions(
-            debug_ops=["DebugIdentity"],
-            node_name_regex_whitelist=r"u_init")
-      sess = grpc_wrapper.GrpcDebugWrapperSession(
-          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
-      sess.run(u.initializer)
-
-      self.assertAllEqual(
-          u_init_val_array,
-          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
-
-  def testSendingEmptyFloatTensorWorks(self):
-    with self.test_session(
-        use_gpu=True, config=no_rewrite_session_config()) as sess:
-      u_init = constant_op.constant(
-          [], dtype=dtypes.float32, shape=[0], name="u_init")
-      u = variables.Variable(u_init, name="u")
-
-      def watch_fn(fetches, feeds):
-        del fetches, feeds
-        return framework.WatchOptions(
-            debug_ops=["DebugIdentity"],
-            node_name_regex_whitelist=r"u_init")
-      sess = grpc_wrapper.GrpcDebugWrapperSession(
-          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
-      sess.run(u.initializer)
-
-      u_init_value = self.debug_server.debug_tensor_values[
-          "u_init:0:DebugIdentity"][0]
-      self.assertEqual(np.float32, u_init_value.dtype)
-      self.assertEqual(0, len(u_init_value))
-
-  def testSendingEmptyStringTensorWorks(self):
-    with self.test_session(
-        use_gpu=True, config=no_rewrite_session_config()) as sess:
-      u_init = constant_op.constant(
-          [], dtype=dtypes.string, shape=[0], name="u_init")
-      u = variables.Variable(u_init, name="u")
-
-      def watch_fn(fetches, feeds):
-        del fetches, feeds
-        return framework.WatchOptions(
-            debug_ops=["DebugIdentity"],
-            node_name_regex_whitelist=r"u_init")
-      sess = grpc_wrapper.GrpcDebugWrapperSession(
-          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
-      sess.run(u.initializer)
-
-      u_init_value = self.debug_server.debug_tensor_values[
-          "u_init:0:DebugIdentity"][0]
-      self.assertEqual(np.object, u_init_value.dtype)
-      self.assertEqual(0, len(u_init_value))
-
-
 class SessionDebugConcurrentTest(
     session_debug_testlib.DebugConcurrentRunCallsTest):
 
@@ -548,7 +379,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
     self._server_2.clear_data()
 
   def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
-    with session.Session(config=no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       v_1 = variables.Variable(50.0, name="v_1")
       v_2 = variables.Variable(-50.0, name="v_1")
       delta_1 = constant_op.constant(5.0, name="delta_1")
@@ -617,7 +449,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
                                         ("toggled_2", 0, "DebugIdentity")])
     self._servers_and_threads.append((server, server_thread))
 
-    with session.Session(config=no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       v_1 = variables.Variable(50.0, name="v_1")
       v_2 = variables.Variable(-50.0, name="v_1")
       # These two nodes have names that match those in the
@@ -656,7 +489,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
           self.assertEqual(0, len(server.debug_tensor_values))
 
   def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
-    with session.Session(config=no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       v = variables.Variable(50.0, name="v")
       delta = constant_op.constant(5.0, name="delta")
       inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -698,7 +532,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
           self.assertEqual(0, len(self._server_2.debug_tensor_values))
 
   def testToggleBreakpointsWorks(self):
-    with session.Session(config=no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       v_1 = variables.Variable(50.0, name="v_1")
       v_2 = variables.Variable(-50.0, name="v_2")
       delta_1 = constant_op.constant(5.0, name="delta_1")
@@ -755,7 +590,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
           self.assertSetEqual(set(), self._server_1.breakpoints)
 
   def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
-    with session.Session(config=no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       v_1 = variables.Variable(50.0, name="v_1")
       v_2 = variables.Variable(-50.0, name="v_2")
       delta_1 = constant_op.constant(5.0, name="delta_1")
@@ -827,7 +663,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
             self._server_1.query_source_file_line(__file__, 1)
 
   def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
-    with session.Session(config=no_rewrite_session_config()) as sess:
+    with session.Session(
+        config=session_debug_testlib.no_rewrite_session_config()) as sess:
       v_1 = variables.Variable(50.0, name="v_1")
       v_2 = variables.Variable(-50.0, name="v_2")
       delta_1 = constant_op.constant(5.0, name="delta_1")