Adds a within_ops_fn parameter to get_forward_walk_ops and get_backward_walk_ops
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 20:21:34 +0000 (13:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 20:24:25 +0000 (13:24 -0700)
that allows setting a condition on ops that are within or not within.

Also adds tests for these methods that were missing.

PiperOrigin-RevId: 192176693

tensorflow/contrib/graph_editor/select.py
tensorflow/contrib/graph_editor/tests/select_test.py

index 3ea6ff4..d700e6e 100644 (file)
@@ -383,6 +383,7 @@ def get_within_boundary_ops(ops,
 def get_forward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
+                         within_ops_fn=None,
                          stop_at_ts=(),
                          control_outputs=None):
   """Do a forward graph walk and return all the visited ops.
@@ -395,6 +396,9 @@ def get_forward_walk_ops(seed_ops,
     within_ops: an iterable of `tf.Operation` within which the search is
       restricted. If `within_ops` is `None`, the search is performed within
       the whole graph.
+    within_ops_fn: if provided, a function on ops that should return True iff
+      the op is within the graph traversal. This can be used along within_ops,
+      in which case an op is within if it is also in within_ops.
     stop_at_ts: an iterable of tensors at which the graph walk stops.
     control_outputs: a `util.ControlOutputs` instance or None.
       If not `None`, it will be used while walking the graph forward.
@@ -423,7 +427,8 @@ def get_forward_walk_ops(seed_ops,
     seed_ops &= within_ops
 
   def is_within(op):
-    return within_ops is None or op in within_ops
+    return (within_ops is None or op in within_ops) and (
+        within_ops_fn is None or within_ops_fn(op))
 
   result = list(seed_ops)
   wave = set(seed_ops)
@@ -450,6 +455,7 @@ def get_forward_walk_ops(seed_ops,
 def get_backward_walk_ops(seed_ops,
                           inclusive=True,
                           within_ops=None,
+                          within_ops_fn=None,
                           stop_at_ts=(),
                           control_inputs=False):
   """Do a backward graph walk and return all the visited ops.
@@ -462,6 +468,9 @@ def get_backward_walk_ops(seed_ops,
     within_ops: an iterable of `tf.Operation` within which the search is
       restricted. If `within_ops` is `None`, the search is performed within
       the whole graph.
+    within_ops_fn: if provided, a function on ops that should return True iff
+      the op is within the graph traversal. This can be used along within_ops,
+      in which case an op is within if it is also in within_ops.
     stop_at_ts: an iterable of tensors at which the graph walk stops.
     control_inputs: if True, control inputs will be used while moving backward.
   Returns:
@@ -488,7 +497,8 @@ def get_backward_walk_ops(seed_ops,
     seed_ops &= within_ops
 
   def is_within(op):
-    return within_ops is None or op in within_ops
+    return (within_ops is None or op in within_ops) and (
+        within_ops_fn is None or within_ops_fn(op))
 
   result = list(seed_ops)
   wave = set(seed_ops)
@@ -516,6 +526,7 @@ def get_walks_intersection_ops(forward_seed_ops,
                                forward_inclusive=True,
                                backward_inclusive=True,
                                within_ops=None,
+                               within_ops_fn=None,
                                control_inputs=False,
                                control_outputs=None,
                                control_ios=None):
@@ -535,6 +546,9 @@ def get_walks_intersection_ops(forward_seed_ops,
     within_ops: an iterable of tf.Operation within which the search is
       restricted. If within_ops is None, the search is performed within
       the whole graph.
+    within_ops_fn: if provided, a function on ops that should return True iff
+      the op is within the graph traversal. This can be used along within_ops,
+      in which case an op is within if it is also in within_ops.
     control_inputs: A boolean indicating whether control inputs are enabled.
     control_outputs: An instance of util.ControlOutputs or None. If not None,
       control outputs are enabled.
@@ -555,11 +569,13 @@ def get_walks_intersection_ops(forward_seed_ops,
       forward_seed_ops,
       inclusive=forward_inclusive,
       within_ops=within_ops,
+      within_ops_fn=within_ops_fn,
       control_outputs=control_outputs)
   backward_ops = get_backward_walk_ops(
       backward_seed_ops,
       inclusive=backward_inclusive,
       within_ops=within_ops,
+      within_ops_fn=within_ops_fn,
       control_inputs=control_inputs)
   return [op for op in forward_ops if op in backward_ops]
 
@@ -569,6 +585,7 @@ def get_walks_union_ops(forward_seed_ops,
                         forward_inclusive=True,
                         backward_inclusive=True,
                         within_ops=None,
+                        within_ops_fn=None,
                         control_inputs=False,
                         control_outputs=None,
                         control_ios=None):
@@ -587,6 +604,9 @@ def get_walks_union_ops(forward_seed_ops,
       resulting set.
     within_ops: restrict the search within those operations. If within_ops is
       None, the search is done within the whole graph.
+    within_ops_fn: if provided, a function on ops that should return True iff
+      the op is within the graph traversal. This can be used along within_ops,
+      in which case an op is within if it is also in within_ops.
     control_inputs: A boolean indicating whether control inputs are enabled.
     control_outputs: An instance of util.ControlOutputs or None. If not None,
       control outputs are enabled.
@@ -607,11 +627,13 @@ def get_walks_union_ops(forward_seed_ops,
       forward_seed_ops,
       inclusive=forward_inclusive,
       within_ops=within_ops,
+      within_ops_fn=within_ops_fn,
       control_outputs=control_outputs)
   backward_ops = get_backward_walk_ops(
       backward_seed_ops,
       inclusive=backward_inclusive,
       within_ops=within_ops,
+      within_ops_fn=within_ops_fn,
       control_inputs=control_inputs)
   return util.concatenate_unique(forward_ops, backward_ops)
 
index 82f9996..d12c6d3 100644 (file)
@@ -77,12 +77,10 @@ class SelectTest(test.TestCase):
     """Test for ge.get_ops_ios."""
     control_outputs = ge.util.ControlOutputs(self.graph)
     self.assertEqual(
-        len(ge.get_ops_ios(
-            self.h.op, control_ios=control_outputs)), 3)
+        len(ge.get_ops_ios(self.h.op, control_ios=control_outputs)), 3)
     self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2)
     self.assertEqual(
-        len(ge.get_ops_ios(
-            self.c.op, control_ios=control_outputs)), 6)
+        len(ge.get_ops_ios(self.c.op, control_ios=control_outputs)), 6)
     self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5)
 
   def test_compute_boundary_ts_0(self):
@@ -135,16 +133,49 @@ class SelectTest(test.TestCase):
     ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
     self.assertEqual(len(ops), 2)
 
+    ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op])
+    self.assertEqual(len(ops), 3)
+    self.assertTrue(self.a.op in ops)
+    self.assertTrue(self.c.op in ops)
+    self.assertTrue(self.f.op in ops)
+
+    within_ops = [self.a.op, self.f.op]
+    ops = ge.get_walks_intersection_ops(
+        [self.a.op], [self.f.op], within_ops=within_ops)
+    self.assertEqual(len(ops), 0)
+
+    within_ops_fn = lambda op: op in [self.a.op, self.f.op]
+    ops = ge.get_walks_intersection_ops(
+        [self.a.op], [self.f.op], within_ops_fn=within_ops_fn)
+    self.assertEqual(len(ops), 0)
+
   def test_get_walks_union(self):
     """Test for ge.get_walks_union_ops."""
     ops = ge.get_walks_union_ops([self.f.op], [self.g.op])
     self.assertEqual(len(ops), 6)
 
+    ops = ge.get_walks_union_ops([self.a.op], [self.f.op])
+    self.assertEqual(len(ops), 8)
+
+    within_ops = [self.a.op, self.c.op, self.d.op, self.f.op]
+    ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+                                 within_ops=within_ops)
+    self.assertEqual(len(ops), 4)
+    self.assertTrue(self.b.op not in ops)
+
+    within_ops_fn = lambda op: op in [self.a.op, self.c.op, self.f.op]
+    ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+                                 within_ops_fn=within_ops_fn)
+    self.assertEqual(len(ops), 3)
+    self.assertTrue(self.b.op not in ops)
+    self.assertTrue(self.d.op not in ops)
+
   def test_select_ops(self):
     parameters = (
         (("^foo/",), 7),
         (("^foo/bar/",), 4),
-        (("^foo/bar/", "a"), 5),)
+        (("^foo/bar/", "a"), 5),
+    )
     for param, length in parameters:
       ops = ge.select_ops(*param, graph=self.graph)
       self.assertEqual(len(ops), length)
@@ -152,7 +183,8 @@ class SelectTest(test.TestCase):
   def test_select_ts(self):
     parameters = (
         (".*:0", 8),
-        (r".*/bar/\w+:0", 4),)
+        (r".*/bar/\w+:0", 4),
+    )
     for regex, length in parameters:
       ts = ge.select_ts(regex, graph=self.graph)
       self.assertEqual(len(ts), length)
@@ -160,12 +192,121 @@ class SelectTest(test.TestCase):
   def test_select_ops_and_ts(self):
     parameters = (
         (("^foo/.*",), 7, 0),
-        (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),)
+        (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),
+    )
     for param, l0, l1 in parameters:
       ops, ts = ge.select_ops_and_ts(*param, graph=self.graph)
       self.assertEqual(len(ops), l0)
       self.assertEqual(len(ts), l1)
 
+  def test_forward_walk_ops(self):
+    seed_ops = [self.a.op, self.d.op]
+    # Include all ops except for self.g.op
+    within_ops = [
+        x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+    ]
+    # For the fn, exclude self.e.op.
+    within_ops_fn = lambda op: op not in (self.e.op,)
+    stop_at_ts = (self.f,)
+
+    with self.graph.as_default():
+      # No b.op since it's an independent source node.
+      # No g.op from within_ops.
+      # No e.op from within_ops_fn.
+      # No h.op from stop_at_ts and within_ops.
+      ops = ge.select.get_forward_walk_ops(
+          seed_ops,
+          inclusive=True,
+          within_ops=within_ops,
+          within_ops_fn=within_ops_fn,
+          stop_at_ts=stop_at_ts)
+      self.assertEqual(
+          set(ops), set([self.a.op, self.c.op, self.d.op, self.f.op]))
+
+      # Also no a.op and d.op when inclusive=False
+      ops = ge.select.get_forward_walk_ops(
+          seed_ops,
+          inclusive=False,
+          within_ops=within_ops,
+          within_ops_fn=within_ops_fn,
+          stop_at_ts=stop_at_ts)
+      self.assertEqual(set(ops), set([self.c.op, self.f.op]))
+
+      # Not using within_ops_fn adds e.op.
+      ops = ge.select.get_forward_walk_ops(
+          seed_ops,
+          inclusive=False,
+          within_ops=within_ops,
+          stop_at_ts=stop_at_ts)
+      self.assertEqual(set(ops), set([self.c.op, self.e.op, self.f.op]))
+
+      # Not using stop_at_ts adds back h.op.
+      ops = ge.select.get_forward_walk_ops(
+          seed_ops, inclusive=False, within_ops=within_ops)
+      self.assertEqual(
+          set(ops), set([self.c.op, self.e.op, self.f.op, self.h.op]))
+
+      # Starting just form a (the tensor, not op) omits a, b, d.
+      ops = ge.select.get_forward_walk_ops([self.a], inclusive=True)
+      self.assertEqual(
+          set(ops), set([self.c.op, self.e.op, self.f.op, self.g.op,
+                         self.h.op]))
+
+  def test_backward_walk_ops(self):
+    seed_ops = [self.h.op]
+    # Include all ops except for self.g.op
+    within_ops = [
+        x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+    ]
+    # For the fn, exclude self.c.op.
+    within_ops_fn = lambda op: op not in (self.c.op,)
+    stop_at_ts = (self.f,)
+
+    with self.graph.as_default():
+      # Backward walk only includes h since we stop at f and g is not within.
+      ops = ge.select.get_backward_walk_ops(
+          seed_ops,
+          inclusive=True,
+          within_ops=within_ops,
+          within_ops_fn=within_ops_fn,
+          stop_at_ts=stop_at_ts)
+      self.assertEqual(set(ops), set([self.h.op]))
+
+      # If we do inclusive=False, the result is empty.
+      ops = ge.select.get_backward_walk_ops(
+          seed_ops,
+          inclusive=False,
+          within_ops=within_ops,
+          within_ops_fn=within_ops_fn,
+          stop_at_ts=stop_at_ts)
+      self.assertEqual(set(ops), set())
+
+      # Removing stop_at_fs adds f.op, d.op.
+      ops = ge.select.get_backward_walk_ops(
+          seed_ops,
+          inclusive=True,
+          within_ops=within_ops,
+          within_ops_fn=within_ops_fn)
+      self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
+
+      # Not using within_ops_fn adds back ops for a, b, c.
+      ops = ge.select.get_backward_walk_ops(
+          seed_ops, inclusive=True, within_ops=within_ops)
+      self.assertEqual(
+          set(ops),
+          set([
+              self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
+          ]))
+
+      # Vanially backward search via self.h.op includes everything excpet e.op.
+      ops = ge.select.get_backward_walk_ops(seed_ops, inclusive=True)
+      self.assertEqual(
+          set(ops),
+          set([
+              self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
+              self.h.op
+          ]))
+
 
 if __name__ == "__main__":
   test.main()