From b50f6325143486eb82b5654f8794f0771b54dd4d Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 1 May 2018 19:05:39 -0700 Subject: [PATCH] Minor refactor: establish some operator naming conventions and apply them, so that the interface is a bit more consistent. PiperOrigin-RevId: 195034691 --- .../autograph/converters/break_statements.py | 4 +- .../contrib/autograph/converters/control_flow.py | 24 ++--- tensorflow/contrib/autograph/operators/__init__.py | 16 +++- .../contrib/autograph/operators/control_flow.py | 105 ++++++++++++--------- .../autograph/operators/control_flow_test.py | 30 +++--- 5 files changed, 99 insertions(+), 80 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 91de82f..1be1c96 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -114,9 +114,9 @@ class BreakStatementTransformer(transformer.Base): template, var_name=break_var, for_stmt=node) - extra_cond = templates.replace_as_expression( + extra_test = templates.replace_as_expression( 'not var_name', var_name=break_var) - anno.setanno(node[1], 'extra_cond', extra_cond) + anno.setanno(node[1], 'extra_test', extra_test) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 2e26cdb..935a278 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -207,7 +207,7 @@ class ControlFlowTransformer(transformer.Base): def body_name(state_ssf): body return state_ssf, - state_ast_tuple = ag__.while_loop( + state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( @@ -252,31 +252,31 @@ class ControlFlowTransformer(transformer.Base): state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) - if anno.hasanno(node, 'extra_cond'): - extra_cond = anno.getanno(node, 'extra_cond') - extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) + if anno.hasanno(node, 'extra_test'): + extra_test = anno.getanno(node, 'extra_test') + extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: - extra_cond = parser.parse_expression('True') + extra_test = parser.parse_expression('True') template = """ - def extra_cond_name(state_ssf): - return extra_cond_expr + def extra_test_name(state_ssf): + return extra_test_expr def body_name(iterate, state_ssf): body return state_ssf, - state_ast_tuple = ag__.for_loop( - iterated, extra_cond_name, body_name, (state,)) + state_ast_tuple = ag__.for_stmt( + iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - iterated=node.iter, + iter_=node.iter, iterate=node.target, - extra_cond_name=self.context.namer.new_symbol('extra_cond', + extra_test_name=self.context.namer.new_symbol('extra_test', all_referenced), - extra_cond_expr=extra_cond, + extra_test_expr=extra_test, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 04b4734..38b761d 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -19,11 +19,19 @@ conditionals and loops, implemented in functional form, using for example closures for the body. """ +# Naming conventions: +# * operator names match the name usually used for the respective Python +# idiom; examples: for_stmt, list_append +# * operator arguments match either of: +# - the corresponding Python AST attribute (e.g. the condition of an if +# statement is called test) if the operator represents an AST construct +# - the names used in the Python docs, if the operator is a function (e.g. +# list_ and x for append, see +# https://docs.python.org/3.7/tutorial/datastructures.html) + from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Add a container for implementation-specific toggles (throughout). - -from tensorflow.contrib.autograph.operators.control_flow import for_loop -from tensorflow.contrib.autograph.operators.control_flow import while_loop +from tensorflow.contrib.autograph.operators.control_flow import for_stmt +from tensorflow.contrib.autograph.operators.control_flow import while_stmt diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index d9d8b0d..9f72028 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -25,44 +25,55 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops -# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature. -# TODO(mdan): Rename arguments to match the AST names. - -def for_loop(iterated, extra_cond, loop_body, init_state): +def for_stmt(iter_, extra_test, body, init_state): """Functional form of a for statement. - The loop operates on a so-called state, which includes all symbols that are - variant across loop iterations, excluding the iterate. In what follows we - refer to state as either a tuple of entities that represent an actual state, - or a list of arguments of the corresponding types. + The loop operates on a state, which includes all symbols that are + variant across loop iterations, excluding the iterate as well as the + variables local to the loop. + + For example, given the loop below that calculates the geometric and + arithmetic means or some numbers: + + geo_mean = 1 + arith_mean = 0 + for i in range(n): + a = numbers[i] + geo_mean *= a + arith_mean += a + + The state is represented by the variables geo_mean and arith_mean. The + argument for initial_state may contain the tuple (1, 0), the body will + include the arguments geo_mean and arith_mean and will return a tuple + representing the new values for geo_mean and respectively arith_mean. Args: - iterated: The entity being iterated over. - extra_cond: Callable with the state as arguments, and boolean return type. + iter_: The entity being iterated over. + extra_test: Callable with the state as arguments, and boolean return type. An additionnal loop condition. - loop_body: Callable with the iterate and the state as arguments, and + body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. Returns: Tuple containing the final state. """ - if tensor_util.is_tensor(iterated): - return _known_len_for_loop(iterated, extra_cond, loop_body, init_state) - elif isinstance(iterated, dataset_ops.Dataset): - return _dataset_for_loop(iterated, extra_cond, loop_body, init_state) + if tensor_util.is_tensor(iter_): + return _known_len_for_stmt(iter_, extra_test, body, init_state) + elif isinstance(iter_, dataset_ops.Dataset): + return _dataset_for_stmt(iter_, extra_test, body, init_state) else: - return _py_for_loop(iterated, extra_cond, loop_body, init_state) + return _py_for_stmt(iter_, extra_test, body, init_state) -def _py_for_loop(iterated, extra_cond, loop_body, init_state): - """Overload of for_loop that executes a Python for loop.""" +def _py_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that executes a Python for loop.""" state = init_state - for iterate in iterated: - if not extra_cond(*state): + for target in iter_: + if not extra_test(*state): break - state = loop_body(iterate, *state) + state = body(target, *state) # TODO(mdan): Remove this special case. if len(state) == 1: @@ -70,23 +81,23 @@ def _py_for_loop(iterated, extra_cond, loop_body, init_state): return state -def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): - """Overload of for_loop that iterates over objects that define a length.""" - n = builtins.dynamic_len(iterated) +def _known_len_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that iterates over objects that define a length.""" + n = builtins.dynamic_len(iter_) def while_body(iterate_index, *state): - iterate = iterated[iterate_index] - new_state = loop_body(iterate, *state) + iterate = iter_[iterate_index] + new_state = body(iterate, *state) return (iterate_index + 1,) + new_state def while_cond(iterate_index, *state): - return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state)) + return gen_math_ops.logical_and(iterate_index < n, extra_test(*state)) - results = while_loop( + results = while_stmt( while_cond, while_body, init_state=(0,) + init_state, - extra_deps=(iterated,), + extra_deps=(iter_,), opts=dict(maximum_iterations=n)) # Dropping the iteration index because it's not syntactically visible. results = results[1:] @@ -97,8 +108,8 @@ def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): return results -def _dataset_for_loop(ds, extra_cond, loop_body, init_state): - """Overload of for_loop that iterates over TF Datasets.""" +def _dataset_for_stmt(ds, extra_test, body, init_state): + """Overload of for_stmt that iterates over TF Datasets.""" # Because Datsets only expose get_next, in the style of Python iterators, # we are forced to unpack the loop as: # @@ -117,15 +128,15 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state): epoch_number, iterate = iterator.get_next() def while_body(epoch_number, iterate, *state): - new_state = loop_body(iterate, *state) + new_state = body(iterate, *state) epoch_number, iterate = iterator.get_next() return (epoch_number, iterate) + new_state def while_cond(epoch_number, iterate, *state): del iterate - return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state)) + return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state)) - results = while_loop( + results = while_stmt( while_cond, while_body, init_state=(epoch_number, iterate) + init_state, @@ -140,7 +151,7 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state): return results -def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): +def while_stmt(test, body, init_state, extra_deps, opts=None): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are @@ -149,13 +160,13 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): of the corresponding types. Args: - loop_cond: Callable with the state as arguments, and boolean return type. + test: Callable with the state as arguments, and boolean return type. The loop condition. - loop_body: Callable with the state as arguments, and state as return type. + body: Callable with the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. extra_deps: Tuple containing additional entities on which the loop may - depend, such as loop invariants referenced by loop_cond. Used + depend, such as loop invariants referenced by test. Used exclusively for dispatch control. opts: Optional dict of extra loop parameters. @@ -166,24 +177,24 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): # That could be somethins as simple as a collection of dispatch rules, with # some prioritization. if any(tensor_util.is_tensor(v) for v in init_state + extra_deps): - return _tf_while_loop(loop_cond, loop_body, init_state, opts) + return _tf_while_stmt(test, body, init_state, opts) else: - return _py_while_loop(loop_cond, loop_body, init_state, opts) + return _py_while_stmt(test, body, init_state, opts) -def _tf_while_loop(loop_cond, loop_body, init_state, opts): - """Overload of while_loop that stages a TF while_loop.""" +def _tf_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that stages a TF while_stmt.""" if opts is None: opts = {} - return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts) + return control_flow_ops.while_loop(test, body, init_state, **opts) -def _py_while_loop(loop_cond, loop_body, init_state, opts): - """Overload of while_loop that executes a Python while loop.""" +def _py_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that executes a Python while loop.""" del opts state = init_state - while loop_cond(*state): - state = loop_body(*state) + while test(*state): + state = body(*state) return state diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py index a0cd0bf..b14d7ed 100644 --- a/tensorflow/contrib/autograph/operators/control_flow_test.py +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -29,28 +29,28 @@ from tensorflow.python.platform import test class ForLoopTest(test.TestCase): def test_tensor(self): - s = control_flow.for_loop( + s = control_flow.for_stmt( constant_op.constant([1, 2, 3, 4]), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) with self.test_session() as sess: self.assertEqual((10,), sess.run(s)) def test_python(self): - s = control_flow.for_loop( + s = control_flow.for_stmt( range(5), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) self.assertEqual(10, s) def test_dataset(self): to_int32 = lambda i: math_ops.cast(i, dtypes.int32) - s = control_flow.for_loop( + s = control_flow.for_stmt( dataset_ops.Dataset.range(5).map(to_int32), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) with self.test_session() as sess: self.assertEqual((10,), sess.run(s)) @@ -60,9 +60,9 @@ class WhileLoopTest(test.TestCase): def test_tensor(self): n = constant_op.constant(5) - results = control_flow.while_loop( - loop_cond=lambda i, s: i < n, - loop_body=lambda i, s: (i + 1, s + i,), + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i,), init_state=(0, 0), extra_deps=(n,)) with self.test_session() as sess: @@ -70,9 +70,9 @@ class WhileLoopTest(test.TestCase): def test_python(self): n = 5 - results = control_flow.while_loop( - loop_cond=lambda i, s: i < n, - loop_body=lambda i, s: (i + 1, s + i), + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i), init_state=(0, 0), extra_deps=(n,)) self.assertEqual((5, 10), results) -- 2.7.4