Minor refactor: establish some operator naming conventions and apply them, so that...
authorDan Moldovan <mdan@google.com>
Wed, 2 May 2018 02:05:39 +0000 (19:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 2 May 2018 02:08:50 +0000 (19:08 -0700)
PiperOrigin-RevId: 195034691

tensorflow/contrib/autograph/converters/break_statements.py
tensorflow/contrib/autograph/converters/control_flow.py
tensorflow/contrib/autograph/operators/__init__.py
tensorflow/contrib/autograph/operators/control_flow.py
tensorflow/contrib/autograph/operators/control_flow_test.py

index 91de82f..1be1c96 100644 (file)
@@ -114,9 +114,9 @@ class BreakStatementTransformer(transformer.Base):
           template,
           var_name=break_var,
           for_stmt=node)
-      extra_cond = templates.replace_as_expression(
+      extra_test = templates.replace_as_expression(
           'not var_name', var_name=break_var)
-      anno.setanno(node[1], 'extra_cond', extra_cond)
+      anno.setanno(node[1], 'extra_test', extra_test)
 
     return node
 
index 2e26cdb..935a278 100644 (file)
@@ -207,7 +207,7 @@ class ControlFlowTransformer(transformer.Base):
       def body_name(state_ssf):
         body
         return state_ssf,
-      state_ast_tuple = ag__.while_loop(
+      state_ast_tuple = ag__.while_stmt(
           test_name, body_name, (state,), (extra_deps,))
     """
     node = templates.replace(
@@ -252,31 +252,31 @@ class ControlFlowTransformer(transformer.Base):
       state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
 
     node_body = ast_util.rename_symbols(node.body, ssf_map)
-    if anno.hasanno(node, 'extra_cond'):
-      extra_cond = anno.getanno(node, 'extra_cond')
-      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+    if anno.hasanno(node, 'extra_test'):
+      extra_test = anno.getanno(node, 'extra_test')
+      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
     else:
-      extra_cond = parser.parse_expression('True')
+      extra_test = parser.parse_expression('True')
 
     template = """
-      def extra_cond_name(state_ssf):
-        return extra_cond_expr
+      def extra_test_name(state_ssf):
+        return extra_test_expr
       def body_name(iterate, state_ssf):
         body
         return state_ssf,
-      state_ast_tuple = ag__.for_loop(
-          iterated, extra_cond_name, body_name, (state,))
+      state_ast_tuple = ag__.for_stmt(
+          iter_, extra_test_name, body_name, (state,))
     """
     node = templates.replace(
         template,
         state=state,
         state_ssf=state_ssf,
         state_ast_tuple=state_ast_tuple,
-        iterated=node.iter,
+        iter_=node.iter,
         iterate=node.target,
-        extra_cond_name=self.context.namer.new_symbol('extra_cond',
+        extra_test_name=self.context.namer.new_symbol('extra_test',
                                                       all_referenced),
-        extra_cond_expr=extra_cond,
+        extra_test_expr=extra_test,
         body_name=self.context.namer.new_symbol('loop_body', all_referenced),
         body=node_body)
 
index 04b4734..38b761d 100644 (file)
@@ -19,11 +19,19 @@ conditionals and loops, implemented in functional form, using for example
 closures for the body.
 """
 
+# Naming conventions:
+#  * operator names match the name usually used for the respective Python
+#    idiom; examples: for_stmt, list_append
+#  * operator arguments match either of:
+#    - the corresponding Python AST attribute (e.g. the condition of an if
+#      statement is called test) if the operator represents an AST construct
+#    - the names used in the Python docs, if the operator is a function (e.g.
+#      list_ and x for append, see
+#      https://docs.python.org/3.7/tutorial/datastructures.html)
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# TODO(mdan): Add a container for implementation-specific toggles (throughout).
-
-from tensorflow.contrib.autograph.operators.control_flow import for_loop
-from tensorflow.contrib.autograph.operators.control_flow import while_loop
+from tensorflow.contrib.autograph.operators.control_flow import for_stmt
+from tensorflow.contrib.autograph.operators.control_flow import while_stmt
index d9d8b0d..9f72028 100644 (file)
@@ -25,44 +25,55 @@ from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_math_ops
 
-# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature.
-# TODO(mdan): Rename arguments to match the AST names.
 
-
-def for_loop(iterated, extra_cond, loop_body, init_state):
+def for_stmt(iter_, extra_test, body, init_state):
   """Functional form of a for statement.
 
-  The loop operates on a so-called state, which includes all symbols that are
-  variant across loop iterations, excluding the iterate. In what follows we
-  refer to state as either a tuple of entities that represent an actual state,
-  or a list of arguments of the corresponding types.
+  The loop operates on a state, which includes all symbols that are
+  variant across loop iterations, excluding the iterate as well as the
+  variables local to the loop.
+
+  For example, given the loop below that calculates the geometric and
+  arithmetic means or some numbers:
+
+    geo_mean = 1
+    arith_mean = 0
+    for i in range(n):
+      a = numbers[i]
+      geo_mean *= a
+      arith_mean += a
+
+  The state is represented by the variables geo_mean and arith_mean. The
+  argument for initial_state may contain the tuple (1, 0), the body will
+  include the arguments geo_mean and arith_mean and will return a tuple
+  representing the new values for geo_mean and respectively arith_mean.
 
   Args:
-    iterated: The entity being iterated over.
-    extra_cond: Callable with the state as arguments, and boolean return type.
+    iter_: The entity being iterated over.
+    extra_test: Callable with the state as arguments, and boolean return type.
         An additionnal loop condition.
-    loop_body: Callable with the iterate and the state as arguments, and
+    body: Callable with the iterate and the state as arguments, and
         state as return type. The actual loop body.
     init_state: Tuple containing the initial state.
 
   Returns:
     Tuple containing the final state.
   """
-  if tensor_util.is_tensor(iterated):
-    return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
-  elif isinstance(iterated, dataset_ops.Dataset):
-    return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+  if tensor_util.is_tensor(iter_):
+    return _known_len_for_stmt(iter_, extra_test, body, init_state)
+  elif isinstance(iter_, dataset_ops.Dataset):
+    return _dataset_for_stmt(iter_, extra_test, body, init_state)
   else:
-    return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+    return _py_for_stmt(iter_, extra_test, body, init_state)
 
 
-def _py_for_loop(iterated, extra_cond, loop_body, init_state):
-  """Overload of for_loop that executes a Python for loop."""
+def _py_for_stmt(iter_, extra_test, body, init_state):
+  """Overload of for_stmt that executes a Python for loop."""
   state = init_state
-  for iterate in iterated:
-    if not extra_cond(*state):
+  for target in iter_:
+    if not extra_test(*state):
       break
-    state = loop_body(iterate, *state)
+    state = body(target, *state)
 
   # TODO(mdan): Remove this special case.
   if len(state) == 1:
@@ -70,23 +81,23 @@ def _py_for_loop(iterated, extra_cond, loop_body, init_state):
   return state
 
 
-def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
-  """Overload of for_loop that iterates over objects that define a length."""
-  n = builtins.dynamic_len(iterated)
+def _known_len_for_stmt(iter_, extra_test, body, init_state):
+  """Overload of for_stmt that iterates over objects that define a length."""
+  n = builtins.dynamic_len(iter_)
 
   def while_body(iterate_index, *state):
-    iterate = iterated[iterate_index]
-    new_state = loop_body(iterate, *state)
+    iterate = iter_[iterate_index]
+    new_state = body(iterate, *state)
     return (iterate_index + 1,) + new_state
 
   def while_cond(iterate_index, *state):
-    return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+    return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))
 
-  results = while_loop(
+  results = while_stmt(
       while_cond,
       while_body,
       init_state=(0,) + init_state,
-      extra_deps=(iterated,),
+      extra_deps=(iter_,),
       opts=dict(maximum_iterations=n))
   # Dropping the iteration index because it's not syntactically visible.
   results = results[1:]
@@ -97,8 +108,8 @@ def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
   return results
 
 
-def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
-  """Overload of for_loop that iterates over TF Datasets."""
+def _dataset_for_stmt(ds, extra_test, body, init_state):
+  """Overload of for_stmt that iterates over TF Datasets."""
   # Because Datsets only expose get_next, in the style of Python iterators,
   # we are forced to unpack the loop as:
   #
@@ -117,15 +128,15 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
     epoch_number, iterate = iterator.get_next()
 
     def while_body(epoch_number, iterate, *state):
-      new_state = loop_body(iterate, *state)
+      new_state = body(iterate, *state)
       epoch_number, iterate = iterator.get_next()
       return (epoch_number, iterate) + new_state
 
     def while_cond(epoch_number, iterate, *state):
       del iterate
-      return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+      return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state))
 
-    results = while_loop(
+    results = while_stmt(
         while_cond,
         while_body,
         init_state=(epoch_number, iterate) + init_state,
@@ -140,7 +151,7 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
   return results
 
 
-def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
+def while_stmt(test, body, init_state, extra_deps, opts=None):
   """Functional form of a while statement.
 
   The loop operates on a so-called state, which includes all symbols that are
@@ -149,13 +160,13 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
   of the corresponding types.
 
   Args:
-    loop_cond: Callable with the state as arguments, and boolean return type.
+    test: Callable with the state as arguments, and boolean return type.
         The loop condition.
-    loop_body: Callable with the state as arguments, and state as return type.
+    body: Callable with the state as arguments, and state as return type.
         The actual loop body.
     init_state: Tuple containing the initial state.
     extra_deps: Tuple containing additional entities on which the loop may
-        depend, such as loop invariants referenced by loop_cond. Used
+        depend, such as loop invariants referenced by test. Used
         exclusively for dispatch control.
     opts: Optional dict of extra loop parameters.
 
@@ -166,24 +177,24 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
   # That could be somethins as simple as a collection of dispatch rules, with
   # some prioritization.
   if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
-    return _tf_while_loop(loop_cond, loop_body, init_state, opts)
+    return _tf_while_stmt(test, body, init_state, opts)
   else:
-    return _py_while_loop(loop_cond, loop_body, init_state, opts)
+    return _py_while_stmt(test, body, init_state, opts)
 
 
-def _tf_while_loop(loop_cond, loop_body, init_state, opts):
-  """Overload of while_loop that stages a TF while_loop."""
+def _tf_while_stmt(test, body, init_state, opts):
+  """Overload of while_stmt that stages a TF while_stmt."""
   if opts is None:
     opts = {}
-  return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts)
+  return control_flow_ops.while_loop(test, body, init_state, **opts)
 
 
-def _py_while_loop(loop_cond, loop_body, init_state, opts):
-  """Overload of while_loop that executes a Python while loop."""
+def _py_while_stmt(test, body, init_state, opts):
+  """Overload of while_stmt that executes a Python while loop."""
   del opts
   state = init_state
-  while loop_cond(*state):
-    state = loop_body(*state)
+  while test(*state):
+    state = body(*state)
   return state
 
 
index a0cd0bf..b14d7ed 100644 (file)
@@ -29,28 +29,28 @@ from tensorflow.python.platform import test
 class ForLoopTest(test.TestCase):
 
   def test_tensor(self):
-    s = control_flow.for_loop(
+    s = control_flow.for_stmt(
         constant_op.constant([1, 2, 3, 4]),
-        extra_cond=lambda s: True,
-        loop_body=lambda i, s: (s + i,),
+        extra_test=lambda s: True,
+        body=lambda i, s: (s + i,),
         init_state=(0,))
     with self.test_session() as sess:
       self.assertEqual((10,), sess.run(s))
 
   def test_python(self):
-    s = control_flow.for_loop(
+    s = control_flow.for_stmt(
         range(5),
-        extra_cond=lambda s: True,
-        loop_body=lambda i, s: (s + i,),
+        extra_test=lambda s: True,
+        body=lambda i, s: (s + i,),
         init_state=(0,))
     self.assertEqual(10, s)
 
   def test_dataset(self):
     to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
-    s = control_flow.for_loop(
+    s = control_flow.for_stmt(
         dataset_ops.Dataset.range(5).map(to_int32),
-        extra_cond=lambda s: True,
-        loop_body=lambda i, s: (s + i,),
+        extra_test=lambda s: True,
+        body=lambda i, s: (s + i,),
         init_state=(0,))
     with self.test_session() as sess:
       self.assertEqual((10,), sess.run(s))
@@ -60,9 +60,9 @@ class WhileLoopTest(test.TestCase):
 
   def test_tensor(self):
     n = constant_op.constant(5)
-    results = control_flow.while_loop(
-        loop_cond=lambda i, s: i < n,
-        loop_body=lambda i, s: (i + 1, s + i,),
+    results = control_flow.while_stmt(
+        test=lambda i, s: i < n,
+        body=lambda i, s: (i + 1, s + i,),
         init_state=(0, 0),
         extra_deps=(n,))
     with self.test_session() as sess:
@@ -70,9 +70,9 @@ class WhileLoopTest(test.TestCase):
 
   def test_python(self):
     n = 5
-    results = control_flow.while_loop(
-        loop_cond=lambda i, s: i < n,
-        loop_body=lambda i, s: (i + 1, s + i),
+    results = control_flow.while_stmt(
+        test=lambda i, s: i < n,
+        body=lambda i, s: (i + 1, s + i),
         init_state=(0, 0),
         extra_deps=(n,))
     self.assertEqual((5, 10), results)