"continue_statements.py",
"control_flow.py",
"decorators.py",
- "for_loops.py",
"ifexp.py",
"list_comprehension.py",
"lists.py",
)
py_test(
- name = "for_loops_test",
- srcs = ["for_loops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":test_lib",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "name_scopes_test",
srcs = ["name_scopes_test.py"],
deps = [
def __init__(self, context):
super(BreakCanonicalizationTransformer, self).__init__(context)
# This is a stack structure, to correctly process nested loops.
+ # Each item is a list [break_used, break_variable_name]
self.break_uses = []
def _create_break_check(self):
self.break_uses.append([False, break_var])
node.body = self._manual_visit_list(node.body)
if self.break_uses[-1][0]:
- anno.setanno(node, 'extra_cond',
- gast.UnaryOp(gast.Not(),
- gast.Name(break_var, gast.Load(), None)))
+ extra_cond = templates.replace_as_expression(
+ 'not var_name', var_name=break_var)
+ anno.setanno(node, 'extra_cond', extra_cond)
final_nodes = [self._create_break_init(), node]
else:
final_nodes = node
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
def __init__(self, context):
super(ControlFlowTransformer, self).__init__(context)
- # pylint:disable=invalid-name
-
- def visit_For(self, node):
- assert False, 'for statement should have been canonicalized at this point'
-
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if aliased_orig_names:
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
+ cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
+ cond_closure = set()
+ for s in cond_scope.referenced:
+ for root in s.support_set:
+ if root not in body_scope.created:
+ cond_closure.add(root)
+
state = list(body_closure)
if not state:
# TODO(mdan): Implement this properly.
def body_name(state_ssf):
body
return state_ssf,
- state_ast_tuple = autograph_utils.run_while(test_name, body_name, [state])
+ state_ast_tuple = __ops.while_loop(
+ test_name, body_name, (state,), (extra_deps,))
"""
node = templates.replace(
template,
test=test,
body_name=self.context.namer.new_symbol('loop_body',
body_scope.referenced),
- body=node_body)
+ body=node_body,
+ extra_deps=tuple(s.ast() for s in cond_closure),
+ )
return node
- # pylint:enable=invalid-name
+ def visit_For(self, node):
+ self.generic_visit(node)
+
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ body_closure = body_scope.modified - body_scope.created
+ all_referenced = body_scope.referenced
+
+ state = list(body_closure)
+
+ state_ssf = [
+ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ ]
+ ssf_map = {
+ name: ssf
+ for name, ssf in zip(state, state_ssf)
+ if str(name) != ssf
+ }
+
+ if len(state) == 1:
+ state = state[0]
+ state_ssf = state_ssf[0]
+ state_ast_tuple = state
+ else:
+ state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+ node_body = ast_util.rename_symbols(node.body, ssf_map)
+ if anno.hasanno(node, 'extra_cond'):
+ extra_cond = anno.getanno(node, 'extra_cond')
+ extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+ else:
+ extra_cond = parser.parse_expression('True')
+
+ template = """
+ def extra_cond_name(state_ssf):
+ return extra_cond_expr
+ def body_name(iterate, state_ssf):
+ body
+ return state_ssf,
+ state_ast_tuple = __ops.for_loop(
+ iterated, extra_cond_name, body_name, (state,))
+ """
+ node = templates.replace(
+ template,
+ state=state,
+ state_ssf=state_ssf,
+ state_ast_tuple=state_ast_tuple,
+ iterated=node.iter,
+ iterate=node.target,
+ extra_cond_name=self.context.namer.new_symbol('extra_cond',
+ all_referenced),
+ extra_cond_expr=extra_cond,
+ body_name=self.context.namer.new_symbol('loop_body', all_referenced),
+ body=node_body)
+
+ return node
def transform(node, context):
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
with self.test_session() as sess:
self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+ def test_simple_for(self):
+
+ def test_fn(l):
+ s1 = 0
+ s2 = 0
+ for e in l:
+ s1 += e
+ s2 += e * e
+ return s1, s2
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ l = [1, 2, 3]
+ self.assertEqual(
+ test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
+ l = []
+ self.assertEqual(
+ test_fn(l),
+ sess.run(
+ result.test_fn(
+ constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
+
+ def test_for_single_var(self):
+
+ def test_fn(l):
+ s = 0
+ for e in l:
+ s += e
+ return s
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ l = [1, 2, 3]
+ self.assertEqual(
+ test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
+ l = []
+ self.assertEqual(
+ test_fn(l),
+ sess.run(
+ result.test_fn(
+ constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
+
+ def test_for_with_iterated_expression(self):
+
+ eval_count = [0]
+
+ def count_evals(x):
+ eval_count[0] += 1
+ return x
+
+ def test_fn(n):
+ s = 0
+ for e in count_evals(range(n)):
+ s += e
+ return s
+
+ node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ result.count_evals = count_evals
+ self.assertEqual(test_fn(5), result.test_fn(5))
+ # count_evals ran twice, once for test_fn and another for result.test_fn
+ self.assertEqual(eval_count[0], 2)
+
if __name__ == '__main__':
test.main()
result.tf = self.make_fake_mod('fake_tf', *symbols)
result.autograph_utils = utils
result.autograph_api = self.make_fake_mod('fake_api', converted_call)
- result.__ops = operators # pylint:disable=protected-access
+ result.__dict__['__ops'] = operators
yield result
except Exception: # pylint:disable=broad-except
if source is None:
+++ /dev/null
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Canonicalizes for loops into while loops.
-
-This canonicalizer uses the len function on its argument. That should be
-converted to a tf.shape separately.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
-
-
-class ForLoopCanonicalizationTransformer(transformer.Base):
- """Canonicalizes for loops (e.g. into while loops)."""
-
- def __init__(self, context):
- super(ForLoopCanonicalizationTransformer, self).__init__(context)
-
- def visit_For(self, node):
- self.generic_visit(node)
- body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- i_var = self.context.namer.new_symbol('i', body_scope.referenced)
- smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter',
- body_scope.referenced)
- cont_var = self.context.namer.new_symbol('cont', body_scope.referenced)
- # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
- if anno.hasanno(node, 'extra_cond'):
- template = """
- i = 0
- smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter)
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- while cont and extra_cond:
- body
- i += 1
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- """
- return templates.replace(
- template,
- loop_iter=node.iter,
- target=node.target,
- body=node.body,
- i=i_var,
- smart_loop_iter=smart_loop_iter_var,
- cont=cont_var,
- extra_cond=anno.getanno(node, 'extra_cond'))
- else:
- template = """
- i = 0
- smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter)
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- while cont:
- body
- i += 1
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- """
- repl = templates.replace(
- template,
- loop_iter=node.iter,
- target=node.target,
- body=node.body,
- i=i_var,
- smart_loop_iter=smart_loop_iter_var,
- cont=cont_var)
- return repl
-
- def visit_Continue(self, node):
- assert False, 'continue statement should be desugared at this point'
-
- def visit_Break(self, node):
- assert False, 'break statement should be desugared at this point'
-
-
-def transform(node, context):
- return ForLoopCanonicalizationTransformer(context).visit(node)
+++ /dev/null
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for for_loops module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.converters import converter_test_base
-from tensorflow.contrib.autograph.converters import for_loops
-from tensorflow.python.platform import test
-
-
-class ControlFlowTest(converter_test_base.TestCase):
-
- def test_basic_for(self):
-
- def test_fn(l):
- s = 0
- for e in l:
- s += e
- return s
-
- node = self.parse_and_analyze(test_fn, {})
- node = for_loops.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- l = [1, 2, 3]
- self.assertEqual(test_fn(l), result.test_fn(l))
- l = []
- self.assertEqual(test_fn(l), result.test_fn(l))
-
- def test_for_with_iterated_expression(self):
-
- eval_count = [0]
-
- def count_evals(x):
- eval_count[0] += 1
- return x
-
- def test_fn(n):
- s = 0
- for e in count_evals(range(n)):
- s += e
- return s
-
- node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
- node = for_loops.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- result.count_evals = count_evals
- self.assertEqual(test_fn(5), result.test_fn(5))
- # count_evals ran twice, once for test_fn and another for result.test_fn
- self.assertEqual(eval_count[0], 2)
-
-
-if __name__ == '__main__':
- test.main()
compiled_code = api.to_code(test_fn)
- # Just check for some key words and that it is parseable Python code.
- self.assertRegexpMatches(compiled_code, 'autograph_utils\\.run_while')
+ # Just check that it is parseable Python code.
self.assertIsNotNone(parser.parse_str(compiled_code))
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import for_loops
from tensorflow.contrib.autograph.converters import ifexp
from tensorflow.contrib.autograph.converters import lists
from tensorflow.contrib.autograph.converters import logical_expressions
node = _static_analysis_pass(node, ctx)
node = lists.transform(node, ctx)
- node = for_loops.transform(node, ctx)
- # for_loops may insert new global references.
node = builtin_functions.transform(node, ctx)
node = _static_analysis_pass(node, ctx)
from __future__ import print_function
from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_dataset
-from tensorflow.contrib.autograph.utils.builtins import dynamic_for_cond
from tensorflow.contrib.autograph.utils.builtins import dynamic_print
from tensorflow.contrib.autograph.utils.builtins import dynamic_range
from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not
from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond
-from tensorflow.contrib.autograph.utils.multiple_dispatch import run_while
from tensorflow.contrib.autograph.utils.py_func import wrap_py_func
from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append
from tensorflow.contrib.autograph.utils.testing import fake_tf
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect
return py_func.wrap_py_func(
flushed_print, None, values, use_dummy_return=True)
-
-
-def dynamic_dataset(iterated):
- """Implementartion of smart tf.data.Dataset epoch wrapping.
-
- The function checks if the input is a tf.data.Dataset and if so then wraps it
- so that for each element it returns it also returns the current epoch the
- dataset iteration is in, for two epochs. If the input is not a
- tf.data.Dataset then it just returns the input.
-
- Args:
- iterated: The iterable or tf.data.Dataset that is being iterated over.
- Returns:
- Either just the untouched input, or in the case of input being a
- tf.data.Dataset then it returns a wrapped tf.data.Dataset where for each
- element it returns it also returns the current epoch the dataset iteration
- is in.
- """
- if not isinstance(iterated, dataset_ops.Dataset):
- return iterated
-
- def epoch_dataset_number_helper(i):
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(i).repeat(), iterated))
-
- epoch_numbers = dataset_ops.Dataset.range(2)
- return epoch_numbers.flat_map(epoch_dataset_number_helper)
-
-
-def dynamic_for_cond(iteration, iterated):
- """Implementartion of smart while-loop condition using dynamic dispatch.
-
- The function checks if it is iterating over a tf.data.Dataset or not, and in
- the case it is not then it simply returns if we are still in range of the
- iterated and the next element. If it is iterating over a dataset then it only
- iterates for a single epoch.
-
- Args:
- iteration: The current iteration of the loop.
- iterated: The iterable or tf.data.Dataset that is being iterated over.
- Returns:
- A tuple of a bool that indicates whether the loop should continue, and the
- next element in iterated.
- """
- # TODO(znado): Clean up.
- # TODO(znado): This won't work for unpacked iterates. Fix.
- if isinstance(iterated, dataset_ops.Dataset):
- curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next()
- return math_ops.less(curr_epoch, 1), next_elem
- elif tensor_util.is_tensor(iterated):
- if iterated.shape.ndims > 1:
- elem_shape = array_ops.shape(iterated)[1:]
- else:
- elem_shape = ()
- if iterated.shape.ndims == 0 or iterated.shape[0] == 0:
- return False, array_ops.zeros(elem_shape, iterated.dtype)
- return control_flow_ops.cond(
- math_ops.less(iteration, dynamic_len(iterated)),
- lambda: (True, iterated[iteration]),
- lambda: (False, array_ops.zeros(elem_shape, iterated.dtype)))
- elif hasattr(iterated, '__len__'):
- if iteration < len(iterated):
- return True, iterated[iteration]
- return False, None
- else:
- raise NotImplementedError('Python iterators not yet supported.')
from __future__ import division
from __future__ import print_function
-import six
-
from tensorflow.contrib.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
if len(results) == 1:
return results[0]
return results
-
-
-def run_while(cond_fn, body_fn, init_args):
- """Type-dependent functional while loop.
-
- Args:
- cond_fn: A Python callable implementing the stop conditions of the loop.
- body_fn: A Python callable implementing the body of the loop.
- init_args: The initial values of the arguments that will be passed to both
- cond_fn and body_fn.
-
- Returns:
- result: A list of values with the same shape and type as init_args. If any
- of the init_args, or any variables closed-over in cond_fn are Tensors,
- tf.while_loop will be used, otherwise a Python while loop will be ran.
-
- Raises:
- ValueError: if init_args is not a tuple or list with one or more elements.
- """
- if not isinstance(init_args, (tuple, list)) or not init_args:
- raise ValueError(
- 'init_args must be a non-empty list or tuple, found %s' % init_args)
-
- # TODO(alexbw): statically determine all active variables in cond_fn,
- # and pass them directly
- closure_vars = tuple(
- [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
- possibly_tensors = tuple(init_args) + closure_vars
- if is_tensor(*possibly_tensors):
- return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
- else:
- return py_while_loop(cond_fn, body_fn, init_args)
-
-
-def py_while_loop(cond_fn, body_fn, init_args):
- state = init_args
- while cond_fn(*state):
- state = body_fn(*state)
- return state
out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
self.assertEqual(sess.run(out), 3)
- def test_run_while_python(self):
- cond_fn = lambda x, t, s: x > t
- body_fn = lambda x, t, s: (x * s, t, s)
-
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5])
- self.assertEqual(x, 0.75)
-
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5])
- self.assertEqual(x, 3.0)
-
- def test_run_while_tf(self):
- cond_fn = lambda x, t, s: x > t
- body_fn = lambda x, t, s: (x * s, t, s)
-
- with Session() as sess:
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn,
- [constant(3.0), 1.0, 0.5])
- self.assertEqual(sess.run(x), 0.75)
-
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn,
- [constant(3.0), 4.0, 0.5])
- self.assertEqual(sess.run(x), 3.0)
-
if __name__ == '__main__':
test.main()