def get_forward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
+ within_ops_fn=None,
stop_at_ts=(),
control_outputs=None):
"""Do a forward graph walk and return all the visited ops.
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_outputs: a `util.ControlOutputs` instance or None.
If not `None`, it will be used while walking the graph forward.
seed_ops &= within_ops
def is_within(op):
- return within_ops is None or op in within_ops
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
def get_backward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
+ within_ops_fn=None,
stop_at_ts=(),
control_inputs=False):
"""Do a backward graph walk and return all the visited ops.
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_inputs: if True, control inputs will be used while moving backward.
Returns:
seed_ops &= within_ops
def is_within(op):
- return within_ops is None or op in within_ops
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
+ within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
within_ops: an iterable of tf.Operation within which the search is
restricted. If within_ops is None, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return [op for op in forward_ops if op in backward_ops]
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
+ within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
resulting set.
within_ops: restrict the search within those operations. If within_ops is
None, the search is done within the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return util.concatenate_unique(forward_ops, backward_ops)
"""Test for ge.get_ops_ios."""
control_outputs = ge.util.ControlOutputs(self.graph)
self.assertEqual(
- len(ge.get_ops_ios(
- self.h.op, control_ios=control_outputs)), 3)
+ len(ge.get_ops_ios(self.h.op, control_ios=control_outputs)), 3)
self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2)
self.assertEqual(
- len(ge.get_ops_ios(
- self.c.op, control_ios=control_outputs)), 6)
+ len(ge.get_ops_ios(self.c.op, control_ios=control_outputs)), 6)
self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5)
def test_compute_boundary_ts_0(self):
ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
self.assertEqual(len(ops), 2)
+ ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op])
+ self.assertEqual(len(ops), 3)
+ self.assertTrue(self.a.op in ops)
+ self.assertTrue(self.c.op in ops)
+ self.assertTrue(self.f.op in ops)
+
+ within_ops = [self.a.op, self.f.op]
+ ops = ge.get_walks_intersection_ops(
+ [self.a.op], [self.f.op], within_ops=within_ops)
+ self.assertEqual(len(ops), 0)
+
+ within_ops_fn = lambda op: op in [self.a.op, self.f.op]
+ ops = ge.get_walks_intersection_ops(
+ [self.a.op], [self.f.op], within_ops_fn=within_ops_fn)
+ self.assertEqual(len(ops), 0)
+
def test_get_walks_union(self):
"""Test for ge.get_walks_union_ops."""
ops = ge.get_walks_union_ops([self.f.op], [self.g.op])
self.assertEqual(len(ops), 6)
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op])
+ self.assertEqual(len(ops), 8)
+
+ within_ops = [self.a.op, self.c.op, self.d.op, self.f.op]
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+ within_ops=within_ops)
+ self.assertEqual(len(ops), 4)
+ self.assertTrue(self.b.op not in ops)
+
+ within_ops_fn = lambda op: op in [self.a.op, self.c.op, self.f.op]
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+ within_ops_fn=within_ops_fn)
+ self.assertEqual(len(ops), 3)
+ self.assertTrue(self.b.op not in ops)
+ self.assertTrue(self.d.op not in ops)
+
def test_select_ops(self):
parameters = (
(("^foo/",), 7),
(("^foo/bar/",), 4),
- (("^foo/bar/", "a"), 5),)
+ (("^foo/bar/", "a"), 5),
+ )
for param, length in parameters:
ops = ge.select_ops(*param, graph=self.graph)
self.assertEqual(len(ops), length)
def test_select_ts(self):
parameters = (
(".*:0", 8),
- (r".*/bar/\w+:0", 4),)
+ (r".*/bar/\w+:0", 4),
+ )
for regex, length in parameters:
ts = ge.select_ts(regex, graph=self.graph)
self.assertEqual(len(ts), length)
def test_select_ops_and_ts(self):
parameters = (
(("^foo/.*",), 7, 0),
- (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),)
+ (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),
+ )
for param, l0, l1 in parameters:
ops, ts = ge.select_ops_and_ts(*param, graph=self.graph)
self.assertEqual(len(ops), l0)
self.assertEqual(len(ts), l1)
+ def test_forward_walk_ops(self):
+ seed_ops = [self.a.op, self.d.op]
+ # Include all ops except for self.g.op
+ within_ops = [
+ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+ ]
+ # For the fn, exclude self.e.op.
+ within_ops_fn = lambda op: op not in (self.e.op,)
+ stop_at_ts = (self.f,)
+
+ with self.graph.as_default():
+ # No b.op since it's an independent source node.
+ # No g.op from within_ops.
+ # No e.op from within_ops_fn.
+ # No h.op from stop_at_ts and within_ops.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(
+ set(ops), set([self.a.op, self.c.op, self.d.op, self.f.op]))
+
+ # Also no a.op and d.op when inclusive=False
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.c.op, self.f.op]))
+
+ # Not using within_ops_fn adds e.op.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.c.op, self.e.op, self.f.op]))
+
+ # Not using stop_at_ts adds back h.op.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops, inclusive=False, within_ops=within_ops)
+ self.assertEqual(
+ set(ops), set([self.c.op, self.e.op, self.f.op, self.h.op]))
+
+ # Starting just form a (the tensor, not op) omits a, b, d.
+ ops = ge.select.get_forward_walk_ops([self.a], inclusive=True)
+ self.assertEqual(
+ set(ops), set([self.c.op, self.e.op, self.f.op, self.g.op,
+ self.h.op]))
+
+ def test_backward_walk_ops(self):
+ seed_ops = [self.h.op]
+ # Include all ops except for self.g.op
+ within_ops = [
+ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+ ]
+ # For the fn, exclude self.c.op.
+ within_ops_fn = lambda op: op not in (self.c.op,)
+ stop_at_ts = (self.f,)
+
+ with self.graph.as_default():
+ # Backward walk only includes h since we stop at f and g is not within.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.h.op]))
+
+ # If we do inclusive=False, the result is empty.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set())
+
+ # Removing stop_at_fs adds f.op, d.op.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn)
+ self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
+
+ # Not using within_ops_fn adds back ops for a, b, c.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops, inclusive=True, within_ops=within_ops)
+ self.assertEqual(
+ set(ops),
+ set([
+ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
+ ]))
+
+ # Vanially backward search via self.h.op includes everything excpet e.op.
+ ops = ge.select.get_backward_walk_ops(seed_ops, inclusive=True)
+ self.assertEqual(
+ set(ops),
+ set([
+ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
+ self.h.op
+ ]))
+
if __name__ == "__main__":
test.main()