Cleanup: update continue_statements.py to use the base transformer facilities for...
authorDan Moldovan <mdan@google.com>
Thu, 31 May 2018 15:53:36 +0000 (08:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 15:56:25 +0000 (08:56 -0700)
PiperOrigin-RevId: 198727946

tensorflow/contrib/autograph/converters/break_statements.py
tensorflow/contrib/autograph/converters/continue_statements.py
tensorflow/contrib/autograph/pyct/transformer.py
tensorflow/contrib/autograph/pyct/transformer_test.py

index 5b7508c..775d92c 100644 (file)
@@ -32,14 +32,6 @@ CONTROL_VAR_NAME = 'control_var_name'
 class BreakStatementTransformer(transformer.Base):
   """Canonicalizes break statements into additional conditionals."""
 
-  def _track_body(self, nodes, break_var):
-    self.enter_local_scope()
-    self.set_local(CONTROL_VAR_NAME, break_var)
-    nodes = self.visit_block(nodes)
-    break_used = self.get_local(BREAK_USED, False)
-    self.exit_local_scope()
-    return nodes, break_used
-
   def visit_Break(self, node):
     self.set_local(BREAK_USED, True)
     var_name = self.get_local(CONTROL_VAR_NAME)
@@ -65,6 +57,14 @@ class BreakStatementTransformer(transformer.Base):
         block=block)
     return node
 
+  def _track_body(self, nodes, break_var):
+    self.enter_local_scope()
+    self.set_local(CONTROL_VAR_NAME, break_var)
+    nodes = self.visit_block(nodes)
+    break_used = self.get_local(BREAK_USED, False)
+    self.exit_local_scope()
+    return nodes, break_used
+
   def visit_While(self, node):
     scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
     break_var = self.context.namer.new_symbol('break_', scope.referenced)
index 4299a8a..0417817 100644 (file)
@@ -24,103 +24,115 @@ from tensorflow.contrib.autograph.pyct import transformer
 from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
 
 
-class ContinueCanonicalizationTransformer(transformer.Base):
-  """Canonicalizes continue statements into additional conditionals."""
+# Tags for local state.
+CONTROL_VAR_NAME = 'control_var_name'
+CONTINUE_USED = 'continue_used'
+GUARD_CREATED = 'guard_created'
+CREATE_GUARD_NEXT = 'create_guard_next'
 
-  def __init__(self, context):
-    super(ContinueCanonicalizationTransformer, self).__init__(context)
-    # This is a stack structure, to correctly process nested loops.
-    self.continuation_uses = []
 
-  def _create_continuation_check(self):
-    template = """
-      if not var_name:
-        pass
-    """
-    cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
-    cond.body = []
-    return cond
+class ContinueCanonicalizationTransformer(transformer.Base):
+  """Canonicalizes continue statements into additional conditionals."""
 
-  def _create_continuation_trigger(self):
+  def visit_Continue(self, node):
+    self.set_local(CONTINUE_USED, True)
     template = """
       var_name = True
     """
-    assign, = templates.replace(
-        template, var_name=self.continuation_uses[-1][1])
-    return assign
-
-  def _create_continuation_init(self):
-    template = """
-      var_name = False
-    """
-    assign, = templates.replace(
-        template, var_name=self.continuation_uses[-1][1])
-    return assign
-
-  def _visit_and_reindent_if_necessary(self, nodes):
-    reorganized_nodes = []
-    current_dest = reorganized_nodes
-    continue_used_in_block = False
-    for i, n in enumerate(nodes):
-      # TODO(mdan): This could be optimized if control structures are simple.
-      self.continuation_uses[-1][0] = False
-      n = self.visit(n)
-      current_dest.append(n)
-      if self.continuation_uses[-1][0]:
-        continue_used_in_block = True
-        if i < len(nodes) - 1:  # Last statement in block needs no protection.
-          cond = self._create_continuation_check()
-          current_dest.append(cond)
-          current_dest = cond.body
-    self.continuation_uses[-1][0] = continue_used_in_block
-    return reorganized_nodes
-
-  def _process_loop_block(self, block, scope):
-    cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced)
-    self.continuation_uses.append([False, cont_var])
-    block = self._visit_and_reindent_if_necessary(block)
-    if self.continuation_uses[-1][0]:
-      block.insert(0, self._create_continuation_init())
-    self.continuation_uses.pop()
-    return block
+    return templates.replace(
+        template, var_name=self.get_local(CONTROL_VAR_NAME))
+
+  def _postprocess_statement(self, node):
+    # Example of how the state machine below works:
+    #
+    #   1| stmt           # State: CONTINUE_USED = False
+    #    |                # Action: none
+    #   2| if cond:
+    #   3|   continue     # State: CONTINUE_USED = True,
+    #    |                #        GUARD_CREATED = False,
+    #    |                #        CREATE_GUARD_NEXT = False
+    #    |                # Action: set CREATE_GUARD_NEXT = True
+    #   4| stmt           # State: CONTINUE_USED = True,
+    #    |                #        GUARD_CREATED = False,
+    #    |                #        CREATE_GUARD_NEXT = True
+    #    |                # Action: create `if not continue_used`,
+    #    |                #         set GUARD_CREATED = True
+    #   5| stmt           # State: CONTINUE_USED = True, GUARD_CREATED = True
+    #    |                # Action: none (will be wrapped under previously
+    #    |                #         created if node)
+
+    if self.get_local(CONTINUE_USED, False):
+      if self.get_local(GUARD_CREATED, False):
+        return node, None
+
+      elif not self.get_local(CREATE_GUARD_NEXT, False):
+        self.set_local(CREATE_GUARD_NEXT, True)
+        return node, None
+
+      else:
+        self.set_local(GUARD_CREATED, True)
+        template = """
+          if not var_name:
+            original_node
+        """
+        cond, = templates.replace(
+            template,
+            var_name=self.get_local(CONTROL_VAR_NAME),
+            original_node=node)
+        return cond, cond.body
+    return node, None
+
+  def _visit_loop_body(self, node, nodes):
+    self.enter_local_scope()
+    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+    continue_var = self.context.namer.new_symbol('continue_', scope.referenced)
+    self.set_local(CONTROL_VAR_NAME, continue_var)
+
+    nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
+
+    if self.get_local(CONTINUE_USED, False):
+      template = """
+        var_name = False
+      """
+      control_var_init = templates.replace(template, var_name=continue_var)
+      nodes = control_var_init + nodes
+
+    self.exit_local_scope()
+    return nodes
+
+  def _visit_non_loop_body(self, nodes):
+    self.enter_local_scope(inherit=(CONTROL_VAR_NAME,))
+    nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
+    continue_used = self.get_local(CONTINUE_USED, False)
+    self.exit_local_scope(keep=(CONTINUE_USED,))
+    return nodes, continue_used
 
   def visit_While(self, node):
-    self.generic_visit(node.test)
-    node.body = self._process_loop_block(node.body,
-                                         anno.getanno(node,
-                                                      NodeAnno.BODY_SCOPE))
-    for n in node.orelse:
-      self.generic_visit(n)
+    node.test = self.visit(node.test)
+    node.body = self._visit_loop_body(node, node.body)
+    # A continue in the else clause applies to the containing scope.
+    node.orelse, _ = self._visit_non_loop_body(node.orelse)
     return node
 
   def visit_For(self, node):
-    self.generic_visit(node.target)
-    self.generic_visit(node.iter)
-    node.body = self._process_loop_block(node.body,
-                                         anno.getanno(node,
-                                                      NodeAnno.BODY_SCOPE))
-    for n in node.orelse:
-      self.generic_visit(n)
+    node.target = self.generic_visit(node.target)
+    node.iter = self.generic_visit(node.iter)
+    node.body = self._visit_loop_body(node, node.body)
+    # A continue in the else clause applies to the containing scope.
+    node.orelse, _ = self._visit_non_loop_body(node.orelse)
     return node
 
   def visit_If(self, node):
-    if self.continuation_uses:
-      self.generic_visit(node.test)
-      node.body = self._visit_and_reindent_if_necessary(node.body)
-      continue_used_in_body = self.continuation_uses[-1][0]
-      node.orelse = self._visit_and_reindent_if_necessary(node.orelse)
-      self.continuation_uses[-1][0] = (
-          continue_used_in_body or self.continuation_uses[-1][0])
-    else:
-      node = self.generic_visit(node)
+    node.test = self.generic_visit(node.test)
+    node.body, continue_used_body = self._visit_non_loop_body(node.body)
+    node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse)
+    self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse)
     return node
 
-  def visit_Continue(self, node):
-    self.continuation_uses[-1][0] = True
-    return self._create_continuation_trigger()
-
-  def visit_Break(self, node):
-    assert False, 'break statement should be desugared at this point'
+  def visit_With(self, node):
+    node.items = self.visit_block(node.items)
+    node.body, _ = self._visit_non_loop_body(node.body)
+    return node
 
 
 def transform(node, namer):
index 4c65edb..60bca8b 100644 (file)
@@ -70,14 +70,40 @@ class Base(gast.NodeTransformer):
     return tuple(self._enclosing_entities)
 
   @property
-  def locel_scope_level(self):
+  def local_scope_level(self):
     return len(self._local_scope_state)
 
-  def enter_local_scope(self):
-    self._local_scope_state.append({})
+  def enter_local_scope(self, inherit=None):
+    """Marks entry into a new local scope.
 
-  def exit_local_scope(self):
-    return self._local_scope_state.pop()
+    Args:
+      inherit: Optional enumerable of variable names to copy from the
+          parent scope.
+    """
+    scope_entered = {}
+    if inherit:
+      this_scope = self._local_scope_state[-1]
+      for name in inherit:
+        if name in this_scope:
+          scope_entered[name] = this_scope[name]
+    self._local_scope_state.append(scope_entered)
+
+  def exit_local_scope(self, keep=None):
+    """Marks exit from the current local scope.
+
+    Args:
+      keep: Optional enumerable of variable names to copy into the
+          parent scope.
+    Returns:
+      A dict containing the scope that has just been exited.
+    """
+    scope_left = self._local_scope_state.pop()
+    if keep:
+      this_scope = self._local_scope_state[-1]
+      for name in keep:
+        if name in scope_left:
+          this_scope[name] = scope_left[name]
+    return scope_left
 
   def set_local(self, name, value):
     self._local_scope_state[-1][name] = value
@@ -91,16 +117,76 @@ class Base(gast.NodeTransformer):
       print(pretty_printer.fmt(node))
     return node
 
-  def visit_block(self, nodes):
-    """Helper equivalent to generic_visit, but for node lists."""
+  def visit_block(self, nodes, before_visit=None, after_visit=None):
+    """A more powerful version of generic_visit for statement blocks.
+
+    An example of a block is the body of an if statement.
+
+    This function allows specifying a postprocessing callback (the
+    after_visit argument) argument which can be used to move nodes to a new
+    destination. This is done by after_visit by returning a non-null
+    second return value, e.g. return new_node, new_destination.
+
+    For example, a transformer could perform the following move:
+
+        foo()
+        bar()
+        baz()
+
+        foo()
+        if cond:
+          bar()
+          baz()
+
+    The above could be done with a postprocessor of this kind:
+
+        def after_visit(node):
+          if node_is_function_call(bar):
+            new_container_node = build_cond()
+            new_container_node.body.append(node)
+            return new_container_node, new_container_node.body
+          else:
+            # Once we set a new destination, all subsequent items will be
+            # moved to it, so we don't need to explicitly handle baz.
+            return node, None
+
+    Args:
+      nodes: enumerable of AST node objects
+      before_visit: optional callable that is called before visiting each item
+          in nodes
+      after_visit: optional callable that takes in an AST node and
+          returns a tuple (new_node, new_destination). It is called after
+          visiting each item in nodes. Is used in the same was as the
+          visit_* methods: new_node will replace the node; if not None,
+          new_destination must be a list, and subsequent nodes will be placed
+          in this list instead of the list returned by visit_block.
+    Returns:
+      A list of AST node objects containing the transformed items fron nodes,
+      except those nodes that have been relocated using after_visit.
+    """
     results = []
+    node_destination = results
     for node in nodes:
+      if before_visit:
+        # TODO(mdan): We can modify node here too, if ever needed.
+        before_visit()
+
       replacement = self.visit(node)
+
+      if after_visit and replacement:
+        replacement, new_destination = after_visit(replacement)
+      else:
+        new_destination = None
+
       if replacement:
         if isinstance(replacement, (list, tuple)):
-          results.extend(replacement)
+          node_destination.extend(replacement)
         else:
-          results.append(replacement)
+          node_destination.append(replacement)
+
+      # Allow the postprocessor to reroute the remaining nodes to a new list.
+      if new_destination is not None:
+        node_destination = new_destination
     return results
 
   # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
@@ -155,22 +241,39 @@ class Base(gast.NodeTransformer):
     source_code = self.context.source_code
     source_file = self.context.source_file
     did_enter_function = False
-    local_scope_state_size = len(self._local_scope_state)
+    local_scope_size_at_entry = len(self._local_scope_state)
 
     try:
       if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
-        self._enclosing_entities.append(node)
         did_enter_function = True
 
+      if did_enter_function:
+        self._enclosing_entities.append(node)
+
       if source_code and hasattr(node, 'lineno'):
         self._lineno = node.lineno
         self._col_offset = node.col_offset
-      if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
-        return node
-      return super(Base, self).visit(node)
 
-    except (ValueError, AttributeError, KeyError, NotImplementedError,
-            AssertionError) as e:
+      if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+        result = super(Base, self).visit(node)
+
+      # On exception, the local scope integrity is not guaranteed.
+      if did_enter_function:
+        self._enclosing_entities.pop()
+
+      if local_scope_size_at_entry != len(self._local_scope_state):
+        raise AssertionError(
+            'Inconsistent local scope stack. Before entering node %s, the'
+            ' stack had length %d, after exit it has length %d. This'
+            ' indicates enter_local_scope and exit_local_scope are not'
+            ' well paired.' % (
+                node,
+                local_scope_size_at_entry,
+                len(self._local_scope_state)
+            ))
+      return result
+
+    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
       msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
           e.__class__.__name__, str(e), try_ast_to_source(node),
           pretty_printer.fmt(node, color=False))
@@ -178,18 +281,11 @@ class Base(gast.NodeTransformer):
         line = source_code.splitlines()[self._lineno - 1]
       else:
         line = '<no source available>'
+      # TODO(mdan): Avoid the printing of the original exception.
+      # In other words, we need to find how to suppress the "During handling
+      # of the above exception, another exception occurred" message.
       six.reraise(AutographParseError,
                   AutographParseError(
                       msg,
                       (source_file, self._lineno, self._col_offset + 1, line)),
                   sys.exc_info()[2])
-    finally:
-      if did_enter_function:
-        self._enclosing_entities.pop()
-
-      if local_scope_state_size != len(self._local_scope_state):
-        raise AssertionError(
-            'Inconsistent local scope stack. Before entering node %s, the'
-            ' stack had length %d, after exit it has length %d. This'
-            ' indicates enter_local_scope and exit_local_scope are not'
-            ' well paired.')
index 1f1adf4..f110e79 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import gast
+
 from tensorflow.contrib.autograph.pyct import anno
 from tensorflow.contrib.autograph.pyct import context
 from tensorflow.contrib.autograph.pyct import parser
@@ -27,7 +29,7 @@ from tensorflow.python.platform import test
 
 class TransformerTest(test.TestCase):
 
-  def _context_for_nodetesting(self):
+  def _context_for_testing(self):
     return context.EntityContext(
         namer=None,
         source_code=None,
@@ -53,7 +55,7 @@ class TransformerTest(test.TestCase):
         anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
         return self.generic_visit(node)
 
-    tr = TestTransformer(self._context_for_nodetesting())
+    tr = TestTransformer(self._context_for_testing())
 
     def test_function():
       a = 0
@@ -116,7 +118,7 @@ class TransformerTest(test.TestCase):
       def visit_For(self, node):
         return self._annotate_result(node)
 
-    tr = TestTransformer(self._context_for_nodetesting())
+    tr = TestTransformer(self._context_for_testing())
 
     def test_function(a):
       """Docstring."""
@@ -155,7 +157,7 @@ class TransformerTest(test.TestCase):
         self.exit_local_scope()
         return node
 
-    tr = TestTransformer(self._context_for_nodetesting())
+    tr = TestTransformer(self._context_for_testing())
 
     def no_exit(a):
       if a > 0:
@@ -174,6 +176,38 @@ class TransformerTest(test.TestCase):
     with self.assertRaises(AssertionError):
       tr.visit(node)
 
+  def test_visit_block_postprocessing(self):
+
+    class TestTransformer(transformer.Base):
+
+      def _process_body_item(self, node):
+        if isinstance(node, gast.Assign) and (node.value.id == 'y'):
+          if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
+          return if_node, if_node.body
+        return node, None
+
+      def visit_FunctionDef(self, node):
+        node.body = self.visit_block(
+            node.body, after_visit=self._process_body_item)
+        return node
+
+    def test_function(x, y):
+      z = x
+      z = y
+      return z
+
+    tr = TestTransformer(self._context_for_testing())
+
+    node, _ = parser.parse_entity(test_function)
+    node = tr.visit(node)
+    node = node.body[0]
+
+    self.assertEqual(len(node.body), 2)
+    self.assertTrue(isinstance(node.body[0], gast.Assign))
+    self.assertTrue(isinstance(node.body[1], gast.If))
+    self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
+    self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
+
 
 if __name__ == '__main__':
   test.main()