From 5dd290ea8ccf3320215fd84e982724c91e72100c Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 29 May 2018 09:46:21 -0700 Subject: [PATCH] Clean up: handle the hidden additional clause on for loops in a way that's clearer about what it does. PiperOrigin-RevId: 198409797 --- .../autograph/converters/break_statements.py | 49 +++++++++++++--------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 3587722..5b7508c 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer @@ -54,13 +52,9 @@ class BreakStatementTransformer(transformer.Base): def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" - - # If we don't have statements that immediately depend on the break - # we still need to make sure that the break variable remains - # used, in case the break becomes useful in later stages of transformation. - # Not having this broke the break_in_inner_loop test. if not block: - block = [gast.Pass()] + return block + template = """ if not var_name: block @@ -73,7 +67,7 @@ class BreakStatementTransformer(transformer.Base): def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -81,6 +75,10 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + template = """ var_name = False while test and not var_name: @@ -88,20 +86,18 @@ class BreakStatementTransformer(transformer.Base): else: orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, - orelse=self._guard_if_present(node.orelse, break_var)) + orelse=guarded_orelse) return node def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -110,19 +106,32 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: - node.orelse = self._guard_if_present(node.orelse, break_var) + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False - for_stmt + for target in iter_: + (var_name,) + body + else: + orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, - for_stmt=node) - extra_test = templates.replace_as_expression( - 'not var_name', var_name=break_var) + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + anno.setanno(node[1], 'extra_test', extra_test) return node -- 2.7.4