from __future__ import division
from __future__ import print_function
+import gast
+
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
def _guard_if_present(self, block, var_name):
"""Prevents the block from executing if var_name is set."""
+
+ # If we don't have statements that immediately depend on the break
+ # we still need to make sure that the break variable remains
+ # used, in case the break becomes useful in later stages of transformation.
+ # Not having this broke the break_in_inner_loop test.
if not block:
- return block
+ block = [gast.Pass()]
template = """
if not var_name:
block
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Handles control flow statements: while, if."""
+"""Handles control flow statements: while, for, if."""
from __future__ import absolute_import
from __future__ import division
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import cfg
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
class ControlFlowTransformer(transformer.Base):
"""Transforms control flow structures like loops an conditionals."""
- def __init__(self, context):
- super(ControlFlowTransformer, self).__init__(context)
-
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if aliased_orig_names:
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
-
- if body_scope.created - orelse_scope.created:
- raise ValueError(
- 'The if branch creates new symbols that the else branch does not.')
- if orelse_scope.created - body_scope.created:
- raise ValueError(
- 'The else branch creates new symbols that the if branch does not.')
-
- modified = tuple(body_scope.modified | orelse_scope.modified)
- all_referenced = body_scope.referenced | orelse_scope.referenced
+ body_defs = body_scope.created | body_scope.modified
+ orelse_defs = orelse_scope.created | orelse_scope.modified
+ live = anno.getanno(node, 'live_out')
+
+ # We'll need to check if we're closing over variables that are defined
+ # elsewhere in the function
+ # NOTE: we can only detect syntactic closure in the scope
+ # of the code passed in. If the AutoGraph'd function itself closes
+ # over other variables, this analysis won't take that into account.
+ defined = anno.getanno(node, 'defined_in')
+
+ # We only need to return variables that are
+ # - modified by one or both branches
+ # - live (or has a live parent) at the end of the conditional
+ modified = []
+ for def_ in body_defs | orelse_defs:
+ def_with_parents = set((def_,)) | def_.support_set
+ if live & def_with_parents:
+ modified.append(def_)
+
+ # We need to check if live created variables are balanced
+ # in both branches
+ created = live & (body_scope.created | orelse_scope.created)
+
+ # The if statement is illegal if there are variables that are created,
+ # that are also live, but both branches don't create them.
+ if created:
+ if created != (body_scope.created & live):
+ raise ValueError(
+ 'The main branch does not create all live symbols that the else '
+ 'branch does.')
+ if created != (orelse_scope.created & live):
+ raise ValueError(
+ 'The else branch does not create all live symbols that the main '
+ 'branch does.')
# Alias the closure variables inside the conditional functions
# to avoid errors caused by the local variables created in the branch
# functions.
- need_alias = (
- (body_scope.modified | orelse_scope.modified) -
- (body_scope.created | orelse_scope.created))
- aliased_orig_names = tuple(need_alias)
- aliased_new_names = tuple(
- self.context.namer.new_symbol(s.ssf(), all_referenced)
- for s in aliased_orig_names)
- alias_map = dict(zip(aliased_orig_names, aliased_new_names))
- node_body = ast_util.rename_symbols(node.body, alias_map)
- node_orelse = ast_util.rename_symbols(node.orelse, alias_map)
+ # We will alias variables independently for body and orelse scope,
+ # because different branches might write different variables.
+ aliased_body_orig_names = tuple(body_scope.modified - body_scope.created)
+ aliased_orelse_orig_names = tuple(orelse_scope.modified -
+ orelse_scope.created)
+ aliased_body_new_names = tuple(
+ self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
+ for s in aliased_body_orig_names)
+ aliased_orelse_new_names = tuple(
+ self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
+ for s in aliased_orelse_orig_names)
+
+ alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
+ alias_orelse_map = dict(
+ zip(aliased_orelse_orig_names, aliased_orelse_new_names))
+
+ node_body = ast_util.rename_symbols(node.body, alias_body_map)
+ node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
if not modified:
# When the cond would return no value, we leave the cond called without
else:
results = gast.Tuple([s.ast() for s in modified], None)
- body_name = self.context.namer.new_symbol('if_true', all_referenced)
- orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
+ body_name = self.context.namer.new_symbol('if_true', body_scope.referenced)
+ orelse_name = self.context.namer.new_symbol('if_false',
+ orelse_scope.referenced)
if modified:
- body_returns = tuple(
- alias_map[s] if s in aliased_orig_names else s for s in modified)
+
+ def build_returns(aliased_names, alias_map, scope):
+ """Builds list of return variables for a branch of a conditional."""
+ returns = []
+ for s in modified:
+ if s in aliased_names:
+ returns.append(alias_map[s])
+ else:
+ if s not in scope.created | defined:
+ raise ValueError(
+ 'Attempting to return variable "%s" from the true branch of '
+ 'a conditional, but it was not closed over, or created in '
+ 'this branch.' % str(s))
+ else:
+ returns.append(s)
+ return tuple(returns)
+
+ body_returns = build_returns(aliased_body_orig_names, alias_body_map,
+ body_scope)
+ orelse_returns = build_returns(aliased_orelse_orig_names,
+ alias_orelse_map, orelse_scope)
+
else:
- body_returns = templates.replace('tf.ones(())')[0].value
+ body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value
body_def = self._create_cond_branch(
body_name,
- aliased_orig_names=tuple(aliased_orig_names),
- aliased_new_names=tuple(aliased_new_names),
+ aliased_orig_names=tuple(aliased_body_orig_names),
+ aliased_new_names=tuple(aliased_body_new_names),
body=node_body,
returns=body_returns)
orelse_def = self._create_cond_branch(
orelse_name,
- aliased_orig_names=tuple(aliased_orig_names),
- aliased_new_names=tuple(aliased_new_names),
+ aliased_orig_names=tuple(aliased_orelse_orig_names),
+ aliased_new_names=tuple(aliased_orelse_new_names),
body=node_orelse,
- returns=body_returns)
+ returns=orelse_returns)
cond_expr = self._create_cond_expr(results, node.test, body_name,
orelse_name)
def transform(node, context):
- t = ControlFlowTransformer(context)
- node = t.visit(node)
+ cfg.run_analyses(node, cfg.Liveness(context))
+ cfg.run_analyses(node, cfg.Defined(context))
+ node = ControlFlowTransformer(context).visit(node)
return node
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
with self.test_session() as sess:
self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+ def test_imbalanced_aliasing(self):
+
+ def test_fn(n):
+ if n > 0:
+ n = 3
+ return n
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond) as result:
+ with self.test_session() as sess:
+ self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2))))
+ self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+
+ def test_ignore_unread_variable(self):
+
+ def test_fn(n):
+ b = 3 # pylint: disable=unused-variable
+ if n > 0:
+ b = 4
+ return n
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+ with self.test_session() as sess:
+ self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3))))
+ self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+
+ def test_handle_temp_variable(self):
+
+ def test_fn_using_temp(x, y, w):
+ if x < y:
+ z = x + y
+ else:
+ w = 2
+ tmp = w
+ z = x - tmp
+ return z, w
+
+ node = self.parse_and_analyze(test_fn_using_temp, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+ with self.test_session() as sess:
+ z, w = sess.run(
+ result.test_fn_using_temp(
+ constant_op.constant(-3), constant_op.constant(3),
+ constant_op.constant(3)))
+ self.assertEqual(0, z)
+ self.assertEqual(3, w)
+ z, w = sess.run(
+ result.test_fn_using_temp(
+ constant_op.constant(3), constant_op.constant(-3),
+ constant_op.constant(3)))
+ self.assertEqual(1, z)
+ self.assertEqual(2, w)
+
+ def test_fn_ignoring_temp(x, y, w):
+ if x < y:
+ z = x + y
+ else:
+ w = 2
+ tmp = w
+ z = x - tmp
+ return z
+
+ node = self.parse_and_analyze(test_fn_ignoring_temp, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+ with self.test_session() as sess:
+ z = sess.run(
+ result.test_fn_ignoring_temp(
+ constant_op.constant(-3), constant_op.constant(3),
+ constant_op.constant(3)))
+ self.assertEqual(0, z)
+ z = sess.run(
+ result.test_fn_ignoring_temp(
+ constant_op.constant(3), constant_op.constant(-3),
+ constant_op.constant(3)))
+ self.assertEqual(1, z)
+
def test_simple_for(self):
def test_fn(l):
# Handle the body
self.visit_statements(node.body)
body_exit = self.current_leaves
- self.current_leaves = []
- self.current_leaves.append(test)
+ self.current_leaves = [test]
# Handle the orelse
self.visit_statements(node.orelse)
self.current_leaves.extend(body_exit)
self.continue_.append([])
# Handle the body
self.visit_statements(node.body)
+ body_exit = self.current_leaves
self.current_leaves.extend(self.continue_.pop())
self.set_current_leaves(test)
# Handle the orelse
self.visit_statements(node.orelse)
# The break statements and the test go to the next node
self.current_leaves.extend(self.break_.pop())
+ # Body and orelse statements can reach out of the loop
+ self.current_leaves.extend(body_exit)
def visit_For(self, node):
iter_ = CfgNode(node.iter)
self.break_.append([])
self.continue_.append([])
self.visit_statements(node.body)
+ body_exit = self.current_leaves
self.current_leaves.extend(self.continue_.pop())
self.set_current_leaves(iter_)
+ # Handle the orelse
+ self.visit_statements(node.orelse)
+ # The break statements and the test go to the next node
self.current_leaves.extend(self.break_.pop())
+ # Body and orelse statements can reach out of the loop
+ self.current_leaves.extend(body_exit)
def visit_Break(self, node):
self.break_[-1].extend(self.current_leaves)
super(Liveness, self).__init__('live', context)
def get_gen_kill(self, node, _):
+ # A variable's parents are live if it is live
+ # e.g. x is live if x.y is live. This means gen needs to return
+ # all parents of a variable (if it's an Attribute or Subscript).
+ # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x)
gen = activity.get_read(node.value, self.context)
+ gen = functools.reduce(lambda left, right: left | right.support_set, gen,
+ gen)
kill = activity.get_updated(node.value, self.context)
return gen, kill
anno.getanno(body[2], 'defined_in'),
frozenset(map(qual_names.QN, ('x', 'g'))))
+ def test_loop_else(self):
+
+ # Disabling useless-else-on-loop error, because 'break' and 'continue'
+ # canonicalization are a separate analysis pass, and here we test
+ # the CFG analysis in isolation.
+ def for_orelse(x):
+ y = 0
+ for i in range(len(x)):
+ x += i
+ else: # pylint: disable=useless-else-on-loop
+ y = 1
+ return x, y
+
+ def while_orelse(x, i):
+ y = 0
+ while x < 10:
+ x += i
+ else: # pylint: disable=useless-else-on-loop
+ y = 1
+ return x, y
+
+ for f in (for_orelse, while_orelse):
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
+ body = node.body[0].body
+ return_node = body[-1]
+ reaching_defs = anno.getanno(return_node, 'definitions_in')
+
+ # Y could be defined by Assign(Num(0)) or Assign(Num(1))
+ # X could be defined as an argument or an AugAssign.
+ y_defs = [node for var, node in reaching_defs if str(var) == 'y']
+ x_defs = [node for var, node in reaching_defs if str(var) == 'x']
+
+ self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs))
+ self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs))
+ self.assertEqual(len(y_defs), 2)
+ self.assertEqual(
+ set((gast.arguments, gast.AugAssign)),
+ set(type(def_) for def_ in x_defs))
+ self.assertEqual(len(x_defs), 2)
+
if __name__ == '__main__':
test.main()