Simplify the template mechanism by specifying templates using multi-line strings...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 25 Jan 2018 18:36:25 +0000 (10:36 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 25 Jan 2018 18:39:58 +0000 (10:39 -0800)
Addresses #16318

PiperOrigin-RevId: 183260854

tensorflow/contrib/py2tf/converters/break_canonicalization.py
tensorflow/contrib/py2tf/converters/builtin_functions.py
tensorflow/contrib/py2tf/converters/call_trees.py
tensorflow/contrib/py2tf/converters/continue_canonicalization.py
tensorflow/contrib/py2tf/converters/control_flow.py
tensorflow/contrib/py2tf/converters/for_canonicalization.py
tensorflow/contrib/py2tf/converters/side_effect_guards.py
tensorflow/contrib/py2tf/pyct/templates.py
tensorflow/contrib/py2tf/pyct/templates_test.py

index ef585734454db1aa1ffdb798d93978fb09752f05..2ae65e3007466409433e9b4ea0081898907e19ac 100644 (file)
@@ -33,31 +33,25 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer):
     self.break_uses = []
 
   def _create_break_check(self):
-
-    def template(var_name):
-      (not var_name)  # pylint:disable=pointless-statement
-
-    expr, = templates.replace(
-        template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+    template = """
+      (not var_name)
+    """
+    expr, = templates.replace(template, var_name=self.break_uses[-1][1])
     return expr.value
 
   def _create_break_trigger(self):
-
-    def template(var_name):  # pylint:disable=unused-argument
+    template = """
       var_name = True
-
-    block = templates.replace(
-        template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+    """
+    block = templates.replace(template, var_name=self.break_uses[-1][1])
     block.append(gast.Continue())
     return block
 
   def _create_break_init(self):
-
-    def template(var_name):  # pylint:disable=unused-argument
+    template = """
       var_name = False
-
-    assign, = templates.replace(
-        template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+    """
+    assign, = templates.replace(template, var_name=self.break_uses[-1][1])
     return assign
 
   # TODO(mdan): Surely the transformer supports this better?
index b80c96c97ac0c55f449a83bd43f2b65cdbdba390..7f6b64a34c1b95f0dd6b92dbc587da672e6c9c28 100644 (file)
@@ -29,10 +29,9 @@ class BuiltinFunctionTransformer(gast.NodeTransformer):
   # TODO(mdan): Bring print_functions in here.
 
   def _convert_len(self, node):
-
-    def template(args):
-      tf.shape(args)[0]  # pylint:disable=undefined-variable,expression-not-assigned
-
+    template = """
+      tf.shape(args)[0]
+    """
     new_call = templates.replace(template, args=node.args)[0].value
     return new_call
 
index df071f596fc31502a98182f27bb66c54f71d2572..0aae030450ae2b981328f604bfddec2f25e13ec4 100644 (file)
@@ -151,7 +151,7 @@ class CallTreeTransformer(gast.NodeTransformer):
     else:
       new_name = self.namer.compiled_function_name(
           '__'.join(target_fqn), live_object=target_obj)
-    node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
+    node.func = gast.Name(new_name, gast.Load(), None)
     return node
 
   def _rename_member_function_of_known_type(self, node):
@@ -184,26 +184,17 @@ class CallTreeTransformer(gast.NodeTransformer):
   def _wrap_to_py_func_no_return(self, node):
     args_scope = anno.getanno(node, 'args_scope')
     # TODO(mdan): Properly handle varargs, kwargs, etc.
-    args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
-
-    # pylint:disable=undefined-variable,unused-argument,function-redefined
-
-    def template(call, wrapper, args):
-
+    template = """
       def wrapper(args):
         call(args)
         return 1
-
       tf.py_func(wrapper, [args], [tf.int64])
-
-    # pylint:enable=undefined-variable,unused-argument,function-redefined
-
-    wrapper_name = self.namer.compiled_function_name(node.func.id)
+    """
     wrapper_def, call_expr = templates.replace(
         template,
         call=node.func,
-        wrapper=gast.Name(wrapper_name, gast.Load(), None),
-        args=args)
+        wrapper=self.namer.compiled_function_name(node.func.id),
+        args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
     anno.setanno(call_expr.value, 'args_scope', args_scope)
     # TODO(mdan): Rename this annotation to 'graph_ready'
     anno.setanno(wrapper_def, 'skip_processing', True)
index 7f8ace77a830ebcc4d49fcf2190e4bac920b1cde..486f0f6509d67d9d981e43ea6e5c77d14e6b23fc 100644 (file)
@@ -33,32 +33,28 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
     self.continuation_uses = []
 
   def _create_continuation_check(self):
-
-    def template(var_name):
+    template = """
       if not var_name:
         pass
-
-    cond, = templates.replace(
-        template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+    """
+    cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
     cond.body = []
     return cond
 
   def _create_continuation_trigger(self):
-
-    def template(var_name):  # pylint:disable=unused-argument
+    template = """
       var_name = True
-
+    """
     assign, = templates.replace(
-        template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+        template, var_name=self.continuation_uses[-1][1])
     return assign
 
   def _create_continuation_init(self):
-
-    def template(var_name):  # pylint:disable=unused-argument
+    template = """
       var_name = False
-
+    """
     assign, = templates.replace(
-        template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+        template, var_name=self.continuation_uses[-1][1])
     return assign
 
   def _visit_and_reindent_if_necessary(self, nodes):
index 8ebd9ad93dbc17814d1d7f53c3eac2e078030141..a40c7b28f7bc3b8483b0b18cf11dbf99456df645 100644 (file)
@@ -75,29 +75,6 @@ class ControlFlowTransformer(gast.NodeTransformer):
       raise ValueError(
           'The else branch creates new symbols that the if branch does not.')
 
-    def template(  # pylint:disable=missing-docstring
-        test,
-        body_name,
-        body,
-        orelse_name,
-        orelse,
-        aliased,
-        aliases,  # pylint:disable=unused-argument
-        aliased_results,
-        results):  # pylint:disable=unused-argument
-
-      def body_name():  # pylint:disable=function-redefined
-        aliases, = aliased,  # pylint:disable=unused-variable
-        body  # pylint:disable=pointless-statement
-        return (aliased_results,)
-
-      def orelse_name():  # pylint:disable=function-redefined
-        aliases, = aliased,  # pylint:disable=unused-variable
-        orelse  # pylint:disable=pointless-statement
-        return (aliased_results,)
-
-      results = tf.cond(test, body_name, orelse_name)  # pylint:disable=undefined-variable
-
     all_modified = tuple(body_scope.modified | orelse_scope.modified)
     all_referenced = body_scope.referenced | orelse_scope.referenced
 
@@ -107,10 +84,10 @@ class ControlFlowTransformer(gast.NodeTransformer):
     need_alias = (
         (body_scope.modified | orelse_scope.modified) -
         (body_scope.created | orelse_scope.created))
-    aliased = tuple(need_alias)
-    aliases = tuple(
-        self.namer.new_symbol(s, all_referenced) for s in aliased)
-    alias_map = dict(zip(aliased, aliases))
+    aliased_orig_names = tuple(need_alias)
+    aliased_new_names = tuple(
+        self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names)
+    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
     node_body = node.body
     node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
     node_orelse = node.orelse
@@ -122,20 +99,29 @@ class ControlFlowTransformer(gast.NodeTransformer):
       results = gast.Tuple(
           tuple(gast.Name(s, None, None) for s in all_modified), None)
 
+    template = """
+      def body_name():
+        aliased_new_names, = aliased_orig_names,
+        body
+        return (all_results,)
+      def orelse_name():
+        aliased_new_names, = aliased_orig_names,
+        orelse
+        return (all_results,)
+      results = tf.cond(test, body_name, orelse_name)
+    """
+    body_name = self.namer.new_symbol('if_true', all_referenced)
     return templates.replace(
         template,
         test=node.test,
-        body_name=gast.Name(
-            self.namer.new_symbol('if_true', all_referenced), None, None),
+        body_name=body_name,
         body=node_body,
-        orelse_name=gast.Name(
-            self.namer.new_symbol('if_false', all_referenced), None, None),
+        orelse_name=self.namer.new_symbol('if_false', all_referenced),
         orelse=node_orelse,
-        aliased=tuple(gast.Name(s, None, None) for s in aliased),
-        aliases=tuple(gast.Name(s, None, None) for s in aliases),
-        aliased_results=tuple(
-            gast.Name(alias_map[s] if s in aliased else s, None, None)
-            for s in all_modified),
+        aliased_orig_names=tuple(aliased_orig_names),
+        aliased_new_names=tuple(aliased_new_names),
+        all_results=tuple(alias_map[s] if s in aliased_orig_names else s
+                          for s in all_modified),
         results=results)
 
   def visit_While(self, node):
@@ -144,38 +130,28 @@ class ControlFlowTransformer(gast.NodeTransformer):
     body_scope = anno.getanno(node, 'body_scope')
     body_closure = tuple(body_scope.modified - body_scope.created)
 
-    def template(
-        state,  # pylint:disable=unused-argument
-        state_ast_tuple,  # pylint:disable=unused-argument
-        test_name,
-        test,  # pylint:disable=unused-argument
-        body_name,
-        body):
-
-      def test_name(state):  # pylint:disable=function-redefined,unused-argument
-        return test
-
-      def body_name(state):  # pylint:disable=function-redefined,unused-argument
-        body  # pylint:disable=pointless-statement
-        return state,
-
-      state_ast_tuple = tf.while_loop(test_name, body_name, [state])  # pylint:disable=undefined-variable
-
-    test_name = self.namer.new_symbol('loop_test', body_scope.referenced)
-    body_name = self.namer.new_symbol('loop_body', body_scope.referenced)
     if len(body_closure) == 1:
-      state = gast.Name(body_closure[0], None, None)
+      state = body_closure[0]
       state_ast_tuple = state
     else:
-      state = tuple(gast.Name(n, None, None) for n in body_closure)
-      state_ast_tuple = gast.Tuple(state, None)
+      state = tuple(body_closure)
+      state_ast_tuple = gast.Tuple(
+          tuple(gast.Name(n, None, None) for n in state), None)
+    template = """
+      def test_name(state):
+        return test
+      def body_name(state):
+        body
+        return state,
+      state_ast_tuple = tf.while_loop(test_name, body_name, [state])
+    """
     node = templates.replace(
         template,
         state=state,
         state_ast_tuple=state_ast_tuple,
-        test_name=gast.Name(test_name, gast.Load(), None),
+        test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
         test=node.test,
-        body_name=gast.Name(body_name, gast.Load(), None),
+        body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
         body=node.body)
 
     return node
index 52360789cdc25528d925092e3e269c9968f2022f..c284689b904c6f372f30e83c259416a51babe4a6 100644 (file)
@@ -42,46 +42,40 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
     # Or maybe we should replace range with tf.range?
 
     if anno.hasanno(node, 'extra_cond'):
-
-      def template(loop_iter, target, body, i, n, extra_cond):  # pylint:disable=unused-argument
+      template = """
         i = 0
-        n = len(loop_iter)  # pylint:disable=undefined-variable
+        n = len(loop_iter)
         while i < n and extra_cond:
           # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
           target = loop_iter[i]
-          body  # pylint:disable=pointless-statement
+          body
           i += 1
-
+      """
       return templates.replace(
           template,
           loop_iter=node.iter,
           target=node.target,
           body=node.body,
-          i=gast.Name(
-              self.namer.new_symbol('i', body_scope.referenced), None, None),
-          n=gast.Name(
-              self.namer.new_symbol('n', body_scope.referenced), None, None),
+          i=self.namer.new_symbol('i', body_scope.referenced),
+          n=self.namer.new_symbol('n', body_scope.referenced),
           extra_cond=anno.getanno(node, 'extra_cond'))
     else:
-
-      def template(loop_iter, target, body, i, n):  # pylint:disable=unused-argument
+      template = """
         i = 0
-        n = len(loop_iter)  # pylint:disable=undefined-variable
+        n = len(loop_iter)
         while i < n:
           # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
           target = loop_iter[i]
           body  # pylint:disable=pointless-statement
           i += 1
-
+      """
       return templates.replace(
           template,
           loop_iter=node.iter,
           target=node.target,
           body=node.body,
-          i=gast.Name(
-              self.namer.new_symbol('i', body_scope.referenced), None, None),
-          n=gast.Name(
-              self.namer.new_symbol('n', body_scope.referenced), None, None))
+          i=self.namer.new_symbol('i', body_scope.referenced),
+          n=self.namer.new_symbol('n', body_scope.referenced))
 
   def visit_Continue(self, node):
     assert False, 'continue statement should be desugared at this point'
index 1f25303fbac1184d016a63d629ba2ecf17d7e426..a88828ff802e7f15310b9350a8a73928bf699ebd 100644 (file)
@@ -94,12 +94,10 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
     return node
 
   def _gate_symbols(self, guard_statement, guarded_args):
-
-    def template(args):  # pylint:disable=unused-argument
-      (args,) = (tf.identity(a) for a in (args,))  # pylint:disable=undefined-variable
-
-    guards = templates.replace(
-        template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
+    template = """
+      (args,) = (tf.identity(a) for a in (args,))
+    """
+    guards = templates.replace(template, args=tuple(guarded_args))
     guard_statement.body.extend(guards)
     return guard_statement
 
@@ -110,29 +108,25 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
       #   opt.minimize(loss)
       # or:
       #   tf.py_func(...)
-
       args_scope = anno.getanno(node.value, 'args_scope')
       temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced)
       # TODO(mdan): Unsafe reference modification!
       args_scope.mark_write(temp_name)
-
-      def template(call, temp_result):
+      template = """
         temp_result = call
         if temp_result is not None:
           if not isinstance(temp_result, (list, tuple)):
             temp_result = (temp_result,)
-          ctx = tf.control_dependencies(temp_result)  # pylint:disable=undefined-variable
+          ctx = tf.control_dependencies(temp_result)
         else:
-          ctx = contextmanager(lambda: (yield))()  # pylint:disable=undefined-variable
+          ctx = contextmanager(lambda: (yield))()
         with ctx:
           # TODO(mdan): Also insert ops to re-fetch if variables are involved.
           pass  # Will be removed below.
-
-      # TODO(mdan): This is brittle. Reorganize this mechanism.
+      """
+      # TODO(mdan): This is brittle. Reorganize the mechanism.
       statements = templates.replace(
-          template,
-          call=node.value,
-          temp_result=gast.Name(temp_name, None, None))
+          template, call=node.value, temp_result=temp_name)
       control_deps_guard = statements[-1]
       control_deps_guard.body = []
 
index 4fadc793e6d1dfa8ddabea1d607de68ac6ad9c85..77c5fbe02a11ed4a6b3d2cd80a032858f5b07e33 100644 (file)
@@ -80,37 +80,46 @@ class ReplaceTransformer(gast.NodeTransformer):
       return node
 
 
+def _strings_to_names(n):
+  if isinstance(n, str):
+    # Note: the node will receive the ctx value from the template, see
+    # ReplaceTransformer.visit_Name.
+    return gast.Name(id=n, ctx=None, annotation=None)
+  if isinstance(n, list):
+    return [_strings_to_names(e) for e in n]
+  if isinstance(n, tuple):
+    return tuple(_strings_to_names(e) for e in n)
+  return n
+
+
 def replace(template, **replacements):
   """Replace placeholders in a Python template.
 
+  AST Name and Tuple nodes always receive the context that inferred from
+  the template. However, when replacing more complex nodes (that can potentially
+  contain Name children), then the caller is responsible for setting the
+  appropriate context.
+
   Args:
-    template: A function to be used as a template. Any placeholder is expected
-        to also be a function argument.
+    template: A string representing Python code. Any symbol name can be used
+        that appears in the template code can be used as placeholder.
     **replacements: A mapping from placeholder names to (lists of) AST nodes
-        that these placeholders will be replaced by.
+        that these placeholders will be replaced by. String values are also
+        supported as a shorthand for AST Name nodes with the respective ID.
 
   Returns:
-    body: An AST node or list of AST nodes with the replacements made. If the
-        template was a function, a list will be returned. If the template was a
-        node, the same node will be returned. If the template was a string, an
-        AST node will be returned (a `Module` node in the case of a multi-line
-        string, an `Expr` node otherwise).
+    An AST node or list of AST nodes with the replacements made. If the
+    template was a function, a list will be returned. If the template was a
+    node, the same node will be returned. If the template was a string, an
+    AST node will be returned (a `Module` node in the case of a multi-line
+    string, an `Expr` node otherwise).
 
   Raises:
-    ValueError: If a function is used as a template and an incorrect set of
-        replacements was passed.
+    ValueError: if the arguments are incorrect.
   """
-  tree = parser.parse_object(template).body[0]
-  placeholders = set(arg.id for arg in tree.args.args)
-  tree.args.args = []
-  if tree.args.vararg:
-    placeholders.add(tree.args.vararg)
-    tree.args.vararg = None
-  if set(replacements.keys()) != placeholders:
-    raise ValueError(
-        'too many or few replacements. replacements: %s; placeholders: %s' %
-        (replacements.keys(), placeholders))
-
-  # Perform the replacement, stripping the function into which the template was
-  # wrapped.
+  if not isinstance(template, str):
+    raise ValueError('Expected string template, got %s' % type(template))
+  tree = parser.parse_str(template)
+  for k in replacements:
+    replacements[k] = _strings_to_names(replacements[k])
   return ReplaceTransformer(replacements).visit(tree).body
index 2ad8b9317b67c7ae18a16efac745138e14101e6a..1143131283cd92c42abfc73d5728fac96cc31c23 100644 (file)
@@ -28,46 +28,42 @@ from tensorflow.python.platform import test
 class TemplatesTest(test.TestCase):
 
   def test_replace_variable(self):
-    def template(a):  # pylint:disable=unused-argument
-      def test_fn(a):  # pylint:disable=unused-variable
+    template = """
+      def test_fn(a):
         a += 1
         a = 2 * a + 1
-        return b  # pylint:disable=undefined-variable
+        return b
+    """
 
-    node = templates.replace(
-        template, a=gast.Name('b', gast.Load(), None))[0]
+    node = templates.replace(template, a='b')[0]
     result = compiler.ast_to_object(node)
     self.assertEquals(7, result.test_fn(2))
 
   def test_replace_function_name(self):
-    def template(fname):  # pylint:disable=unused-argument
-      def fname(a):  # pylint:disable=function-redefined
+    template = """
+      def fname(a):
         a += 1
         a = 2 * a + 1
         return a
+    """
 
-    node = templates.replace(
-        template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+    node = templates.replace(template, fname='test_fn')[0]
     result = compiler.ast_to_object(node)
     self.assertEquals(7, result.test_fn(2))
 
   def test_code_block(self):
-    def template(block):  # pylint:disable=unused-argument
-      def test_fn(a):  # pylint:disable=unused-variable
-        block  # pylint:disable=pointless-statement
+    template = """
+      def test_fn(a):
+        block
         return a
+    """
 
     node = templates.replace(
         template,
         block=[
-            gast.Assign(
-                [
-                    gast.Name('a', gast.Store(), None)
-                ],
-                gast.BinOp(
-                    gast.Name('a', gast.Load(), None),
-                    gast.Add(),
-                    gast.Num(1))),
+            gast.Assign([
+                gast.Name('a', None, None)
+            ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
         ] * 2)[0]
     result = compiler.ast_to_object(node)
     self.assertEquals(3, result.test_fn(1))