def body_name(state_ssf):
body
return state_ssf,
- state_ast_tuple = ag__.while_loop(
+ state_ast_tuple = ag__.while_stmt(
test_name, body_name, (state,), (extra_deps,))
"""
node = templates.replace(
state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
node_body = ast_util.rename_symbols(node.body, ssf_map)
- if anno.hasanno(node, 'extra_cond'):
- extra_cond = anno.getanno(node, 'extra_cond')
- extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+ if anno.hasanno(node, 'extra_test'):
+ extra_test = anno.getanno(node, 'extra_test')
+ extra_test = ast_util.rename_symbols(extra_test, ssf_map)
else:
- extra_cond = parser.parse_expression('True')
+ extra_test = parser.parse_expression('True')
template = """
- def extra_cond_name(state_ssf):
- return extra_cond_expr
+ def extra_test_name(state_ssf):
+ return extra_test_expr
def body_name(iterate, state_ssf):
body
return state_ssf,
- state_ast_tuple = ag__.for_loop(
- iterated, extra_cond_name, body_name, (state,))
+ state_ast_tuple = ag__.for_stmt(
+ iter_, extra_test_name, body_name, (state,))
"""
node = templates.replace(
template,
state=state,
state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
- iterated=node.iter,
+ iter_=node.iter,
iterate=node.target,
- extra_cond_name=self.context.namer.new_symbol('extra_cond',
+ extra_test_name=self.context.namer.new_symbol('extra_test',
all_referenced),
- extra_cond_expr=extra_cond,
+ extra_test_expr=extra_test,
body_name=self.context.namer.new_symbol('loop_body', all_referenced),
body=node_body)
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
-# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature.
-# TODO(mdan): Rename arguments to match the AST names.
-
-def for_loop(iterated, extra_cond, loop_body, init_state):
+def for_stmt(iter_, extra_test, body, init_state):
"""Functional form of a for statement.
- The loop operates on a so-called state, which includes all symbols that are
- variant across loop iterations, excluding the iterate. In what follows we
- refer to state as either a tuple of entities that represent an actual state,
- or a list of arguments of the corresponding types.
+ The loop operates on a state, which includes all symbols that are
+ variant across loop iterations, excluding the iterate as well as the
+ variables local to the loop.
+
+ For example, given the loop below that calculates the geometric and
+ arithmetic means or some numbers:
+
+ geo_mean = 1
+ arith_mean = 0
+ for i in range(n):
+ a = numbers[i]
+ geo_mean *= a
+ arith_mean += a
+
+ The state is represented by the variables geo_mean and arith_mean. The
+ argument for initial_state may contain the tuple (1, 0), the body will
+ include the arguments geo_mean and arith_mean and will return a tuple
+ representing the new values for geo_mean and respectively arith_mean.
Args:
- iterated: The entity being iterated over.
- extra_cond: Callable with the state as arguments, and boolean return type.
+ iter_: The entity being iterated over.
+ extra_test: Callable with the state as arguments, and boolean return type.
An additionnal loop condition.
- loop_body: Callable with the iterate and the state as arguments, and
+ body: Callable with the iterate and the state as arguments, and
state as return type. The actual loop body.
init_state: Tuple containing the initial state.
Returns:
Tuple containing the final state.
"""
- if tensor_util.is_tensor(iterated):
- return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
- elif isinstance(iterated, dataset_ops.Dataset):
- return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+ if tensor_util.is_tensor(iter_):
+ return _known_len_for_stmt(iter_, extra_test, body, init_state)
+ elif isinstance(iter_, dataset_ops.Dataset):
+ return _dataset_for_stmt(iter_, extra_test, body, init_state)
else:
- return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+ return _py_for_stmt(iter_, extra_test, body, init_state)
-def _py_for_loop(iterated, extra_cond, loop_body, init_state):
- """Overload of for_loop that executes a Python for loop."""
+def _py_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that executes a Python for loop."""
state = init_state
- for iterate in iterated:
- if not extra_cond(*state):
+ for target in iter_:
+ if not extra_test(*state):
break
- state = loop_body(iterate, *state)
+ state = body(target, *state)
# TODO(mdan): Remove this special case.
if len(state) == 1:
return state
-def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
- """Overload of for_loop that iterates over objects that define a length."""
- n = builtins.dynamic_len(iterated)
+def _known_len_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over objects that define a length."""
+ n = builtins.dynamic_len(iter_)
def while_body(iterate_index, *state):
- iterate = iterated[iterate_index]
- new_state = loop_body(iterate, *state)
+ iterate = iter_[iterate_index]
+ new_state = body(iterate, *state)
return (iterate_index + 1,) + new_state
def while_cond(iterate_index, *state):
- return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+ return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))
- results = while_loop(
+ results = while_stmt(
while_cond,
while_body,
init_state=(0,) + init_state,
- extra_deps=(iterated,),
+ extra_deps=(iter_,),
opts=dict(maximum_iterations=n))
# Dropping the iteration index because it's not syntactically visible.
results = results[1:]
return results
-def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
- """Overload of for_loop that iterates over TF Datasets."""
+def _dataset_for_stmt(ds, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over TF Datasets."""
# Because Datsets only expose get_next, in the style of Python iterators,
# we are forced to unpack the loop as:
#
epoch_number, iterate = iterator.get_next()
def while_body(epoch_number, iterate, *state):
- new_state = loop_body(iterate, *state)
+ new_state = body(iterate, *state)
epoch_number, iterate = iterator.get_next()
return (epoch_number, iterate) + new_state
def while_cond(epoch_number, iterate, *state):
del iterate
- return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+ return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state))
- results = while_loop(
+ results = while_stmt(
while_cond,
while_body,
init_state=(epoch_number, iterate) + init_state,
return results
-def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
+def while_stmt(test, body, init_state, extra_deps, opts=None):
"""Functional form of a while statement.
The loop operates on a so-called state, which includes all symbols that are
of the corresponding types.
Args:
- loop_cond: Callable with the state as arguments, and boolean return type.
+ test: Callable with the state as arguments, and boolean return type.
The loop condition.
- loop_body: Callable with the state as arguments, and state as return type.
+ body: Callable with the state as arguments, and state as return type.
The actual loop body.
init_state: Tuple containing the initial state.
extra_deps: Tuple containing additional entities on which the loop may
- depend, such as loop invariants referenced by loop_cond. Used
+ depend, such as loop invariants referenced by test. Used
exclusively for dispatch control.
opts: Optional dict of extra loop parameters.
# That could be somethins as simple as a collection of dispatch rules, with
# some prioritization.
if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
- return _tf_while_loop(loop_cond, loop_body, init_state, opts)
+ return _tf_while_stmt(test, body, init_state, opts)
else:
- return _py_while_loop(loop_cond, loop_body, init_state, opts)
+ return _py_while_stmt(test, body, init_state, opts)
-def _tf_while_loop(loop_cond, loop_body, init_state, opts):
- """Overload of while_loop that stages a TF while_loop."""
+def _tf_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that stages a TF while_stmt."""
if opts is None:
opts = {}
- return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts)
+ return control_flow_ops.while_loop(test, body, init_state, **opts)
-def _py_while_loop(loop_cond, loop_body, init_state, opts):
- """Overload of while_loop that executes a Python while loop."""
+def _py_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that executes a Python while loop."""
del opts
state = init_state
- while loop_cond(*state):
- state = loop_body(*state)
+ while test(*state):
+ state = body(*state)
return state
class ForLoopTest(test.TestCase):
def test_tensor(self):
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
constant_op.constant([1, 2, 3, 4]),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
with self.test_session() as sess:
self.assertEqual((10,), sess.run(s))
def test_python(self):
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
range(5),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
self.assertEqual(10, s)
def test_dataset(self):
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
dataset_ops.Dataset.range(5).map(to_int32),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
with self.test_session() as sess:
self.assertEqual((10,), sess.run(s))
def test_tensor(self):
n = constant_op.constant(5)
- results = control_flow.while_loop(
- loop_cond=lambda i, s: i < n,
- loop_body=lambda i, s: (i + 1, s + i,),
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i,),
init_state=(0, 0),
extra_deps=(n,))
with self.test_session() as sess:
def test_python(self):
n = 5
- results = control_flow.while_loop(
- loop_cond=lambda i, s: i < n,
- loop_body=lambda i, s: (i + 1, s + i),
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i),
init_state=(0, 0),
extra_deps=(n,))
self.assertEqual((5, 10), results)