Stricter analysis for functional conditional generation
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 14 May 2018 22:06:17 +0000 (15:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 14 May 2018 22:08:58 +0000 (15:08 -0700)
PiperOrigin-RevId: 196573938

tensorflow/contrib/autograph/converters/break_statements.py
tensorflow/contrib/autograph/converters/control_flow.py
tensorflow/contrib/autograph/converters/control_flow_test.py
tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py

index 1be1c96..3587722 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import gast
+
 from tensorflow.contrib.autograph.pyct import anno
 from tensorflow.contrib.autograph.pyct import templates
 from tensorflow.contrib.autograph.pyct import transformer
@@ -52,8 +54,13 @@ class BreakStatementTransformer(transformer.Base):
 
   def _guard_if_present(self, block, var_name):
     """Prevents the block from executing if var_name is set."""
+
+    # If we don't have statements that immediately depend on the break
+    # we still need to make sure that the break variable remains
+    # used, in case the break becomes useful in later stages of transformation.
+    # Not having this broke the break_in_inner_loop test.
     if not block:
-      return block
+      block = [gast.Pass()]
     template = """
         if not var_name:
           block
index 935a278..d7ddbe8 100644 (file)
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Handles control flow statements: while, if."""
+"""Handles control flow statements: while, for, if."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -25,6 +25,7 @@ from tensorflow.contrib.autograph.pyct import ast_util
 from tensorflow.contrib.autograph.pyct import parser
 from tensorflow.contrib.autograph.pyct import templates
 from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import cfg
 from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
 
 
@@ -47,9 +48,6 @@ class SymbolNamer(object):
 class ControlFlowTransformer(transformer.Base):
   """Transforms control flow structures like loops an conditionals."""
 
-  def __init__(self, context):
-    super(ControlFlowTransformer, self).__init__(context)
-
   def _create_cond_branch(self, body_name, aliased_orig_names,
                           aliased_new_names, body, returns):
     if aliased_orig_names:
@@ -98,30 +96,63 @@ class ControlFlowTransformer(transformer.Base):
 
     body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
     orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
-
-    if body_scope.created - orelse_scope.created:
-      raise ValueError(
-          'The if branch creates new symbols that the else branch does not.')
-    if orelse_scope.created - body_scope.created:
-      raise ValueError(
-          'The else branch creates new symbols that the if branch does not.')
-
-    modified = tuple(body_scope.modified | orelse_scope.modified)
-    all_referenced = body_scope.referenced | orelse_scope.referenced
+    body_defs = body_scope.created | body_scope.modified
+    orelse_defs = orelse_scope.created | orelse_scope.modified
+    live = anno.getanno(node, 'live_out')
+
+    # We'll need to check if we're closing over variables that are defined
+    # elsewhere in the function
+    # NOTE: we can only detect syntactic closure in the scope
+    # of the code passed in. If the AutoGraph'd function itself closes
+    # over other variables, this analysis won't take that into account.
+    defined = anno.getanno(node, 'defined_in')
+
+    # We only need to return variables that are
+    # - modified by one or both branches
+    # - live (or has a live parent) at the end of the conditional
+    modified = []
+    for def_ in body_defs | orelse_defs:
+      def_with_parents = set((def_,)) | def_.support_set
+      if live & def_with_parents:
+        modified.append(def_)
+
+    # We need to check if live created variables are balanced
+    # in both branches
+    created = live & (body_scope.created | orelse_scope.created)
+
+    # The if statement is illegal if there are variables that are created,
+    # that are also live, but both branches don't create them.
+    if created:
+      if created != (body_scope.created & live):
+        raise ValueError(
+            'The main branch does not create all live symbols that the else '
+            'branch does.')
+      if created != (orelse_scope.created & live):
+        raise ValueError(
+            'The else branch does not create all live symbols that the main '
+            'branch does.')
 
     # Alias the closure variables inside the conditional functions
     # to avoid errors caused by the local variables created in the branch
     # functions.
-    need_alias = (
-        (body_scope.modified | orelse_scope.modified) -
-        (body_scope.created | orelse_scope.created))
-    aliased_orig_names = tuple(need_alias)
-    aliased_new_names = tuple(
-        self.context.namer.new_symbol(s.ssf(), all_referenced)
-        for s in aliased_orig_names)
-    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
-    node_body = ast_util.rename_symbols(node.body, alias_map)
-    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)
+    # We will alias variables independently for body and orelse scope,
+    # because different branches might write different variables.
+    aliased_body_orig_names = tuple(body_scope.modified - body_scope.created)
+    aliased_orelse_orig_names = tuple(orelse_scope.modified -
+                                      orelse_scope.created)
+    aliased_body_new_names = tuple(
+        self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
+        for s in aliased_body_orig_names)
+    aliased_orelse_new_names = tuple(
+        self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
+        for s in aliased_orelse_orig_names)
+
+    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
+    alias_orelse_map = dict(
+        zip(aliased_orelse_orig_names, aliased_orelse_new_names))
+
+    node_body = ast_util.rename_symbols(node.body, alias_body_map)
+    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
 
     if not modified:
       # When the cond would return no value, we leave the cond called without
@@ -134,26 +165,47 @@ class ControlFlowTransformer(transformer.Base):
     else:
       results = gast.Tuple([s.ast() for s in modified], None)
 
-    body_name = self.context.namer.new_symbol('if_true', all_referenced)
-    orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
+    body_name = self.context.namer.new_symbol('if_true', body_scope.referenced)
+    orelse_name = self.context.namer.new_symbol('if_false',
+                                                orelse_scope.referenced)
     if modified:
-      body_returns = tuple(
-          alias_map[s] if s in aliased_orig_names else s for s in modified)
+
+      def build_returns(aliased_names, alias_map, scope):
+        """Builds list of return variables for a branch of a conditional."""
+        returns = []
+        for s in modified:
+          if s in aliased_names:
+            returns.append(alias_map[s])
+          else:
+            if s not in scope.created | defined:
+              raise ValueError(
+                  'Attempting to return variable "%s" from the true branch of '
+                  'a conditional, but it was not closed over, or created in '
+                  'this branch.' % str(s))
+            else:
+              returns.append(s)
+        return tuple(returns)
+
+      body_returns = build_returns(aliased_body_orig_names, alias_body_map,
+                                   body_scope)
+      orelse_returns = build_returns(aliased_orelse_orig_names,
+                                     alias_orelse_map, orelse_scope)
+
     else:
-      body_returns = templates.replace('tf.ones(())')[0].value
+      body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value
 
     body_def = self._create_cond_branch(
         body_name,
-        aliased_orig_names=tuple(aliased_orig_names),
-        aliased_new_names=tuple(aliased_new_names),
+        aliased_orig_names=tuple(aliased_body_orig_names),
+        aliased_new_names=tuple(aliased_body_new_names),
         body=node_body,
         returns=body_returns)
     orelse_def = self._create_cond_branch(
         orelse_name,
-        aliased_orig_names=tuple(aliased_orig_names),
-        aliased_new_names=tuple(aliased_new_names),
+        aliased_orig_names=tuple(aliased_orelse_orig_names),
+        aliased_new_names=tuple(aliased_orelse_new_names),
         body=node_orelse,
-        returns=body_returns)
+        returns=orelse_returns)
     cond_expr = self._create_cond_expr(results, node.test, body_name,
                                        orelse_name)
 
@@ -284,6 +336,7 @@ class ControlFlowTransformer(transformer.Base):
 
 
 def transform(node, context):
-  t = ControlFlowTransformer(context)
-  node = t.visit(node)
+  cfg.run_analyses(node, cfg.Liveness(context))
+  cfg.run_analyses(node, cfg.Defined(context))
+  node = ControlFlowTransformer(context).visit(node)
   return node
index c5610b1..1a86359 100644 (file)
@@ -22,6 +22,7 @@ from tensorflow.contrib.autograph.converters import control_flow
 from tensorflow.contrib.autograph.converters import converter_test_base
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.platform import test
 
@@ -95,6 +96,91 @@ class ControlFlowTest(converter_test_base.TestCase):
       with self.test_session() as sess:
         self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
 
+  def test_imbalanced_aliasing(self):
+
+    def test_fn(n):
+      if n > 0:
+        n = 3
+      return n
+
+    node = self.parse_and_analyze(test_fn, {})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node, control_flow_ops.cond) as result:
+      with self.test_session() as sess:
+        self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2))))
+        self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+
+  def test_ignore_unread_variable(self):
+
+    def test_fn(n):
+      b = 3  # pylint: disable=unused-variable
+      if n > 0:
+        b = 4
+      return n
+
+    node = self.parse_and_analyze(test_fn, {})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+      with self.test_session() as sess:
+        self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3))))
+        self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+
+  def test_handle_temp_variable(self):
+
+    def test_fn_using_temp(x, y, w):
+      if x < y:
+        z = x + y
+      else:
+        w = 2
+        tmp = w
+        z = x - tmp
+      return z, w
+
+    node = self.parse_and_analyze(test_fn_using_temp, {})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+      with self.test_session() as sess:
+        z, w = sess.run(
+            result.test_fn_using_temp(
+                constant_op.constant(-3), constant_op.constant(3),
+                constant_op.constant(3)))
+        self.assertEqual(0, z)
+        self.assertEqual(3, w)
+        z, w = sess.run(
+            result.test_fn_using_temp(
+                constant_op.constant(3), constant_op.constant(-3),
+                constant_op.constant(3)))
+        self.assertEqual(1, z)
+        self.assertEqual(2, w)
+
+    def test_fn_ignoring_temp(x, y, w):
+      if x < y:
+        z = x + y
+      else:
+        w = 2
+        tmp = w
+        z = x - tmp
+      return z
+
+    node = self.parse_and_analyze(test_fn_ignoring_temp, {})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+      with self.test_session() as sess:
+        z = sess.run(
+            result.test_fn_ignoring_temp(
+                constant_op.constant(-3), constant_op.constant(3),
+                constant_op.constant(3)))
+        self.assertEqual(0, z)
+        z = sess.run(
+            result.test_fn_ignoring_temp(
+                constant_op.constant(3), constant_op.constant(-3),
+                constant_op.constant(3)))
+        self.assertEqual(1, z)
+
   def test_simple_for(self):
 
     def test_fn(l):
index 230e4cc..ad97fdf 100644 (file)
@@ -135,8 +135,7 @@ class CfgBuilder(gast.NodeVisitor):
     # Handle the body
     self.visit_statements(node.body)
     body_exit = self.current_leaves
-    self.current_leaves = []
-    self.current_leaves.append(test)
+    self.current_leaves = [test]
     # Handle the orelse
     self.visit_statements(node.orelse)
     self.current_leaves.extend(body_exit)
@@ -149,12 +148,15 @@ class CfgBuilder(gast.NodeVisitor):
     self.continue_.append([])
     # Handle the body
     self.visit_statements(node.body)
+    body_exit = self.current_leaves
     self.current_leaves.extend(self.continue_.pop())
     self.set_current_leaves(test)
     # Handle the orelse
     self.visit_statements(node.orelse)
     # The break statements and the test go to the next node
     self.current_leaves.extend(self.break_.pop())
+    # Body and orelse statements can reach out of the loop
+    self.current_leaves.extend(body_exit)
 
   def visit_For(self, node):
     iter_ = CfgNode(node.iter)
@@ -162,9 +164,15 @@ class CfgBuilder(gast.NodeVisitor):
     self.break_.append([])
     self.continue_.append([])
     self.visit_statements(node.body)
+    body_exit = self.current_leaves
     self.current_leaves.extend(self.continue_.pop())
     self.set_current_leaves(iter_)
+    # Handle the orelse
+    self.visit_statements(node.orelse)
+    # The break statements and the test go to the next node
     self.current_leaves.extend(self.break_.pop())
+    # Body and orelse statements can reach out of the loop
+    self.current_leaves.extend(body_exit)
 
   def visit_Break(self, node):
     self.break_[-1].extend(self.current_leaves)
@@ -395,7 +403,13 @@ class Liveness(Backward):
     super(Liveness, self).__init__('live', context)
 
   def get_gen_kill(self, node, _):
+    # A variable's parents are live if it is live
+    # e.g. x is live if x.y is live. This means gen needs to return
+    # all parents of a variable (if it's an Attribute or Subscript).
+    # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x)
     gen = activity.get_read(node.value, self.context)
+    gen = functools.reduce(lambda left, right: left | right.support_set, gen,
+                           gen)
     kill = activity.get_updated(node.value, self.context)
     return gen, kill
 
index af7eaf3..8d723ce 100644 (file)
@@ -247,6 +247,47 @@ class CFGTest(test.TestCase):
         anno.getanno(body[2], 'defined_in'),
         frozenset(map(qual_names.QN, ('x', 'g'))))
 
+  def test_loop_else(self):
+
+    # Disabling useless-else-on-loop error, because 'break' and 'continue'
+    # canonicalization are a separate analysis pass, and here we test
+    # the CFG analysis in isolation.
+    def for_orelse(x):
+      y = 0
+      for i in range(len(x)):
+        x += i
+      else:  # pylint: disable=useless-else-on-loop
+        y = 1
+      return x, y
+
+    def while_orelse(x, i):
+      y = 0
+      while x < 10:
+        x += i
+      else:  # pylint: disable=useless-else-on-loop
+        y = 1
+      return x, y
+
+    for f in (for_orelse, while_orelse):
+      node, ctx = self._parse_and_analyze(f, {})
+      cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
+      body = node.body[0].body
+      return_node = body[-1]
+      reaching_defs = anno.getanno(return_node, 'definitions_in')
+
+      # Y could be defined by Assign(Num(0)) or Assign(Num(1))
+      # X could be defined as an argument or an AugAssign.
+      y_defs = [node for var, node in reaching_defs if str(var) == 'y']
+      x_defs = [node for var, node in reaching_defs if str(var) == 'x']
+
+      self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs))
+      self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs))
+      self.assertEqual(len(y_defs), 2)
+      self.assertEqual(
+          set((gast.arguments, gast.AugAssign)),
+          set(type(def_) for def_ in x_defs))
+      self.assertEqual(len(x_defs), 2)
+
 
 if __name__ == '__main__':
   test.main()