From: A. Unique TensorFlower Date: Mon, 7 May 2018 13:27:43 +0000 (-0700) Subject: Control flow graph with forward and backward analysis X-Git-Tag: upstream/v1.9.0_rc1~150^2~1^2~85 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=adcfdfcfec78b2021320d99c3543a054b470f56f;p=platform%2Fupstream%2Ftensorflow.git Control flow graph with forward and backward analysis PiperOrigin-RevId: 195654450 --- diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 83f3baf..68fbdf6 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,6 +19,7 @@ py_library( srcs = [ "activity.py", "annos.py", + "cfg.py", "live_values.py", "type_info.py", ], @@ -44,6 +45,18 @@ py_test( ) py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( name = "live_values_test", srcs = ["live_values_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py new file mode 100644 index 0000000..230e4cc --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -0,0 +1,431 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph analysis. + +Given a Python AST we construct a control flow graph, with edges both to the +next and previous statements (so it can easily walk the graph both ways). Its +nodes contain the AST of the statements. It can then perform forward or backward +analysis on this CFG. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import functools +import operator + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct.static_analysis import activity + + +class CfgNode(object): + """A node in the CFG.""" + __slots__ = ['next', 'value', 'prev'] + + def __init__(self, value): + self.next = set() + self.prev = set() + self.value = value + + +class Cfg(namedtuple('Cfg', ['entry', 'exit'])): + """A Control Flow Graph. + + Each statement is represented as a node. For control flow statements such + as conditionals and loops the conditional itself is a node which either + branches or cycles, respectively. + Attributes: + entry: The entry node, which contains the `gast.arguments` node of the + function definition. + exit: The exit node. This node is special because it has no value (i.e. no + corresponding AST node). This is because Python functions can have + multiple return statements. + """ + pass + + +class CfgBuilder(gast.NodeVisitor): + """Construct a control flow graph. + + Construct a CFG starting from a FunctionDef node. + Usage: + cfg_obj = CfgBuilder().build_cfg(fndef_node) + """ + + def __init__(self): + # The current leaves of the CFG + self.current_leaves = [] + # TODO(alexbw): generalize to break, return, continue, yield, etc. + # A stack of lists, tracking continue statements + self.continue_ = [] + # A stack of lists tracking break nodes + self.break_ = [] + + def set_current_leaves(self, cfg_node): + """Link this cfg_node to the current leaves. + + This is the central function for building the CFG. It links the current + head cfg_nodes to the passed cfg_node. It then resets the head to the + passed cfg_node. + + Args: + cfg_node: A CfgNode instance. + """ + for head in self.current_leaves: + head.next.add(cfg_node) + # While we're linking the CFG forward, add backlinks + cfg_node.prev.add(head) + self.current_leaves = [cfg_node] + + def build_cfg(self, node): + """Build a CFG for a function. + + Implementation of building a CFG for dataflow analysis. See, e.g.: + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Args: + node: A function definition the body of which to analyze. + Returns: + A CFG object. + Raises: + TypeError: If the input is not a function definition. + """ + if not isinstance(node, gast.FunctionDef): + raise TypeError('input must be a function definition') + entry_cfg_node = CfgNode(node.args) + self.current_leaves = [entry_cfg_node] + self.visit_statements(node.body) + exit_cfg_node = CfgNode(None) + self.set_current_leaves(exit_cfg_node) + return Cfg(entry_cfg_node, exit_cfg_node) + + def visit_statements(self, nodes): + for node in nodes: + # Check for control flow + if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, + gast.Continue, gast.With)): + self.visit(node) + else: + expr = CfgNode(node) + self.set_current_leaves(expr) + + def generic_visit(self, node): + raise ValueError('unknown control flow') + + def visit_If(self, node): + # TODO(alexbw): change this to use immutable tuples instead of lists + # The current head will hold the conditional + test = CfgNode(node.test) + self.set_current_leaves(test) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves = [] + self.current_leaves.append(test) + # Handle the orelse + self.visit_statements(node.orelse) + self.current_leaves.extend(body_exit) + + def visit_While(self, node): + test = CfgNode(node.test) + self.set_current_leaves(test) + # Start a new level of nesting + self.break_.append([]) + self.continue_.append([]) + # Handle the body + self.visit_statements(node.body) + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(test) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + + def visit_For(self, node): + iter_ = CfgNode(node.iter) + self.set_current_leaves(iter_) + self.break_.append([]) + self.continue_.append([]) + self.visit_statements(node.body) + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(iter_) + self.current_leaves.extend(self.break_.pop()) + + def visit_Break(self, node): + self.break_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Continue(self, node): + self.continue_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Try(self, node): + self.visit_statements(node.body) + body = self.current_leaves + handlers = [] + for handler in node.handlers: + self.current_leaves = body[:] + self.visit_statements(handler.body) + handlers.extend(self.current_leaves) + self.current_leaves = body + self.visit_statements(node.orelse) + self.current_leaves = handlers + self.current_leaves + self.visit_statements(node.finalbody) + + def visit_With(self, node): + for item in node.items: + self.set_current_leaves(CfgNode(item)) + self.visit_statements(node.body) + + +# TODO(alexbw): once CFG analysis occurs at a block level, +# this extra class will not be necessary +class PropagateAnalysis(gast.NodeVisitor): + """Port analysis annotations from statements to their enclosing blocks.""" + + def __init__(self, analysis): + self.transfer_fn = analysis.transfer_fn + self.in_label = analysis.in_label + self.out_label = analysis.out_label + super(PropagateAnalysis, self).__init__() + + def visit_If(self, node): + # Depth-first. + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + outgoing |= anno.getanno(node.test, self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_For(self, node): + self.generic_visit(node) + incoming = set(anno.getanno(node.body[0], self.in_label)) + incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, frozenset(incoming)) + anno.setanno(node, self.out_label, outgoing) + + def visit_While(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_With(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + for item in node.items: + incoming |= anno.getanno(item, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + +# TODO(alexbw): Abstract the CFG walking machinery into a superclass +# which is parameterized on which fields it selects when walking. +# TODO(alexbw): Abstract the application of dataflow analysis +class Forward(object): + """Forward analysis on CFG. + + Args: + label: A name for this analysis e.g. 'active' for activity analysis. The AST + nodes in the CFG will be given annotations 'name_in', 'name_out', + 'name_gen' and 'name_kill' which contain the incoming values, outgoing + values, values generated by the statement, and values deleted by the + statement respectively. + transfer_fn: Either the AND or OR operator. If the AND operator is used it + turns into forward must analysis (i.e. a value will only be carried + forward if it appears on all incoming paths). The OR operator means that + forward may analysis is done (i.e. the union of incoming values will be + taken). + """ + + def __init__(self, label, context, transfer_fn=operator.or_): + self.transfer_fn = transfer_fn + self.context = context + self.out_label = label + '_out' + self.in_label = label + '_in' + self.gen_label = label + '_gen' + self.kill_label = label + '_kill' + + # TODO(alexbw): see if we can simplify by visiting breadth-first + def visit(self, node): + """Depth-first walking the CFG, applying dataflow information propagtion.""" + # node.value is None only for the exit CfgNode. + if not node.value: + return + + if anno.hasanno(node.value, self.out_label): + before = hash(anno.getanno(node.value, self.out_label)) + else: + before = None + preds = [ + anno.getanno(pred.value, self.out_label) + for pred in node.prev + if anno.hasanno(pred.value, self.out_label) + ] + if preds: + incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) + else: + incoming = frozenset() + anno.setanno(node.value, self.in_label, incoming) + gen, kill = self.get_gen_kill(node, incoming) + anno.setanno(node.value, self.gen_label, gen) + anno.setanno(node.value, self.kill_label, kill) + anno.setanno(node.value, self.out_label, (incoming - kill) | gen) + + if hash(anno.getanno(node.value, self.out_label)) != before: + for succ in node.next: + self.visit(succ) + + def get_gen_kill(self, cfg_node, incoming): + """Calculate Gen and Kill properties of a CFG node in dataflow analysis. + + A function which takes the CFG node as well as a set of incoming + values. It must return a set of newly generated values by the statement as + well as a set of deleted (killed) values. + + Args: + cfg_node: A CfgNode instance. + incoming: + """ + raise NotImplementedError() + + +class Backward(Forward): + """Backward analysis on CFG.""" + + def visit(self, cfg_node): + # cfg_node.value is None for the exit node, which will be visited only once + if not cfg_node.value: + for pred in cfg_node.prev: + self.visit(pred) + return + + if anno.hasanno(cfg_node.value, self.in_label): + before = hash(anno.getanno(cfg_node.value, self.in_label)) + else: + before = None + succs = [ + anno.getanno(succ.value, self.in_label) + for succ in cfg_node.next + if anno.hasanno(succ.value, self.in_label) + ] + if succs: + incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) + else: + incoming = frozenset() + anno.setanno(cfg_node.value, self.out_label, incoming) + gen, kill = self.get_gen_kill(cfg_node, incoming) + anno.setanno(cfg_node.value, self.gen_label, gen) + anno.setanno(cfg_node.value, self.kill_label, kill) + anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) + if hash(anno.getanno(cfg_node.value, self.in_label)) != before: + for pred in cfg_node.prev: + self.visit(pred) + + +def run_analyses(node, analyses): + """Perform dataflow analysis on all functions within an AST. + + Args: + node: An AST node on which to run dataflow analysis. + analyses: Either an instance of the Forward or Backward dataflow analysis + class, or a list or tuple of them. + + Returns: + node: The node, but now with annotations on the AST nodes containing the + results of the dataflow analyses. + """ + if not isinstance(analyses, (tuple, list)): + analyses = (analyses,) + for analysis in analyses: + if not isinstance(analysis, (Forward, Backward)): + raise TypeError('not a valid forward analysis object') + + for child_node in gast.walk(node): + if isinstance(child_node, gast.FunctionDef): + cfg_obj = CfgBuilder().build_cfg(child_node) + for analysis in analyses: + if isinstance(analysis, Backward): + analysis.visit(cfg_obj.exit) + elif isinstance(analysis, Forward): + analysis.visit(cfg_obj.entry) + for analysis in analyses: + PropagateAnalysis(analysis).visit(node) + return node + + +class Liveness(Backward): + """Perform a liveness analysis. + + Each statement is annotated with a set of variables that may be used + later in the program. + """ + + def __init__(self, context): + super(Liveness, self).__init__('live', context) + + def get_gen_kill(self, node, _): + gen = activity.get_read(node.value, self.context) + kill = activity.get_updated(node.value, self.context) + return gen, kill + + +class ReachingDefinitions(Forward): + """Perform reaching definition analysis. + + Each statement is annotated with a set of (variable, definition) pairs. + """ + + def __init__(self, context): + super(ReachingDefinitions, self).__init__('definitions', context) + + def get_gen_kill(self, node, incoming): + definitions = activity.get_updated(node.value, self.context) + gen = frozenset((id_, node.value) for id_ in definitions) + kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) + return gen, kill + + +class Defined(Forward): + """Perform defined variable analysis. + + Each statement is annotated with a set of variables which are guaranteed to + be defined at that point. + """ + + def __init__(self, context): + super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + + def get_gen_kill(self, node, _): + gen = activity.get_updated(node.value, self.context) + return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py new file mode 100644 index 0000000..af7eaf3 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -0,0 +1,252 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.python.platform import test + + +class CFGTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + return node, ctx + + def _check_anno_matches(self, node, anno_name, var_names): + if isinstance(var_names, str): + var_names = (var_names,) + qual_vars = set() + for var_name in var_names: + if isinstance(var_name, str): + if '[' in var_name or ']' in var_name: + raise ValueError('Annotation matching not supported with subscript.') + if '.' not in var_name: + qual_vars.add(qual_names.QN(var_name)) + else: + attrs = var_name.split('.') + this_qn = functools.reduce(qual_names.QN, attrs[1:], + qual_names.QN(attrs[0])) + qual_vars.add(this_qn) + self.assertEqual(anno.getanno(node, anno_name), qual_vars) + + def test_reaching(self): + + def f(x): + print(x) + while True: + x = x + x = x + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + # Only the argument reaches the expression + def_in = anno.getanno(body[0], 'definitions_in') + # One element, x, from arguments + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) + + while_body = body[1].body + def_in = anno.getanno(while_body[0], 'definitions_in') + # One definition, two possible sources. + # - One from an assignment (if the loop is entered) + # - The other from the arguments (if loop is not entered) + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def_in = anno.getanno(while_body[1], 'definitions_in') + # If we've reached this line, the only reaching definition of x is the + # Assign node in previous line + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) + + def_in = anno.getanno(body[2], 'definitions_in') + # Same situation as while_body[0] + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def test_defined(self): + + def f(x): + if x: + y = 2 # pylint: disable=unused-variable + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + # only x is for sure defined at the end + self._check_anno_matches(body[1], 'defined_in', 'x') + # at the end of the if body both x and y are defined + if_body = body[0].body + self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) + + # TODO(alexbw): b/73926938 split this test up + def test_live(self): + + def get_live_annotated_fnbody(f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body + + def f1(x): + a = g(x) # pylint: disable=undefined-variable + b = h(a) # pylint: disable=undefined-variable, unused-variable + return x + + def f2(x, a): # pylint: disable=unused-argument + if a > 0: # x should not be live + x = 0 + if a > 1: + x = 1 + else: + x = 2 + + def f3(x, a): + if a > 0: # x and a should be live + x = 0 + if a > 1: # x and a should be live_in + x = 1 + return x # x should be live + + def f4(x, a): + if a > 0: # x should be live + x = 0 + x += 1 + + def f5(x, a): + if a > 0: # x.y should be live + x.y = 0 + return x.y + + def f6(x): + return x # should this cause x.* to be live? + + def f7(x, n): + for i in range(n): + x += i + return x + + def f8(x, f): + with f: + x += 1 + + body = get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + body = get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + body = get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + body = get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f8) + self._check_anno_matches(body[0], 'live_in', ('f', 'x')) + + def test_node_equality(self): + node_a = gast.parse('y = x').body[0] + node_b = gast.parse('y = x').body[0] + self.assertNotEqual(node_a, node_b) + + def test_nested_functions_defined(self): + + def f(x): + y = x * 2 + + def g(z): + return z + y + + return g(x) + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('g', 'x', 'y')))) + + # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. + # NOTE: 'z' is easy to find, but 'y' is not identified as + # defined, because CFG analysis is applied with each function separately. + # fndef_body = body[1].body + # self.assertEqual( + # anno.getanno(fndef_body[0], 'defined_in'), + # frozenset(map(qual_names.QN, ('z', 'y')))) + + def test_nested_functions_dont_leak_definitions(self): + + def f(x): + print(x) + + def g(): + y = 2 + return y + + return g() # y is not defined here + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('x', 'g')))) + + +if __name__ == '__main__': + test.main()