--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow graph analysis.
+
+Given a Python AST we construct a control flow graph, with edges both to the
+next and previous statements (so it can easily walk the graph both ways). Its
+nodes contain the AST of the statements. It can then perform forward or backward
+analysis on this CFG.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+import functools
+import operator
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct.static_analysis import activity
+
+
+class CfgNode(object):
+ """A node in the CFG."""
+ __slots__ = ['next', 'value', 'prev']
+
+ def __init__(self, value):
+ self.next = set()
+ self.prev = set()
+ self.value = value
+
+
+class Cfg(namedtuple('Cfg', ['entry', 'exit'])):
+ """A Control Flow Graph.
+
+ Each statement is represented as a node. For control flow statements such
+ as conditionals and loops the conditional itself is a node which either
+ branches or cycles, respectively.
+ Attributes:
+ entry: The entry node, which contains the `gast.arguments` node of the
+ function definition.
+ exit: The exit node. This node is special because it has no value (i.e. no
+ corresponding AST node). This is because Python functions can have
+ multiple return statements.
+ """
+ pass
+
+
+class CfgBuilder(gast.NodeVisitor):
+ """Construct a control flow graph.
+
+ Construct a CFG starting from a FunctionDef node.
+ Usage:
+ cfg_obj = CfgBuilder().build_cfg(fndef_node)
+ """
+
+ def __init__(self):
+ # The current leaves of the CFG
+ self.current_leaves = []
+ # TODO(alexbw): generalize to break, return, continue, yield, etc.
+ # A stack of lists, tracking continue statements
+ self.continue_ = []
+ # A stack of lists tracking break nodes
+ self.break_ = []
+
+ def set_current_leaves(self, cfg_node):
+ """Link this cfg_node to the current leaves.
+
+ This is the central function for building the CFG. It links the current
+ head cfg_nodes to the passed cfg_node. It then resets the head to the
+ passed cfg_node.
+
+ Args:
+ cfg_node: A CfgNode instance.
+ """
+ for head in self.current_leaves:
+ head.next.add(cfg_node)
+ # While we're linking the CFG forward, add backlinks
+ cfg_node.prev.add(head)
+ self.current_leaves = [cfg_node]
+
+ def build_cfg(self, node):
+ """Build a CFG for a function.
+
+ Implementation of building a CFG for dataflow analysis. See, e.g.:
+ https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf
+
+ Args:
+ node: A function definition the body of which to analyze.
+ Returns:
+ A CFG object.
+ Raises:
+ TypeError: If the input is not a function definition.
+ """
+ if not isinstance(node, gast.FunctionDef):
+ raise TypeError('input must be a function definition')
+ entry_cfg_node = CfgNode(node.args)
+ self.current_leaves = [entry_cfg_node]
+ self.visit_statements(node.body)
+ exit_cfg_node = CfgNode(None)
+ self.set_current_leaves(exit_cfg_node)
+ return Cfg(entry_cfg_node, exit_cfg_node)
+
+ def visit_statements(self, nodes):
+ for node in nodes:
+ # Check for control flow
+ if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break,
+ gast.Continue, gast.With)):
+ self.visit(node)
+ else:
+ expr = CfgNode(node)
+ self.set_current_leaves(expr)
+
+ def generic_visit(self, node):
+ raise ValueError('unknown control flow')
+
+ def visit_If(self, node):
+ # TODO(alexbw): change this to use immutable tuples instead of lists
+ # The current head will hold the conditional
+ test = CfgNode(node.test)
+ self.set_current_leaves(test)
+ # Handle the body
+ self.visit_statements(node.body)
+ body_exit = self.current_leaves
+ self.current_leaves = []
+ self.current_leaves.append(test)
+ # Handle the orelse
+ self.visit_statements(node.orelse)
+ self.current_leaves.extend(body_exit)
+
+ def visit_While(self, node):
+ test = CfgNode(node.test)
+ self.set_current_leaves(test)
+ # Start a new level of nesting
+ self.break_.append([])
+ self.continue_.append([])
+ # Handle the body
+ self.visit_statements(node.body)
+ self.current_leaves.extend(self.continue_.pop())
+ self.set_current_leaves(test)
+ # Handle the orelse
+ self.visit_statements(node.orelse)
+ # The break statements and the test go to the next node
+ self.current_leaves.extend(self.break_.pop())
+
+ def visit_For(self, node):
+ iter_ = CfgNode(node.iter)
+ self.set_current_leaves(iter_)
+ self.break_.append([])
+ self.continue_.append([])
+ self.visit_statements(node.body)
+ self.current_leaves.extend(self.continue_.pop())
+ self.set_current_leaves(iter_)
+ self.current_leaves.extend(self.break_.pop())
+
+ def visit_Break(self, node):
+ self.break_[-1].extend(self.current_leaves)
+ self.current_leaves[:] = []
+
+ def visit_Continue(self, node):
+ self.continue_[-1].extend(self.current_leaves)
+ self.current_leaves[:] = []
+
+ def visit_Try(self, node):
+ self.visit_statements(node.body)
+ body = self.current_leaves
+ handlers = []
+ for handler in node.handlers:
+ self.current_leaves = body[:]
+ self.visit_statements(handler.body)
+ handlers.extend(self.current_leaves)
+ self.current_leaves = body
+ self.visit_statements(node.orelse)
+ self.current_leaves = handlers + self.current_leaves
+ self.visit_statements(node.finalbody)
+
+ def visit_With(self, node):
+ for item in node.items:
+ self.set_current_leaves(CfgNode(item))
+ self.visit_statements(node.body)
+
+
+# TODO(alexbw): once CFG analysis occurs at a block level,
+# this extra class will not be necessary
+class PropagateAnalysis(gast.NodeVisitor):
+ """Port analysis annotations from statements to their enclosing blocks."""
+
+ def __init__(self, analysis):
+ self.transfer_fn = analysis.transfer_fn
+ self.in_label = analysis.in_label
+ self.out_label = analysis.out_label
+ super(PropagateAnalysis, self).__init__()
+
+ def visit_If(self, node):
+ # Depth-first.
+ self.generic_visit(node)
+ incoming = anno.getanno(node.body[0], self.in_label)
+ incoming |= anno.getanno(node.test, self.in_label)
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ outgoing |= anno.getanno(node.test, self.out_label)
+ if node.orelse:
+ orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
+ outgoing = self.transfer_fn(outgoing, orelse_outgoing)
+ anno.setanno(node, self.in_label, incoming)
+ anno.setanno(node, self.out_label, outgoing)
+
+ def visit_For(self, node):
+ self.generic_visit(node)
+ incoming = set(anno.getanno(node.body[0], self.in_label))
+ incoming -= set((anno.getanno(node.target, anno.Basic.QN),))
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ if node.orelse:
+ orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
+ outgoing = self.transfer_fn(outgoing, orelse_outgoing)
+ anno.setanno(node, self.in_label, frozenset(incoming))
+ anno.setanno(node, self.out_label, outgoing)
+
+ def visit_While(self, node):
+ self.generic_visit(node)
+ incoming = anno.getanno(node.body[0], self.in_label)
+ incoming |= anno.getanno(node.test, self.in_label)
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ if node.orelse:
+ orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
+ outgoing = self.transfer_fn(outgoing, orelse_outgoing)
+ anno.setanno(node, self.in_label, incoming)
+ anno.setanno(node, self.out_label, outgoing)
+
+ def visit_With(self, node):
+ self.generic_visit(node)
+ incoming = anno.getanno(node.body[0], self.in_label)
+ for item in node.items:
+ incoming |= anno.getanno(item, self.in_label)
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ anno.setanno(node, self.in_label, incoming)
+ anno.setanno(node, self.out_label, outgoing)
+
+
+# TODO(alexbw): Abstract the CFG walking machinery into a superclass
+# which is parameterized on which fields it selects when walking.
+# TODO(alexbw): Abstract the application of dataflow analysis
+class Forward(object):
+ """Forward analysis on CFG.
+
+ Args:
+ label: A name for this analysis e.g. 'active' for activity analysis. The AST
+ nodes in the CFG will be given annotations 'name_in', 'name_out',
+ 'name_gen' and 'name_kill' which contain the incoming values, outgoing
+ values, values generated by the statement, and values deleted by the
+ statement respectively.
+ transfer_fn: Either the AND or OR operator. If the AND operator is used it
+ turns into forward must analysis (i.e. a value will only be carried
+ forward if it appears on all incoming paths). The OR operator means that
+ forward may analysis is done (i.e. the union of incoming values will be
+ taken).
+ """
+
+ def __init__(self, label, context, transfer_fn=operator.or_):
+ self.transfer_fn = transfer_fn
+ self.context = context
+ self.out_label = label + '_out'
+ self.in_label = label + '_in'
+ self.gen_label = label + '_gen'
+ self.kill_label = label + '_kill'
+
+ # TODO(alexbw): see if we can simplify by visiting breadth-first
+ def visit(self, node):
+ """Depth-first walking the CFG, applying dataflow information propagtion."""
+ # node.value is None only for the exit CfgNode.
+ if not node.value:
+ return
+
+ if anno.hasanno(node.value, self.out_label):
+ before = hash(anno.getanno(node.value, self.out_label))
+ else:
+ before = None
+ preds = [
+ anno.getanno(pred.value, self.out_label)
+ for pred in node.prev
+ if anno.hasanno(pred.value, self.out_label)
+ ]
+ if preds:
+ incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0])
+ else:
+ incoming = frozenset()
+ anno.setanno(node.value, self.in_label, incoming)
+ gen, kill = self.get_gen_kill(node, incoming)
+ anno.setanno(node.value, self.gen_label, gen)
+ anno.setanno(node.value, self.kill_label, kill)
+ anno.setanno(node.value, self.out_label, (incoming - kill) | gen)
+
+ if hash(anno.getanno(node.value, self.out_label)) != before:
+ for succ in node.next:
+ self.visit(succ)
+
+ def get_gen_kill(self, cfg_node, incoming):
+ """Calculate Gen and Kill properties of a CFG node in dataflow analysis.
+
+ A function which takes the CFG node as well as a set of incoming
+ values. It must return a set of newly generated values by the statement as
+ well as a set of deleted (killed) values.
+
+ Args:
+ cfg_node: A CfgNode instance.
+ incoming:
+ """
+ raise NotImplementedError()
+
+
+class Backward(Forward):
+ """Backward analysis on CFG."""
+
+ def visit(self, cfg_node):
+ # cfg_node.value is None for the exit node, which will be visited only once
+ if not cfg_node.value:
+ for pred in cfg_node.prev:
+ self.visit(pred)
+ return
+
+ if anno.hasanno(cfg_node.value, self.in_label):
+ before = hash(anno.getanno(cfg_node.value, self.in_label))
+ else:
+ before = None
+ succs = [
+ anno.getanno(succ.value, self.in_label)
+ for succ in cfg_node.next
+ if anno.hasanno(succ.value, self.in_label)
+ ]
+ if succs:
+ incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0])
+ else:
+ incoming = frozenset()
+ anno.setanno(cfg_node.value, self.out_label, incoming)
+ gen, kill = self.get_gen_kill(cfg_node, incoming)
+ anno.setanno(cfg_node.value, self.gen_label, gen)
+ anno.setanno(cfg_node.value, self.kill_label, kill)
+ anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen)
+ if hash(anno.getanno(cfg_node.value, self.in_label)) != before:
+ for pred in cfg_node.prev:
+ self.visit(pred)
+
+
+def run_analyses(node, analyses):
+ """Perform dataflow analysis on all functions within an AST.
+
+ Args:
+ node: An AST node on which to run dataflow analysis.
+ analyses: Either an instance of the Forward or Backward dataflow analysis
+ class, or a list or tuple of them.
+
+ Returns:
+ node: The node, but now with annotations on the AST nodes containing the
+ results of the dataflow analyses.
+ """
+ if not isinstance(analyses, (tuple, list)):
+ analyses = (analyses,)
+ for analysis in analyses:
+ if not isinstance(analysis, (Forward, Backward)):
+ raise TypeError('not a valid forward analysis object')
+
+ for child_node in gast.walk(node):
+ if isinstance(child_node, gast.FunctionDef):
+ cfg_obj = CfgBuilder().build_cfg(child_node)
+ for analysis in analyses:
+ if isinstance(analysis, Backward):
+ analysis.visit(cfg_obj.exit)
+ elif isinstance(analysis, Forward):
+ analysis.visit(cfg_obj.entry)
+ for analysis in analyses:
+ PropagateAnalysis(analysis).visit(node)
+ return node
+
+
+class Liveness(Backward):
+ """Perform a liveness analysis.
+
+ Each statement is annotated with a set of variables that may be used
+ later in the program.
+ """
+
+ def __init__(self, context):
+ super(Liveness, self).__init__('live', context)
+
+ def get_gen_kill(self, node, _):
+ gen = activity.get_read(node.value, self.context)
+ kill = activity.get_updated(node.value, self.context)
+ return gen, kill
+
+
+class ReachingDefinitions(Forward):
+ """Perform reaching definition analysis.
+
+ Each statement is annotated with a set of (variable, definition) pairs.
+ """
+
+ def __init__(self, context):
+ super(ReachingDefinitions, self).__init__('definitions', context)
+
+ def get_gen_kill(self, node, incoming):
+ definitions = activity.get_updated(node.value, self.context)
+ gen = frozenset((id_, node.value) for id_ in definitions)
+ kill = frozenset(def_ for def_ in incoming if def_[0] in definitions)
+ return gen, kill
+
+
+class Defined(Forward):
+ """Perform defined variable analysis.
+
+ Each statement is annotated with a set of variables which are guaranteed to
+ be defined at that point.
+ """
+
+ def __init__(self, context):
+ super(Defined, self).__init__('defined', context, transfer_fn=operator.and_)
+
+ def get_gen_kill(self, node, _):
+ gen = activity.get_updated(node.value, self.context)
+ return gen, frozenset()
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cfg module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import context
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct.static_analysis import cfg
+from tensorflow.python.platform import test
+
+
+class CFGTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
+ arg_types = arg_types or {}
+ node, source = parser.parse_entity(test_fn)
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=None,
+ recursive=True)
+ node = qual_names.resolve(node)
+ return node, ctx
+
+ def _check_anno_matches(self, node, anno_name, var_names):
+ if isinstance(var_names, str):
+ var_names = (var_names,)
+ qual_vars = set()
+ for var_name in var_names:
+ if isinstance(var_name, str):
+ if '[' in var_name or ']' in var_name:
+ raise ValueError('Annotation matching not supported with subscript.')
+ if '.' not in var_name:
+ qual_vars.add(qual_names.QN(var_name))
+ else:
+ attrs = var_name.split('.')
+ this_qn = functools.reduce(qual_names.QN, attrs[1:],
+ qual_names.QN(attrs[0]))
+ qual_vars.add(this_qn)
+ self.assertEqual(anno.getanno(node, anno_name), qual_vars)
+
+ def test_reaching(self):
+
+ def f(x):
+ print(x)
+ while True:
+ x = x
+ x = x
+ return x
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
+ body = node.body[0].body
+ # Only the argument reaches the expression
+ def_in = anno.getanno(body[0], 'definitions_in')
+ # One element, x, from arguments
+ self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,)))
+
+ while_body = body[1].body
+ def_in = anno.getanno(while_body[0], 'definitions_in')
+ # One definition, two possible sources.
+ # - One from an assignment (if the loop is entered)
+ # - The other from the arguments (if loop is not entered)
+ self.assertEqual(
+ set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))
+
+ def_in = anno.getanno(while_body[1], 'definitions_in')
+ # If we've reached this line, the only reaching definition of x is the
+ # Assign node in previous line
+ self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,)))
+
+ def_in = anno.getanno(body[2], 'definitions_in')
+ # Same situation as while_body[0]
+ self.assertEqual(
+ set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))
+
+ def test_defined(self):
+
+ def f(x):
+ if x:
+ y = 2 # pylint: disable=unused-variable
+ return x
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Defined(ctx))
+ body = node.body[0].body
+ # only x is for sure defined at the end
+ self._check_anno_matches(body[1], 'defined_in', 'x')
+ # at the end of the if body both x and y are defined
+ if_body = body[0].body
+ self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y'))
+
+ # TODO(alexbw): b/73926938 split this test up
+ def test_live(self):
+
+ def get_live_annotated_fnbody(f):
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Liveness(ctx))
+ body = node.body[0].body
+ return body
+
+ def f1(x):
+ a = g(x) # pylint: disable=undefined-variable
+ b = h(a) # pylint: disable=undefined-variable, unused-variable
+ return x
+
+ def f2(x, a): # pylint: disable=unused-argument
+ if a > 0: # x should not be live
+ x = 0
+ if a > 1:
+ x = 1
+ else:
+ x = 2
+
+ def f3(x, a):
+ if a > 0: # x and a should be live
+ x = 0
+ if a > 1: # x and a should be live_in
+ x = 1
+ return x # x should be live
+
+ def f4(x, a):
+ if a > 0: # x should be live
+ x = 0
+ x += 1
+
+ def f5(x, a):
+ if a > 0: # x.y should be live
+ x.y = 0
+ return x.y
+
+ def f6(x):
+ return x # should this cause x.* to be live?
+
+ def f7(x, n):
+ for i in range(n):
+ x += i
+ return x
+
+ def f8(x, f):
+ with f:
+ x += 1
+
+ body = get_live_annotated_fnbody(f1)
+ self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x'))
+ self._check_anno_matches(body[2], 'live_in', ('x'))
+ self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x'))
+ self._check_anno_matches(body[2], 'live_out', ())
+
+ body = get_live_annotated_fnbody(f2)
+ self._check_anno_matches(body[0], 'live_in', ('a'))
+ self._check_anno_matches(body[1], 'live_in', ('a'))
+
+ body = get_live_annotated_fnbody(f3)
+ self._check_anno_matches(body[0], 'live_in', ('a', 'x'))
+ self._check_anno_matches(body[1], 'live_in', ('a', 'x'))
+ self._check_anno_matches(body[2], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f4)
+ self._check_anno_matches(body[0], 'live_in', ('x', 'a'))
+ self._check_anno_matches(body[1], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f5)
+ self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a'))
+
+ body = get_live_annotated_fnbody(f6)
+ self._check_anno_matches(body[0], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f7)
+ self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range'))
+ self._check_anno_matches(body[1], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f8)
+ self._check_anno_matches(body[0], 'live_in', ('f', 'x'))
+
+ def test_node_equality(self):
+ node_a = gast.parse('y = x').body[0]
+ node_b = gast.parse('y = x').body[0]
+ self.assertNotEqual(node_a, node_b)
+
+ def test_nested_functions_defined(self):
+
+ def f(x):
+ y = x * 2
+
+ def g(z):
+ return z + y
+
+ return g(x)
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Defined(ctx))
+
+ body = node.body[0].body
+ self.assertEqual(
+ anno.getanno(body[2], 'defined_in'),
+ frozenset(map(qual_names.QN, ('g', 'x', 'y'))))
+
+ # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries.
+ # NOTE: 'z' is easy to find, but 'y' is not identified as
+ # defined, because CFG analysis is applied with each function separately.
+ # fndef_body = body[1].body
+ # self.assertEqual(
+ # anno.getanno(fndef_body[0], 'defined_in'),
+ # frozenset(map(qual_names.QN, ('z', 'y'))))
+
+ def test_nested_functions_dont_leak_definitions(self):
+
+ def f(x):
+ print(x)
+
+ def g():
+ y = 2
+ return y
+
+ return g() # y is not defined here
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Defined(ctx))
+ body = node.body[0].body
+ self.assertEqual(
+ anno.getanno(body[2], 'defined_in'),
+ frozenset(map(qual_names.QN, ('x', 'g'))))
+
+
+if __name__ == '__main__':
+ test.main()