Clean up: handle the hidden additional clause on for loops in a way that's clearer...
authorDan Moldovan <mdan@google.com>
Tue, 29 May 2018 16:46:21 +0000 (09:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 16:49:02 +0000 (09:49 -0700)
PiperOrigin-RevId: 198409797

tensorflow/contrib/autograph/converters/break_statements.py

index 3587722..5b7508c 100644 (file)
@@ -18,8 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import gast
-
 from tensorflow.contrib.autograph.pyct import anno
 from tensorflow.contrib.autograph.pyct import templates
 from tensorflow.contrib.autograph.pyct import transformer
@@ -54,13 +52,9 @@ class BreakStatementTransformer(transformer.Base):
 
   def _guard_if_present(self, block, var_name):
     """Prevents the block from executing if var_name is set."""
-
-    # If we don't have statements that immediately depend on the break
-    # we still need to make sure that the break variable remains
-    # used, in case the break becomes useful in later stages of transformation.
-    # Not having this broke the break_in_inner_loop test.
     if not block:
-      block = [gast.Pass()]
+      return block
+
     template = """
         if not var_name:
           block
@@ -73,7 +67,7 @@ class BreakStatementTransformer(transformer.Base):
 
   def visit_While(self, node):
     scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
-    break_var = self.context.namer.new_symbol('break__', scope.referenced)
+    break_var = self.context.namer.new_symbol('break_', scope.referenced)
 
     node.test = self.visit(node.test)
     node.body, break_used = self._track_body(node.body, break_var)
@@ -81,6 +75,10 @@ class BreakStatementTransformer(transformer.Base):
     node.orelse = self.visit_block(node.orelse)
 
     if break_used:
+      # Python's else clause only triggers if the loop exited cleanly (e.g.
+      # break did not trigger).
+      guarded_orelse = self._guard_if_present(node.orelse, break_var)
+
       template = """
         var_name = False
         while test and not var_name:
@@ -88,20 +86,18 @@ class BreakStatementTransformer(transformer.Base):
         else:
           orelse
       """
-      # Python's else clause only triggers if the loop exited cleanly (e.g.
-      # break did not trigger).
       node = templates.replace(
           template,
           var_name=break_var,
           test=node.test,
           body=node.body,
-          orelse=self._guard_if_present(node.orelse, break_var))
+          orelse=guarded_orelse)
 
     return node
 
   def visit_For(self, node):
     scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
-    break_var = self.context.namer.new_symbol('break__', scope.referenced)
+    break_var = self.context.namer.new_symbol('break_', scope.referenced)
 
     node.target = self.visit(node.target)
     node.iter = self.visit(node.iter)
@@ -110,19 +106,32 @@ class BreakStatementTransformer(transformer.Base):
     node.orelse = self.visit_block(node.orelse)
 
     if break_used:
-      node.orelse = self._guard_if_present(node.orelse, break_var)
+      # Python's else clause only triggers if the loop exited cleanly (e.g.
+      # break did not trigger).
+      guarded_orelse = self._guard_if_present(node.orelse, break_var)
+      extra_test = templates.replace_as_expression(
+          'not var_name', var_name=break_var)
+
+      # The extra test is hidden in the AST, which will confuse the static
+      # analysis. To mitigate that, we insert a no-op statement that ensures
+      # the control variable is marked as used.
+      # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
       template = """
         var_name = False
-        for_stmt
+        for target in iter_:
+          (var_name,)
+          body
+        else:
+          orelse
       """
-      # Python's else clause only triggers if the loop exited cleanly (e.g.
-      # break did not trigger).
       node = templates.replace(
           template,
           var_name=break_var,
-          for_stmt=node)
-      extra_test = templates.replace_as_expression(
-          'not var_name', var_name=break_var)
+          iter_=node.iter,
+          target=node.target,
+          body=node.body,
+          orelse=guarded_orelse)
+
       anno.setanno(node[1], 'extra_test', extra_test)
 
     return node