tfdbg CLI: Allow node exclusion with tensor filters
authorShanqing Cai <cais@google.com>
Mon, 26 Mar 2018 19:16:34 +0000 (12:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 26 Mar 2018 19:20:17 +0000 (12:20 -0700)
Fixes: #16619

See the referred GitHub issue for details, but users want to be
able to skip certain nodes when searching for inf/nans, because
some nodes generate inf/nans even in nominal conditions.

This CL adds a new optional flag `--filter_exclude_node_names`
(or `-fenn` for short), which allows users to do exactly that,
by using a regex for node names.

RELNOTES: tfdbg CLI: Allow exclusion of nodes by regular expressions
during tensor filter-enabled Session runs: see the new flags
`--filter_exclude_node_names` (or `-fenn` for short).
PiperOrigin-RevId: 190504225

tensorflow/docs_src/programmers_guide/debugger.md
tensorflow/python/debug/cli/analyzer_cli.py
tensorflow/python/debug/cli/analyzer_cli_test.py
tensorflow/python/debug/lib/debug_data.py
tensorflow/python/debug/lib/session_debug_testlib.py
tensorflow/python/debug/wrappers/local_cli_wrapper.py
tensorflow/python/debug/wrappers/local_cli_wrapper_test.py

index d139981..d1cd7e7 100644 (file)
@@ -155,6 +155,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
 | | `-n <name_pattern>` | List dumped tensors with names matching given regular-expression pattern. | `lt -n Softmax.*` |
 | | `-t <op_pattern>` | List dumped tensors with op types matching given regular-expression pattern. | `lt -t MatMul` |
 | | `-f <filter_name>` | List only the tensors that pass a registered tensor filter. | `lt -f has_inf_or_nan` |
+| | `-f <filter_name> -fenn <regex>` | List only the tensors that pass a registered tensor filter, excluding nodes with names matching the regular expression. | `lt -f has_inf_or_nan` `-fenn .*Sqrt.*` |
 | | `-s <sort_key>` | Sort the output by given `sort_key`, whose possible values are `timestamp` (default), `dump_size`, `op_type` and `tensor_name`. | `lt -s dump_size` |
 | | `-r` | Sort in reverse order. | `lt -r -s dump_size` |
 | **`pt`** | | **Print value of a dumped tensor.** | |
@@ -200,6 +201,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
 | | `-n` | Execute through the next `Session.run` without debugging, and drop to CLI right before the run after that. | `run -n` |
 | | `-t <T>` | Execute `Session.run` `T - 1` times without debugging, followed by a run with debugging. Then drop to CLI right after the debugged run. | `run -t 10` |
 | | `-f <filter_name>` | Continue executing `Session.run` until any intermediate tensor triggers the specified Tensor filter (causes the filter to return `True`). | `run -f has_inf_or_nan` |
+| | `-f <filter_name> -fenn <regex>` | Continue executing `Session.run` until any intermediate tensor whose node names doesn't match the regular expression triggers the specified Tensor filter (causes the filter to return `True`). | `run -f has_inf_or_nan -fenn .*Sqrt.*` |
 | | `--node_name_filter <pattern>` | Execute the next `Session.run`, watching only nodes with names matching the given regular-expression pattern. | `run --node_name_filter Softmax.*` |
 | | `--op_type_filter <pattern>` | Execute the next `Session.run`, watching only nodes with op types matching the given regular-expression pattern. | `run --op_type_filter Variable.*` |
 | | `--tensor_dtype_filter <pattern>` | Execute the next `Session.run`, dumping only Tensors with data types (`dtype`s) matching the given regular-expression pattern. | `run --tensor_dtype_filter int.*` |
@@ -813,6 +815,20 @@ sess.run(b)
 the constant-folding would not occur and `tfdbg` should show the intermediate
 tensor dumps.
 
+
+**Q**: I am debugging a model that generates unwanted infinities or NaNs. But
+       there are some nodes in my model that are known to generate infinities
+       or NaNs in their output tensors even under completely normal conditions.
+       How can I skip those nodes during my `run -f has_inf_or_nan` actions?
+
+**A**: Use the `--filter_exclude_node_names` (`-fenn` for short) flag. For
+       example, if you known you have a node with name matching the regular
+       expression `.*Sqrt.*` that generates infinities or NaNs regardless
+       of whether the model is behaving correctly, you can exclude the nodes
+       from the infinity/NaN-finding runs with the command
+       `run -f has_inf_or_nan -fenn .*Sqrt.*`.
+
+
 **Q**: Is there a GUI for tfdbg?
 
 **A**: Yes, the **TensorBoard Debugger Plugin** is the GUI of tfdbg.
index 156afdf..9a47cd1 100644 (file)
@@ -186,6 +186,15 @@ class DebugAnalyzer(object):
         default="",
         help="List only Tensors passing the filter of the specified name")
     ap.add_argument(
+        "-fenn",
+        "--filter_exclude_node_names",
+        dest="filter_exclude_node_names",
+        type=str,
+        default="",
+        help="When applying the tensor filter, exclude node with names "
+        "matching the regular expression. Applicable only if --tensor_filter "
+        "or -f is used.")
+    ap.add_argument(
         "-n",
         "--node_name_filter",
         dest="node_name_filter",
@@ -484,6 +493,10 @@ class DebugAnalyzer(object):
 
     Returns:
       Output text lines as a RichTextLines object.
+
+    Raises:
+      ValueError: If `--filter_exclude_node_names` is used without `-f` or
+        `--tensor_filter` being used.
     """
 
     # TODO(cais): Add annotations of substrings for dumped tensor names, to
@@ -520,8 +533,15 @@ class DebugAnalyzer(object):
         _add_main_menu(output, node_name=None, enable_list_tensors=False)
         return output
 
-      data_to_show = self._debug_dump.find(filter_callable)
+      data_to_show = self._debug_dump.find(
+          filter_callable,
+          exclude_node_names=parsed.filter_exclude_node_names)
     else:
+      if parsed.filter_exclude_node_names:
+        raise ValueError(
+            "The flag --filter_exclude_node_names is valid only when "
+            "the flag -f or --tensor_filter is used.")
+
       data_to_show = self._debug_dump.dumped_tensor_data
 
     # TODO(cais): Implement filter by lambda on tensor value.
index 6b110fd..5523195 100644 (file)
@@ -820,6 +820,32 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
         op_type_regex="(Add|MatMul)")
     check_main_menu(self, out, list_tensors_enabled=False)
 
+  def testListTensorWithFilterAndNodeNameExclusionWorks(self):
+    # First, create and register the filter.
+    def is_2x1_vector(datum, tensor):
+      del datum  # Unused.
+      return list(tensor.shape) == [2, 1]
+    self._analyzer.add_tensor_filter("is_2x1_vector", is_2x1_vector)
+
+    # Use shorthand alias for the command prefix.
+    out = self._registry.dispatch_command(
+        "lt", ["-f", "is_2x1_vector", "--filter_exclude_node_names", ".*v.*"])
+
+    # If the --filter_exclude_node_names were not used, then the matching
+    # tensors would be:
+    #   - simple_mul_add/v:0
+    #   - simple_mul_add/v/read:0
+    #   - simple_mul_add/matmul:0
+    #   - simple_mul_add/add:0
+    #
+    # With the --filter_exclude_node_names option, only the last two should
+    # show up in the result.
+    assert_listed_tensors(
+        self,
+        out, ["simple_mul_add/matmul:0", "simple_mul_add/add:0"],
+        ["MatMul", "Add"], tensor_filter_name="is_2x1_vector")
+    check_main_menu(self, out, list_tensors_enabled=False)
+
   def testListTensorsFilterNanOrInf(self):
     """Test register and invoke a tensor filter."""
 
index 8d355aa..8a65ad0 100644 (file)
@@ -23,6 +23,7 @@ import glob
 import json
 import os
 import platform
+import re
 
 import numpy as np
 import six
@@ -1411,7 +1412,11 @@ class DebugDumpDir(object):
 
     return self._watch_key_to_datum[device_name].get(debug_watch_key, [])
 
-  def find(self, predicate, first_n=0, device_name=None):
+  def find(self,
+           predicate,
+           first_n=0,
+           device_name=None,
+           exclude_node_names=None):
     """Find dumped tensor data by a certain predicate.
 
     Args:
@@ -1430,17 +1435,24 @@ class DebugDumpDir(object):
         time order) for which the predicate returns True. To return all the
         `DebugTensotDatum` instances, let first_n be <= 0.
       device_name: optional device name.
+      exclude_node_names: Optional regular expression to exclude nodes with
+        names matching the regular expression.
 
     Returns:
       A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
        for which predicate returns True, sorted in ascending order of the
        timestamp.
     """
+    if exclude_node_names:
+      exclude_node_names = re.compile(exclude_node_names)
 
     matched_data = []
     for device in (self._dump_tensor_data if device_name is None
                    else (self._dump_tensor_data[device_name],)):
       for datum in self._dump_tensor_data[device]:
+        if exclude_node_names and exclude_node_names.match(datum.node_name):
+          continue
+
         if predicate(datum, datum.get_tensor()):
           matched_data.append(datum)
 
index f4fac14..070d9c4 100644 (file)
@@ -669,6 +669,55 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
       self.assertEqual(1, len(first_bad_datum))
       self.assertEqual(x_name, first_bad_datum[0].node_name)
 
+  def testFindInfOrNanWithOpNameExclusion(self):
+    with session.Session() as sess:
+      u_name = "testFindInfOrNanWithOpNameExclusion/u"
+      v_name = "testFindInfOrNanWithOpNameExclusion/v"
+      w_name = "testFindInfOrNanWithOpNameExclusion/w"
+      x_name = "testFindInfOrNanWithOpNameExclusion/x"
+      y_name = "testFindInfOrNanWithOpNameExclusion/y"
+      z_name = "testFindInfOrNanWithOpNameExclusion/z"
+
+      u_init = constant_op.constant([2.0, 4.0])
+      u = variables.Variable(u_init, name=u_name)
+      v_init = constant_op.constant([2.0, 1.0])
+      v = variables.Variable(v_init, name=v_name)
+
+      # Expected output: [0.0, 3.0]
+      w = math_ops.subtract(u, v, name=w_name)
+
+      # Expected output: [inf, 1.3333]
+      x = math_ops.div(u, w, name=x_name)
+
+      # Expected output: [nan, 4.0]
+      y = math_ops.multiply(w, x, name=y_name)
+
+      z = math_ops.multiply(y, y, name=z_name)
+
+      u.initializer.run()
+      v.initializer.run()
+
+      _, dump = self._debug_run_and_get_dump(
+          sess, z,
+          expected_partition_graph_count=self._expected_partition_graph_count)
+
+      # Find all "offending tensors".
+      bad_data = dump.find(debug_data.has_inf_or_nan,
+                           exclude_node_names=".*/x$")
+
+      # Verify that the nodes with bad values are caught through running find
+      # on the debug dump.
+      self.assertEqual(2, len(bad_data))
+      # Assert that the node `x` should have been excluded.
+      self.assertEqual(y_name, bad_data[0].node_name)
+      self.assertEqual(z_name, bad_data[1].node_name)
+
+      first_bad_datum = dump.find(
+          debug_data.has_inf_or_nan, first_n=1, exclude_node_names=".*/x$")
+
+      self.assertEqual(1, len(first_bad_datum))
+      self.assertEqual(y_name, first_bad_datum[0].node_name)
+
   def _session_run_for_graph_structure_lookup(self):
     with session.Session(config=no_rewrite_session_config()) as sess:
       u_name = "testDumpGraphStructureLookup/u"
index 1465cb7..c862565 100644 (file)
@@ -115,6 +115,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
     #   unavailable (i.e., is None), the run-start CLI will be launched to ask
     #   the user. This is the case, e.g., right before the first run starts.
     self._active_tensor_filter = None
+    self._active_filter_exclude_node_names = None
     self._active_tensor_filter_run_start_response = None
     self._run_through_times = 1
     self._skip_debug = False
@@ -149,6 +150,15 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
         default="",
         help="Run until a tensor in the graph passes the specified filter.")
     ap.add_argument(
+        "-fenn",
+        "--filter_exclude_node_names",
+        dest="filter_exclude_node_names",
+        type=str,
+        default="",
+        help="When applying the tensor filter, exclude node with names "
+        "matching the regular expression. Applicable only if --tensor_filter "
+        "or -f is used.")
+    ap.add_argument(
         "--node_name_filter",
         dest="node_name_filter",
         type=str,
@@ -324,9 +334,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
       debug_dump.set_python_graph(self._sess.graph)
 
       passed_filter = None
+      passed_filter_exclude_node_names = None
       if self._active_tensor_filter:
         if not debug_dump.find(
-            self._tensor_filters[self._active_tensor_filter], first_n=1):
+            self._tensor_filters[self._active_tensor_filter], first_n=1,
+            exclude_node_names=self._active_filter_exclude_node_names):
           # No dumped tensor passes the filter in this run. Clean up the dump
           # directory and move on.
           self._remove_dump_root()
@@ -334,10 +346,14 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
         else:
           # Some dumped tensor(s) from this run passed the filter.
           passed_filter = self._active_tensor_filter
+          passed_filter_exclude_node_names = (
+              self._active_filter_exclude_node_names)
           self._active_tensor_filter = None
+          self._active_filter_exclude_node_names = None
 
       self._prep_debug_cli_for_run_end(
-          debug_dump, request.tf_error, passed_filter)
+          debug_dump, request.tf_error, passed_filter,
+          passed_filter_exclude_node_names)
 
       self._run_start_response = self._launch_cli()
 
@@ -358,7 +374,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
     if os.path.isdir(self._dump_root):
       shutil.rmtree(self._dump_root)
 
-  def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
+  def _prep_debug_cli_for_run_end(self,
+                                  debug_dump,
+                                  tf_error,
+                                  passed_filter,
+                                  passed_filter_exclude_node_names):
     """Prepare (but not launch) CLI for run-end, with debug dump from the run.
 
     Args:
@@ -368,6 +388,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
         (if any).
       passed_filter: (None or str) Name of the tensor filter that just passed
         and caused the preparation of this run-end CLI (if any).
+      passed_filter_exclude_node_names: (None or str) Regular expression used
+        with the tensor filter to exclude ops with names matching the regular
+        expresssion.
     """
 
     if tf_error:
@@ -383,6 +406,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
       if passed_filter is not None:
         # Some dumped tensor(s) from this run passed the filter.
         self._init_command = "lt -f %s" % passed_filter
+        if passed_filter_exclude_node_names:
+          self._init_command += (" --filter_exclude_node_names %s" %
+                                 passed_filter_exclude_node_names)
         self._title_color = "red_on_white"
 
     self._run_cli = analyzer_cli.create_analyzer_ui(
@@ -496,6 +522,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
     parsed.op_type_filter = parsed.op_type_filter or None
     parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
 
+    if parsed.filter_exclude_node_names and not parsed.till_filter_pass:
+      raise ValueError(
+          "The --filter_exclude_node_names (or -feon) flag is valid only if "
+          "the --till_filter_pass (or -f) flag is used.")
+
     if parsed.profile:
       raise debugger_cli_common.CommandLineExit(
           exit_token=framework.OnRunStartResponse(
@@ -525,6 +556,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
       if parsed.till_filter_pass in self._tensor_filters:
         action = framework.OnRunStartAction.DEBUG_RUN
         self._active_tensor_filter = parsed.till_filter_pass
+        self._active_filter_exclude_node_names = (
+            parsed.filter_exclude_node_names)
         self._active_tensor_filter_run_start_response = run_start_response
       else:
         # Handle invalid filter name.
index 490812c..b06fa26 100644 (file)
@@ -87,7 +87,11 @@ class LocalCLIDebuggerWrapperSessionForTest(
   def _prep_cli_for_run_start(self):
     pass
 
-  def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
+  def _prep_debug_cli_for_run_end(self,
+                                  debug_dump,
+                                  tf_error,
+                                  passed_filter,
+                                  passed_filter_exclude_op_names):
     self.observers["debug_dumps"].append(debug_dump)
     self.observers["tf_errors"].append(tf_error)
 
@@ -451,6 +455,36 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
     self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
     self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
 
+  def testRunTillFilterPassesWithExcludeOpNames(self):
+    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+        [["run", "-f", "greater_than_twelve",
+          "--filter_exclude_node_names", "inc_v.*"],
+         ["run"], ["run"]],
+        self.sess,
+        dump_root=self._tmp_dir)
+
+    def greater_than_twelve(datum, tensor):
+      del datum  # Unused.
+      return tensor > 12.0
+
+    # Verify that adding the same tensor filter more than once is tolerated
+    # (i.e., as if it were added only once).
+    wrapped_sess.add_tensor_filter("greater_than_twelve", greater_than_twelve)
+
+    # run five times.
+    wrapped_sess.run(self.inc_v)
+    wrapped_sess.run(self.inc_v)
+    wrapped_sess.run(self.inc_v)
+    wrapped_sess.run(self.inc_v)
+
+    self.assertAllClose(14.0, self.sess.run(self.v))
+
+    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
+
+    # Due to the --filter_exclude_op_names flag, the run-end CLI should show up
+    # not after run 3, but after run 4.
+    self.assertEqual([4], wrapped_sess.observers["run_end_cli_run_numbers"])
+
   def testRunTillFilterPassesWorksInConjunctionWithOtherNodeNameFilter(self):
     """Test that --.*_filter flags work in conjunction with -f.