Expand the activity analysis to composite names.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 5 Feb 2018 15:56:27 +0000 (07:56 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 16:03:03 +0000 (08:03 -0800)
Fix a bug in the cond template that caused bad syntax when there it no symbol that needs aliasing.
More refactoring in the process, including:
 * introduce the QN (qualified name) class to hold symbol information; it has value semantics and can generate the original symbol, a corresponding AST tree or a single-symbol form (e.g. "a.b" -> a_b)
 * allow the template mechanism to use QNs for substitutions
 * annotate *all* symbol nodes with their corresponding QN object; this is done as first step during static analysis, and automatically performed on all template expansions
 * start using typed annotation keys (Enum values) instead of plain strings
 * rename access.py to activity.py
 * sanitize nodes in template expansion by deep copying the AST without annotations, to avoid common references

PiperOrigin-RevId: 184528586

39 files changed:
tensorflow/contrib/py2tf/BUILD
tensorflow/contrib/py2tf/__init__.py
tensorflow/contrib/py2tf/converters/BUILD
tensorflow/contrib/py2tf/converters/break_canonicalization.py
tensorflow/contrib/py2tf/converters/break_canonicalization_test.py
tensorflow/contrib/py2tf/converters/builtin_functions.py
tensorflow/contrib/py2tf/converters/builtin_functions_test.py
tensorflow/contrib/py2tf/converters/call_trees.py
tensorflow/contrib/py2tf/converters/continue_canonicalization.py
tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py
tensorflow/contrib/py2tf/converters/control_flow.py
tensorflow/contrib/py2tf/converters/control_flow_test.py
tensorflow/contrib/py2tf/converters/converter_test_base.py
tensorflow/contrib/py2tf/converters/for_canonicalization.py
tensorflow/contrib/py2tf/converters/for_canonicalization_test.py
tensorflow/contrib/py2tf/converters/print_functions.py [deleted file]
tensorflow/contrib/py2tf/converters/side_effect_guards.py
tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
tensorflow/contrib/py2tf/impl/conversion.py
tensorflow/contrib/py2tf/impl/naming.py
tensorflow/contrib/py2tf/pyct/BUILD
tensorflow/contrib/py2tf/pyct/anno.py
tensorflow/contrib/py2tf/pyct/copier.py [new file with mode: 0644]
tensorflow/contrib/py2tf/pyct/copier_test.py [moved from tensorflow/contrib/py2tf/converters/print_functions_test.py with 53% similarity]
tensorflow/contrib/py2tf/pyct/pretty_printer.py
tensorflow/contrib/py2tf/pyct/pretty_printer_test.py
tensorflow/contrib/py2tf/pyct/qual_names.py [new file with mode: 0644]
tensorflow/contrib/py2tf/pyct/static_analysis/BUILD
tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py [deleted file]
tensorflow/contrib/py2tf/pyct/static_analysis/activity.py [moved from tensorflow/contrib/py2tf/pyct/static_analysis/access.py with 75% similarity]
tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/pyct/static_analysis/annos.py [new file with mode: 0644]
tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
tensorflow/contrib/py2tf/pyct/templates.py
tensorflow/contrib/py2tf/pyct/templates_test.py
tensorflow/contrib/py2tf/pyct/transformer.py

index 479ea9b..d91220f 100644 (file)
@@ -23,6 +23,7 @@ py_library(
     visibility = ["//visibility:public"],
     deps = [
         "//tensorflow/contrib/py2tf/impl",
+        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/contrib/py2tf/utils",
         "@gast_archive//:gast",
         "@six_archive//:six",
index 0d51bf0..379fa7f 100644 (file)
@@ -26,8 +26,11 @@ from tensorflow.contrib.py2tf.impl.api import convert
 from tensorflow.contrib.py2tf.impl.api import graph_ready
 from tensorflow.contrib.py2tf.impl.api import to_code
 from tensorflow.contrib.py2tf.impl.api import to_graph
+from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError
 from tensorflow.python.util.all_util import remove_undocumented
 
-_allowed_symbols = ['to_graph', 'to_code', 'convert', 'graph_ready', 'utils']
+_allowed_symbols = [
+    'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError'
+]
 
 remove_undocumented(__name__, _allowed_symbols)
index 3853c60..68ea247 100644 (file)
@@ -26,7 +26,6 @@ py_library(
         "decorators.py",
         "for_canonicalization.py",
         "logical_expressions.py",
-        "print_functions.py",
         "side_effect_guards.py",
     ],
     srcs_version = "PY2AND3",
@@ -150,18 +149,6 @@ py_test(
 )
 
 py_test(
-    name = "print_functions_test",
-    srcs = ["print_functions_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
-        "//tensorflow/python:client_testlib",
-        "@gast_archive//:gast",
-    ],
-)
-
-py_test(
     name = "side_effect_guards_test",
     srcs = ["side_effect_guards_test.py"],
     srcs_version = "PY2AND3",
index 2ae65e3..bfb709c 100644 (file)
@@ -22,13 +22,15 @@ import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 
-class BreakCanonicalizationTransformer(gast.NodeTransformer):
+class BreakCanonicalizationTransformer(transformer.Base):
   """Canonicalizes continue statements into additional conditionals."""
 
-  def __init__(self, namer):
-    self.namer = namer
+  def __init__(self, context):
+    super(BreakCanonicalizationTransformer, self).__init__(context)
     # This is a stack structure, to correctly process nested loops.
     self.break_uses = []
 
@@ -67,9 +69,10 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer):
 
   def visit_While(self, node):
     self.generic_visit(node.test)
-    scope = anno.getanno(node, 'body_scope')
+    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
 
-    break_var = self.namer.new_symbol('break_requested', scope.referenced)
+    break_var = self.context.namer.new_symbol('break_requested',
+                                              scope.referenced)
     self.break_uses.append([False, break_var])
     node.body = self._manual_visit_list(node.body)
     if self.break_uses[-1][0]:
@@ -89,9 +92,10 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer):
   def visit_For(self, node):
     self.generic_visit(node.target)
     self.generic_visit(node.iter)
-    scope = anno.getanno(node, 'body_scope')
+    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
 
-    break_var = self.namer.new_symbol('break_requested', scope.referenced)
+    break_var = self.context.namer.new_symbol('break_requested',
+                                              scope.referenced)
     self.break_uses.append([False, break_var])
     node.body = self._manual_visit_list(node.body)
     if self.break_uses[-1][0]:
@@ -112,7 +116,5 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer):
     return self._create_break_trigger()
 
 
-def transform(node, namer):
-  transformer = BreakCanonicalizationTransformer(namer)
-  node = transformer.visit(node)
-  return node
+def transform(node, context):
+  return BreakCanonicalizationTransformer(context).visit(node)
index b5ba2ad..54c4d99 100644 (file)
@@ -44,8 +44,8 @@ class BreakCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
-    node = break_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = break_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     self.assertEqual(test_fn(0), result.test_fn(0))
@@ -76,8 +76,8 @@ class BreakCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
-    node = break_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = break_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     # The break is incompletely canonicalized. Everything is in place, but
@@ -104,8 +104,8 @@ class BreakCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v, u, w
 
-    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
-    node = break_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = break_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     self.assertEqual(test_fn(0), result.test_fn(0))
index 7f6b64a..3e56634 100644 (file)
@@ -21,12 +21,14 @@ from __future__ import print_function
 import gast
 
 from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
 
 
-class BuiltinFunctionTransformer(gast.NodeTransformer):
+class BuiltinFunctionTransformer(transformer.Base):
   """Transforms Print nodes to Call so they can be handled as functions."""
 
-  # TODO(mdan): Bring print_functions in here.
+  def __init__(self, context):
+    super(BuiltinFunctionTransformer, self).__init__(context)
 
   def _convert_len(self, node):
     template = """
@@ -44,10 +46,15 @@ class BuiltinFunctionTransformer(gast.NodeTransformer):
       return self._convert_len(node)
     return node
 
+  def visit_Print(self, node):
+    self.generic_visit(node)
+    template = """
+      fname(args)
+    """
+    return templates.replace(template, fname='print', args=node.values)
+
   # pylint:enable=invalid-name
 
 
-def transform(node):
-  transformer = BuiltinFunctionTransformer()
-  node = transformer.visit(node)
-  return node
+def transform(node, context):
+  return BuiltinFunctionTransformer(context).visit(node)
index b5358da..be76066 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import gast
+
 from tensorflow.contrib.py2tf.converters import builtin_functions
 from tensorflow.contrib.py2tf.converters import converter_test_base
 from tensorflow.contrib.py2tf.pyct import compiler
@@ -34,7 +36,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
       return len(a)
 
     node = self.parse_and_analyze(test_fn, {'len': len})
-    node = builtin_functions.transform(node)
+    node = builtin_functions.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
     setattr(result, 'tf', array_ops)
 
@@ -43,6 +45,18 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
                        sess.run(
                            result.test_fn(constant_op.constant([0, 0, 0]))))
 
+  def test_print(self):
+
+    def test_fn(a):
+      print(a)
+
+    node = self.parse_and_analyze(test_fn, {'print': print})
+    node = builtin_functions.transform(node, self.ctx)
+    result = compiler.ast_to_object(node)
+
+    result.test_fn('a')
+    self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+
 
 if __name__ == '__main__':
   test.main()
index 4c238b7..834baf2 100644 (file)
@@ -30,6 +30,7 @@ from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import parser
 from tensorflow.contrib.py2tf.pyct import templates
 from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 from tensorflow.python.util import tf_inspect
 
 
@@ -192,11 +193,11 @@ class CallTreeTransformer(transformer.Base):
           # The renaming process will transform it into a regular function.
           # TODO(mdan): Is this complete? How does it work with nested members?
           node.args = [node.func.value] + node.args
-      node.func = gast.Name(new_name, gast.Load(), None)
+      node.func = templates.replace('func_name', func_name=new_name)[0]
     return node
 
   def _wrap_to_py_func_no_return(self, node):
-    args_scope = anno.getanno(node, 'args_scope')
+    args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
     # TODO(mdan): Properly handle varargs, kwargs, etc.
     template = """
       def wrapper(args):
@@ -208,10 +209,10 @@ class CallTreeTransformer(transformer.Base):
         template,
         call=node.func,
         wrapper=self.context.namer.compiled_function_name(node.func.id)[0],
-        args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
-    anno.setanno(call_expr.value, 'args_scope', args_scope)
+        args=tuple(args_scope.used))
+    anno.setanno(call_expr.value, NodeAnno.ARGS_SCOPE, args_scope)
     # TODO(mdan): Rename this annotation to 'graph_ready'
-    anno.setanno(wrapper_def, 'skip_processing', True)
+    anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)
 
     return (wrapper_def, call_expr)
 
index 486f0f6..4069a67 100644 (file)
@@ -18,17 +18,17 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import gast
-
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 
-class ContinueCanonicalizationTransformer(gast.NodeTransformer):
+class ContinueCanonicalizationTransformer(transformer.Base):
   """Canonicalizes continue statements into additional conditionals."""
 
-  def __init__(self, namer):
-    self.namer = namer
+  def __init__(self, context):
+    super(ContinueCanonicalizationTransformer, self).__init__(context)
     # This is a stack structure, to correctly process nested loops.
     self.continuation_uses = []
 
@@ -76,7 +76,7 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
     return reorganized_nodes
 
   def _process_loop_block(self, block, scope):
-    cont_var = self.namer.new_symbol('cont_requested', scope.referenced)
+    cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced)
     self.continuation_uses.append([False, cont_var])
     block = self._visit_and_reindent_if_necessary(block)
     if self.continuation_uses[-1][0]:
@@ -87,7 +87,8 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
   def visit_While(self, node):
     self.generic_visit(node.test)
     node.body = self._process_loop_block(node.body,
-                                         anno.getanno(node, 'body_scope'))
+                                         anno.getanno(node,
+                                                      NodeAnno.BODY_SCOPE))
     for n in node.orelse:
       self.generic_visit(n)
     return node
@@ -96,7 +97,8 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
     self.generic_visit(node.target)
     self.generic_visit(node.iter)
     node.body = self._process_loop_block(node.body,
-                                         anno.getanno(node, 'body_scope'))
+                                         anno.getanno(node,
+                                                      NodeAnno.BODY_SCOPE))
     for n in node.orelse:
       self.generic_visit(n)
     return node
@@ -122,6 +124,4 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
 
 
 def transform(node, namer):
-  transformer = ContinueCanonicalizationTransformer(namer)
-  node = transformer.visit(node)
-  return node
+  return ContinueCanonicalizationTransformer(namer).visit(node)
index c1fe903..4b18819 100644 (file)
@@ -44,8 +44,8 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
-    node = continue_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = continue_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     self.assertEqual(test_fn(0), result.test_fn(0))
@@ -65,8 +65,8 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
-    node = continue_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = continue_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     self.assertEqual(test_fn([]), result.test_fn([]))
@@ -91,8 +91,8 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v, u, w
 
-    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
-    node = continue_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = continue_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     self.assertEqual(test_fn(0), result.test_fn(0))
index a40c7b2..a256c07 100644 (file)
@@ -22,6 +22,8 @@ import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 
 class SymbolNamer(object):
@@ -41,21 +43,29 @@ class SymbolNamer(object):
 
 
 class SymbolRenamer(gast.NodeTransformer):
+  """Transformer that can rename symbols to a simple names."""
 
   def __init__(self, name_map):
     self.name_map = name_map
 
-  def visit_Name(self, node):
-    if node.id in self.name_map:
-      node.id = self.name_map[node.id]
+  def _process(self, node):
+    qn = anno.getanno(node, anno.Basic.QN)
+    if qn in self.name_map:
+      return gast.Name(self.name_map[qn], node.ctx, None)
     return node
 
+  def visit_Name(self, node):
+    return self._process(node)
+
+  def visit_Attribute(self, node):
+    return self._process(node)
+
 
-class ControlFlowTransformer(gast.NodeTransformer):
+class ControlFlowTransformer(transformer.Base):
   """Transforms control flow structures like loops an conditionals."""
 
-  def __init__(self, namer):
-    self.namer = namer
+  def __init__(self, context):
+    super(ControlFlowTransformer, self).__init__(context)
 
   # pylint:disable=invalid-name
 
@@ -65,8 +75,8 @@ class ControlFlowTransformer(gast.NodeTransformer):
   def visit_If(self, node):
     self.generic_visit(node)
 
-    body_scope = anno.getanno(node, 'body_scope')
-    orelse_scope = anno.getanno(node, 'orelse_scope')
+    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
 
     if body_scope.created - orelse_scope.created:
       raise ValueError(
@@ -86,7 +96,8 @@ class ControlFlowTransformer(gast.NodeTransformer):
         (body_scope.created | orelse_scope.created))
     aliased_orig_names = tuple(need_alias)
     aliased_new_names = tuple(
-        self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names)
+        self.context.namer.new_symbol(s.ssf(), all_referenced)
+        for s in aliased_orig_names)
     alias_map = dict(zip(aliased_orig_names, aliased_new_names))
     node_body = node.body
     node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
@@ -94,72 +105,112 @@ class ControlFlowTransformer(gast.NodeTransformer):
     node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse]
 
     if len(all_modified) == 1:
-      results = gast.Name(all_modified[0], None, None)
+      results = all_modified[0]
     else:
-      results = gast.Tuple(
-          tuple(gast.Name(s, None, None) for s in all_modified), None)
-
-    template = """
-      def body_name():
-        aliased_new_names, = aliased_orig_names,
-        body
-        return (all_results,)
-      def orelse_name():
-        aliased_new_names, = aliased_orig_names,
-        orelse
-        return (all_results,)
-      results = tf.cond(test, body_name, orelse_name)
-    """
-    body_name = self.namer.new_symbol('if_true', all_referenced)
-    return templates.replace(
-        template,
-        test=node.test,
-        body_name=body_name,
-        body=node_body,
-        orelse_name=self.namer.new_symbol('if_false', all_referenced),
-        orelse=node_orelse,
-        aliased_orig_names=tuple(aliased_orig_names),
-        aliased_new_names=tuple(aliased_new_names),
-        all_results=tuple(alias_map[s] if s in aliased_orig_names else s
-                          for s in all_modified),
-        results=results)
+      results = gast.Tuple([s.ast() for s in all_modified], None)
+
+    if aliased_orig_names:
+      template = """
+        def body_name():
+          aliased_new_names, = aliased_orig_names,
+          body
+          return (all_results,)
+        def orelse_name():
+          aliased_new_names, = aliased_orig_names,
+          orelse
+          return (all_results,)
+        results = tf.cond(test, body_name, orelse_name)
+      """
+      body_name = self.context.namer.new_symbol('if_true', all_referenced)
+      return templates.replace(
+          template,
+          test=node.test,
+          body_name=body_name,
+          body=node_body,
+          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
+          orelse=node_orelse,
+          aliased_orig_names=tuple(aliased_orig_names),
+          aliased_new_names=tuple(aliased_new_names),
+          all_results=tuple(alias_map[s] if s in aliased_orig_names else s
+                            for s in all_modified),
+          results=results)
+    else:
+      template = """
+        def body_name():
+          body
+          return (all_results,)
+        def orelse_name():
+          orelse
+          return (all_results,)
+        results = tf.cond(test, body_name, orelse_name)
+      """
+      body_name = self.context.namer.new_symbol('if_true', all_referenced)
+      return templates.replace(
+          template,
+          test=node.test,
+          body_name=body_name,
+          body=node_body,
+          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
+          orelse=node_orelse,
+          all_results=tuple(s for s in all_modified),
+          results=results)
 
   def visit_While(self, node):
     self.generic_visit(node)
 
-    body_scope = anno.getanno(node, 'body_scope')
-    body_closure = tuple(body_scope.modified - body_scope.created)
-
-    if len(body_closure) == 1:
-      state = body_closure[0]
+    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+    body_closure = body_scope.modified - body_scope.created
+    all_referenced = body_scope.referenced
+
+    state = list(body_closure)
+    state_ssf = [
+        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+    ]
+    ssf_map = {
+        name: ssf
+        for name, ssf in zip(state, state_ssf)
+        if str(name) != ssf
+    }
+
+    if len(state) == 1:
+      state = state[0]
+      state_ssf = state_ssf[0]
       state_ast_tuple = state
     else:
-      state = tuple(body_closure)
-      state_ast_tuple = gast.Tuple(
-          tuple(gast.Name(n, None, None) for n in state), None)
+      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+    node_body = node.body
+    node_body = [SymbolRenamer(ssf_map).visit(n) for n in node_body]
+
+    test = node.test
+    test = SymbolRenamer(ssf_map).visit(test)
+
     template = """
-      def test_name(state):
+      def test_name(state_ssf):
         return test
-      def body_name(state):
+      def body_name(state_ssf):
         body
-        return state,
+        return state_ssf,
       state_ast_tuple = tf.while_loop(test_name, body_name, [state])
     """
     node = templates.replace(
         template,
         state=state,
+        state_ssf=state_ssf,
         state_ast_tuple=state_ast_tuple,
-        test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
-        test=node.test,
-        body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
-        body=node.body)
+        test_name=self.context.namer.new_symbol('loop_test',
+                                                body_scope.referenced),
+        test=test,
+        body_name=self.context.namer.new_symbol('loop_body',
+                                                body_scope.referenced),
+        body=node_body)
 
     return node
 
   # pylint:enable=invalid-name
 
 
-def transform(node, namer):
-  transformer = ControlFlowTransformer(namer)
-  node = transformer.visit(node)
+def transform(node, context):
+  t = ControlFlowTransformer(context)
+  node = t.visit(node)
   return node
index 054e337..f192bf1 100644 (file)
@@ -49,8 +49,8 @@ class ControlFlowTest(converter_test_base.TestCase):
         i += 1
       return s, i, n
 
-    node = self.parse_and_analyze(test_fn, {})
-    node = control_flow.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = control_flow.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
     setattr(result, 'tf', control_flow_ops)
 
@@ -65,8 +65,8 @@ class ControlFlowTest(converter_test_base.TestCase):
         n -= 1
       return n
 
-    node = self.parse_and_analyze(test_fn, {})
-    node = control_flow.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = control_flow.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
     setattr(result, 'tf', control_flow_ops)
 
@@ -84,8 +84,8 @@ class ControlFlowTest(converter_test_base.TestCase):
         b = 2 * n
       return a, b
 
-    node = self.parse_and_analyze(test_fn, {})
-    node = control_flow.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = control_flow.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
     setattr(result, 'tf', control_flow_ops)
 
@@ -102,8 +102,8 @@ class ControlFlowTest(converter_test_base.TestCase):
         n = -n
       return n
 
-    node = self.parse_and_analyze(test_fn, {})
-    node = control_flow.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = control_flow.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
     setattr(result, 'tf', control_flow_ops)
 
index 6bfa554..bcb96c8 100644 (file)
@@ -20,7 +20,8 @@ from __future__ import print_function
 
 from tensorflow.contrib.py2tf.pyct import context
 from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
 from tensorflow.python.platform import test
@@ -44,7 +45,8 @@ class TestCase(test.TestCase):
         arg_values=None,
         arg_types=arg_types,
         recursive=recursive)
-    node = access.resolve(node, ctx)
+    node = qual_names.resolve(node)
+    node = activity.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     if include_type_analysis:
       node = type_info.resolve(node, ctx)
index c284689..935dade 100644 (file)
@@ -22,24 +22,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import gast
-
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 
-class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
+class ForLoopCanonicalizationTransformer(transformer.Base):
   """Canonicalizes for loops (e.g. into while loops)."""
 
-  def __init__(self, namer):
-    self.namer = namer
+  def __init__(self, context):
+    super(ForLoopCanonicalizationTransformer, self).__init__(context)
 
   def visit_For(self, node):
     self.generic_visit(node)
-    body_scope = anno.getanno(node, 'body_scope')
-
-    # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)`
-    # Or maybe we should replace range with tf.range?
+    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
 
     if anno.hasanno(node, 'extra_cond'):
       template = """
@@ -56,8 +53,8 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
           loop_iter=node.iter,
           target=node.target,
           body=node.body,
-          i=self.namer.new_symbol('i', body_scope.referenced),
-          n=self.namer.new_symbol('n', body_scope.referenced),
+          i=self.context.namer.new_symbol('i', body_scope.referenced),
+          n=self.context.namer.new_symbol('n', body_scope.referenced),
           extra_cond=anno.getanno(node, 'extra_cond'))
     else:
       template = """
@@ -69,13 +66,14 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
           body  # pylint:disable=pointless-statement
           i += 1
       """
-      return templates.replace(
+      repl = templates.replace(
           template,
           loop_iter=node.iter,
           target=node.target,
           body=node.body,
-          i=self.namer.new_symbol('i', body_scope.referenced),
-          n=self.namer.new_symbol('n', body_scope.referenced))
+          i=self.context.namer.new_symbol('i', body_scope.referenced),
+          n=self.context.namer.new_symbol('n', body_scope.referenced))
+      return repl
 
   def visit_Continue(self, node):
     assert False, 'continue statement should be desugared at this point'
@@ -84,7 +82,5 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
     assert False, 'break statement should be desugared at this point'
 
 
-def transform(node, namer):
-  transformer = ForLoopCanonicalizationTransformer(namer)
-  node = transformer.visit(node)
-  return node
+def transform(node, context):
+  return ForLoopCanonicalizationTransformer(context).visit(node)
index a6e6350..142bd4a 100644 (file)
@@ -41,8 +41,8 @@ class ControlFlowTest(converter_test_base.TestCase):
         s += e
       return s
 
-    node = self.parse_and_analyze(test_fn, {})
-    node = for_canonicalization.transform(node, TestNamer())
+    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = for_canonicalization.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
 
     l = [1, 2, 3]
diff --git a/tensorflow/contrib/py2tf/converters/print_functions.py b/tensorflow/contrib/py2tf/converters/print_functions.py
deleted file mode 100644 (file)
index 5da738c..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Compatibility support. Converts Print nodes to function calls."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.contrib.py2tf.pyct import anno
-
-
-class PrintFunctionTransformer(gast.NodeTransformer):
-  """Transforms Print nodes to Call so they can be handled as functions."""
-
-  # pylint:disable=invalid-name
-
-  def visit_Print(self, node):
-    self.generic_visit(node)
-    for n in node.values:
-      n.ctx = gast.Param()
-    call_node = gast.Call(
-        func=gast.Name('print', gast.Load(), None),
-        args=node.values,
-        keywords=[])
-    anno.setanno(call_node.func, 'live_val', print)
-    anno.setanno(call_node.func, 'fqn', 'print')
-    anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
-    node = gast.Expr(call_node)
-    return node
-
-  # pylint:enable=invalid-name
-
-
-def transform(node):
-  transformer = PrintFunctionTransformer()
-  node = transformer.visit(node)
-  return node
index ffca743..c73388d 100644 (file)
@@ -38,6 +38,8 @@ import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 
 class SymbolNamer(object):
@@ -55,11 +57,11 @@ class SymbolNamer(object):
     raise NotImplementedError()
 
 
-class SideEffectGuardTransformer(gast.NodeTransformer):
+class SideEffectGuardTransformer(transformer.Base):
   """Adds control dependencies to functions with side effects."""
 
-  def __init__(self, namer):
-    self.namer = namer
+  def __init__(self, context):
+    super(SideEffectGuardTransformer, self).__init__(context)
     self.indent_next = False
     self.next_indent_owner = None
 
@@ -88,8 +90,6 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
     return new_nodes
 
   def visit_FunctionDef(self, node):
-    if anno.hasanno(node, 'skip_processing'):
-      return node
     node.body = self._visit_and_reindent(node.body)
     return node
 
@@ -122,7 +122,7 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
       # First, attempt to gate future evaluation of args. If that's not
       # possible, gate all remaining statements (and that may fail too, see
       # _visit_and_reindent.
-      args_scope = anno.getanno(node.value, 'args_scope')
+      args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
       guarded_args = tuple(args_scope.used & (args_scope.parent.modified
                                               | args_scope.parent.returned))
       if guarded_args:
@@ -138,6 +138,5 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
   # pylint:enable=invalid-name
 
 
-def transform(node, namer):
-  transformer = SideEffectGuardTransformer(namer)
-  return transformer.visit(node)
+def transform(node, context):
+  return SideEffectGuardTransformer(context).visit(node)
index 452d7ab..dea09ec 100644 (file)
@@ -43,8 +43,9 @@ class SideEffectGuardsTest(converter_test_base.TestCase):
       state_ops.assign(a, a + 1)
       return a
 
-    node = self.parse_and_analyze(test_fn, {'state_ops': state_ops})
-    node = side_effect_guards.transform(node, TestNamer())
+    node = self.parse_and_analyze(
+        test_fn, {'state_ops': state_ops}, namer=TestNamer())
+    node = side_effect_guards.transform(node, self.ctx)
     result = compiler.ast_to_object(node)
     setattr(result, 'state_ops', state_ops)
     setattr(result, 'py2tf_utils', utils)
index ed71ff5..ff4f159 100644 (file)
@@ -30,13 +30,13 @@ from tensorflow.contrib.py2tf.converters import control_flow
 from tensorflow.contrib.py2tf.converters import decorators
 from tensorflow.contrib.py2tf.converters import for_canonicalization
 from tensorflow.contrib.py2tf.converters import logical_expressions
-from tensorflow.contrib.py2tf.converters import print_functions
 from tensorflow.contrib.py2tf.converters import side_effect_guards
 from tensorflow.contrib.py2tf.impl import config
 from tensorflow.contrib.py2tf.impl import naming
 from tensorflow.contrib.py2tf.pyct import context
 from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
 from tensorflow.python.util import tf_inspect
@@ -208,7 +208,8 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
 
 
 def _static_analysis_pass(node, ctx):
-  node = access.resolve(node, ctx)
+  node = qual_names.resolve(node)
+  node = activity.resolve(node, ctx, None)
   node = live_values.resolve(node, ctx, config.PYTHON_LITERALS)
   node = type_info.resolve(node, ctx)
   return node
@@ -233,10 +234,7 @@ def node_to_graph(node, ctx, nocompile_decorators):
 
   # TODO(mdan): Factor out common elements.
   # These include:
-  #   * keeping track of symbols that have been created
-  #   * marking nodes (e.g. py_func wrappers) to suppress further processing
   #   * code move between blocks
-  #   * insertion of new global references
   #   * visiting blocks in transformers
 
   # Certain steps, especially canonicalization, insert new symbols into the
@@ -244,29 +242,35 @@ def node_to_graph(node, ctx, nocompile_decorators):
   # to re-run the analysis.
 
   node = _static_analysis_pass(node, ctx)
+  # Past this point, line numbers are no longer accurate so we ignore the
+  # source.
+  # TODO(mdan): Is it feasible to reconstruct intermediate source code?
+  ctx.source_code = None
   node = decorators.transform(node, nocompile_decorators)
-  node = break_canonicalization.transform(node, ctx.namer)
+  node = break_canonicalization.transform(node, ctx)
   node = asserts.transform(node, ctx)
 
   # Note: sequencing continue canonicalization before for loop one avoids
   # dealing with the extra loop increment operation that the for
   # canonicalization creates.
-  node = continue_canonicalization.transform(node, ctx.namer)
+  node = continue_canonicalization.transform(node, ctx)
   ctx.namespace['len'] = len
 
   node = _static_analysis_pass(node, ctx)
-  node = for_canonicalization.transform(node, ctx.namer)
+  node = for_canonicalization.transform(node, ctx)
   # for_canonicalization may insert new global references.
-  node = builtin_functions.transform(node)
+  node = builtin_functions.transform(node, ctx)
   # builtin_functions may insert new global references.
   ctx.namespace['print'] = print
 
   node = _static_analysis_pass(node, ctx)
-  node = print_functions.transform(node)
   node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES,
                               nocompile_decorators)
-  node = control_flow.transform(node, ctx.namer)
+  node = control_flow.transform(node, ctx)
+
+  # control_flow may create new symbols and change scopes.
+  node = _static_analysis_pass(node, ctx)
   node = logical_expressions.transform(node)
-  node = side_effect_guards.transform(node, ctx.namer)
+  node = side_effect_guards.transform(node, ctx)
 
   return node
index 5c7e4c5..d31462c 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.py2tf.pyct import qual_names
+
 
 class Namer(object):
   """Implementation of the namer interfaces required by various converters.
@@ -103,11 +105,20 @@ class Namer(object):
 
   def new_symbol(self, name_root, reserved_locals):
     """See control_flow.SymbolNamer.new_symbol."""
+    # reserved_locals may contain QNs.
+    all_reserved_locals = set()
+    for s in reserved_locals:
+      if isinstance(s, qual_names.QN):
+        all_reserved_locals.update(s.qn)
+      elif isinstance(s, str):
+        all_reserved_locals.add(s)
+      else:
+        raise ValueError('Unexpected symbol type "%s"' % type(s))
+
     new_name = name_root
     n = 0
-    while (new_name in self.global_namespace
-           or new_name in reserved_locals
-           or new_name in self.generated_names):
+    while (new_name in self.global_namespace or
+           new_name in all_reserved_locals or new_name in self.generated_names):
       n += 1
       new_name = '%s_%d' % (name_root, n)
 
index 1b2408b..054eb17 100644 (file)
@@ -21,8 +21,10 @@ py_library(
         "anno.py",
         "compiler.py",
         "context.py",
+        "copier.py",
         "parser.py",
         "pretty_printer.py",
+        "qual_names.py",
         "templates.py",
         "transformer.py",
     ],
@@ -58,6 +60,17 @@ py_test(
 )
 
 py_test(
+    name = "copier_test",
+    srcs = ["copier_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":pyct",
+        "//tensorflow/python:client_testlib",
+        "@gast_archive//:gast",
+    ],
+)
+
+py_test(
     name = "parser_test",
     srcs = ["parser_test.py"],
     srcs_version = "PY2AND3",
index 889e4ba..c6d41f9 100644 (file)
@@ -21,6 +21,25 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from enum import Enum
+
+
+class NoValue(Enum):
+
+  def __repr__(self):
+    return self.name
+
+
+class Basic(NoValue):
+  """Container for annotation keys.
+
+  The enum values are used strictly for documentation purposes.
+  """
+
+  QN = 'Qualified name, as it appeared in the code.'
+  SKIP_PROCESSING = (
+      'This node should be preserved as is and not processed any further.')
+
 
 def getanno(node, key, field_name='___pyct_anno'):
   return getattr(node, field_name)[key]
diff --git a/tensorflow/contrib/py2tf/pyct/copier.py b/tensorflow/contrib/py2tf/pyct/copier.py
new file mode 100644 (file)
index 0000000..41598fd
--- /dev/null
@@ -0,0 +1,68 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Copy an AST tree, discarding annotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+
+
+class CleanCopier(gast.NodeVisitor):
+  """Copy AST nodes.
+
+  The copied nodes will ignore almost all fields that prefixed by '__'.
+  Exceptions make some annotations.
+  """
+
+  # TODO(mdan): Parametrize which annotations get carried over.
+
+  def generic_visit(self, node):
+    new_fields = {}
+    for f in node._fields:
+      if f.startswith('__'):
+        continue
+      if not hasattr(node, f):
+        continue
+      v = getattr(node, f)
+      if isinstance(v, list):
+        v = [self.generic_visit(n) for n in v]
+      elif isinstance(v, tuple):
+        v = tuple(self.generic_visit(n) for n in v)
+      elif isinstance(v, (gast.AST, ast.AST)):
+        v = self.generic_visit(v)
+      else:
+        # Assume everything else is a value type.
+        pass
+      new_fields[f] = v
+    new_node = type(node)(**new_fields)
+    if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+      anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True)
+    return new_node
+
+
+def copy_clean(node):
+  copier = CleanCopier()
+  if isinstance(node, list):
+    return [copier.visit(n) for n in node]
+  elif isinstance(node, tuple):
+    return tuple(copier.visit(n) for n in node)
+  else:
+    return copier.visit(node)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for print_functions module."""
+"""Tests for copier module."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import gast
+import ast
 
-from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.converters import print_functions
-from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import copier
 from tensorflow.python.platform import test
 
 
-class PrintFunctionsTest(converter_test_base.TestCase):
-
-  def test_transform(self):
-
-    def test_fn(a):
-      print(a)
-
-    node = self.parse_and_analyze(test_fn, {'print': print})
-    node = print_functions.transform(node)
-    result = compiler.ast_to_object(node)
-
-    result.test_fn('a')
-    self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+class CopierTest(test.TestCase):
+
+  def test_copy_clean(self):
+    ret = ast.Return(
+        ast.BinOp(
+            op=ast.Add(),
+            left=ast.Name(id='a', ctx=ast.Load()),
+            right=ast.Num(1)))
+    setattr(ret, '__foo', 'bar')
+    node = ast.FunctionDef(
+        name='f',
+        args=ast.arguments(
+            args=[ast.Name(id='a', ctx=ast.Param())],
+            vararg=None,
+            kwarg=None,
+            defaults=[]),
+        body=[ret],
+        decorator_list=[],
+        returns=None)
+    new_node = copier.copy_clean(node)
+    self.assertFalse(node is new_node)
+    self.assertFalse(ret is new_node.body[0])
+    self.assertFalse(hasattr(new_node.body[0], '__foo'))
 
 
 if __name__ == '__main__':
index 5e70c0e..bacc1e4 100644 (file)
@@ -25,24 +25,30 @@ import termcolor
 class PrettyPrinter(gast.NodeVisitor):
   """Print AST nodes."""
 
-  def __init__(self):
+  def __init__(self, color):
     self.indent_lvl = 0
     self.result = ''
+    self.color = color
+
+  def _color(self, string, color, attrs=None):
+    if self.color:
+      return termcolor.colored(string, color, attrs=attrs)
+    return string
 
   def _type(self, node):
-    return termcolor.colored(node.__class__.__name__, None, attrs=['bold'])
+    return self._color(node.__class__.__name__, None, ['bold'])
 
   def _field(self, name):
-    return termcolor.colored(name, 'blue')
+    return self._color(name, 'blue')
 
   def _value(self, name):
-    return termcolor.colored(name, 'magenta')
+    return self._color(name, 'magenta')
 
   def _warning(self, name):
-    return termcolor.colored(name, 'red')
+    return self._color(name, 'red')
 
   def _indent(self):
-    return termcolor.colored('| ' * self.indent_lvl, None, attrs=['dark'])
+    return self._color('| ' * self.indent_lvl, None, ['dark'])
 
   def _print(self, s):
     self.result += s
@@ -76,6 +82,16 @@ class PrettyPrinter(gast.NodeVisitor):
           self._print('%s]' % (self._indent()))
         else:
           self._print('%s%s=[]' % (self._indent(), self._field(f)))
+      elif isinstance(v, tuple):
+        if v:
+          self._print('%s%s=(' % (self._indent(), self._field(f)))
+          self.indent_lvl += 1
+          for n in v:
+            self.generic_visit(n)
+          self.indent_lvl -= 1
+          self._print('%s)' % (self._indent()))
+        else:
+          self._print('%s%s=()' % (self._indent(), self._field(f)))
       elif isinstance(v, gast.AST):
         self.generic_visit(v, f)
       elif isinstance(v, str):
@@ -87,8 +103,8 @@ class PrettyPrinter(gast.NodeVisitor):
     self.indent_lvl -= 1
 
 
-def fmt(node):
-  printer = PrettyPrinter()
+def fmt(node, color=True):
+  printer = PrettyPrinter(color)
   if isinstance(node, (list, tuple)):
     for n in node:
       printer.visit(n)
index 65e5b1d..81e3f47 100644 (file)
@@ -24,10 +24,6 @@ from tensorflow.contrib.py2tf.pyct import pretty_printer
 from tensorflow.python.platform import test
 
 
-def f(x):
-  return x + 1
-
-
 class PrettyPrinterTest(test.TestCase):
 
   def test_format(self):
diff --git a/tensorflow/contrib/py2tf/pyct/qual_names.py b/tensorflow/contrib/py2tf/pyct/qual_names.py
new file mode 100644 (file)
index 0000000..11e3838
--- /dev/null
@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for manipulating qualified names.
+
+A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite
+(e.g. 'foo.bar') syntactic symbols.
+
+This is *not* related to the __qualname__ attribute used by inspect, which
+refers to scopes.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+
+
+class QN(object):
+  """Represents a qualified name.
+
+  """
+
+  def __init__(self, base, attr=None):
+    if attr:
+      if not isinstance(base, QN):
+        raise ValueError('For attribute QNs, base must be a QN.')
+      self._parent = base
+      self.qn = base.qn + (attr,)
+    else:
+      self._parent = None
+      self.qn = tuple(base.split('.'))
+
+  def is_composite(self):
+    return len(self.qn) > 1
+
+  @property
+  def parent(self):
+    if self._parent is None:
+      raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
+    return self._parent
+
+  def __hash__(self):
+    return hash(self.qn)
+
+  def __eq__(self, other):
+    return self.qn == other.qn
+
+  def __str__(self):
+    return '.'.join(self.qn)
+
+  def __repr__(self):
+    return str(self)
+
+  def ssf(self):
+    """Simple symbol form."""
+    return '_'.join(self.qn)
+
+  def ast(self):
+    # The caller must adjust the context appropriately.
+    if self.is_composite():
+      return gast.Attribute(self.parent.ast(), self.qn[-1], None)
+    return gast.Name(self.qn[0], None, None)
+
+
+class QnResolver(gast.NodeTransformer):
+  """Annotates nodes with QN information.
+
+  Note: Not using NodeAnnos to avoid circular dependencies.
+  """
+
+  def visit_Name(self, node):
+    self.generic_visit(node)
+    anno.setanno(node, anno.Basic.QN, QN(node.id))
+    return node
+
+  def visit_Attribute(self, node):
+    self.generic_visit(node)
+    anno.setanno(node, anno.Basic.QN,
+                 QN(anno.getanno(node.value, anno.Basic.QN), node.attr))
+    return node
+
+
+def resolve(node):
+  return QnResolver().visit(node)
index 32e2954..fbfce18 100644 (file)
@@ -17,7 +17,8 @@ filegroup(
 py_library(
     name = "static_analysis",
     srcs = [
-        "access.py",
+        "activity.py",
+        "annos.py",
         "live_values.py",
         "type_info.py",
     ],
@@ -30,8 +31,8 @@ py_library(
 )
 
 py_test(
-    name = "access_test",
-    srcs = ["access_test.py"],
+    name = "activity_test",
+    srcs = ["activity_test.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":static_analysis",
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
deleted file mode 100644 (file)
index df0283b..0000000
+++ /dev/null
@@ -1,236 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for access module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.contrib.py2tf.pyct import anno
-from tensorflow.contrib.py2tf.pyct import context
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.python.platform import test
-
-
-class ScopeTest(test.TestCase):
-
-  def test_basic(self):
-    scope = access.Scope(None)
-    self.assertFalse(scope.has('foo'))
-
-    scope.mark_read('foo')
-    self.assertFalse(scope.has('foo'))
-
-    scope.mark_write('foo')
-    self.assertTrue(scope.has('foo'))
-
-    scope.mark_read('bar')
-    self.assertFalse(scope.has('bar'))
-
-  def test_copy(self):
-    scope = access.Scope(None)
-    scope.mark_write('foo')
-
-    other = access.Scope(None)
-    other.copy_from(scope)
-
-    self.assertTrue('foo' in other.created)
-
-    scope.mark_write('bar')
-    scope.copy_from(other)
-
-    self.assertFalse('bar' in scope.created)
-
-    scope.mark_write('bar')
-    scope.merge_from(other)
-
-    self.assertTrue('bar' in scope.created)
-    self.assertFalse('bar' in other.created)
-
-  def test_nesting(self):
-    scope = access.Scope(None)
-    scope.mark_write('foo')
-    scope.mark_read('bar')
-
-    child = access.Scope(scope)
-    self.assertTrue(child.has('foo'))
-    self.assertTrue(scope.has('foo'))
-
-    child.mark_write('bar')
-    self.assertTrue(child.has('bar'))
-    self.assertFalse(scope.has('bar'))
-
-  def test_referenced(self):
-    scope = access.Scope(None)
-    scope.mark_read('a')
-
-    child = access.Scope(scope)
-    child.mark_read('b')
-
-    child2 = access.Scope(child, isolated=False)
-    child2.mark_read('c')
-
-    self.assertTrue('c' in child2.referenced)
-    self.assertTrue('b' in child2.referenced)
-    self.assertFalse('a' in child2.referenced)
-
-    self.assertTrue('c' in child.referenced)
-    self.assertTrue('b' in child.referenced)
-    self.assertFalse('a' in child.referenced)
-
-
-class AccessResolverTest(test.TestCase):
-
-  def _parse_and_analyze(self, test_fn):
-    node, source = parser.parse_entity(test_fn)
-    ctx = context.EntityContext(
-        namer=None,
-        source_code=source,
-        source_file=None,
-        namespace={},
-        arg_values=None,
-        arg_types=None,
-        recursive=True)
-    node = access.resolve(node, ctx)
-    return node
-
-  def test_local_markers(self):
-
-    def test_fn(a):  # pylint:disable=unused-argument
-      b = c  # pylint:disable=undefined-variable
-      while b > 0:
-        b -= 1
-      return b
-
-    node = self._parse_and_analyze(test_fn)
-    self.assertFalse(anno.getanno(node.body[0].body[0].value,
-                                  'is_local'))  # c in b = c
-    self.assertTrue(anno.getanno(node.body[0].body[1].test.left,
-                                 'is_local'))  # b in b > 0
-    self.assertTrue(anno.getanno(node.body[0].body[2].value,
-                                 'is_local'))  # b in return b
-
-  def assertScopeIs(self, scope, used, modified, created):
-    self.assertItemsEqual(used, scope.used)
-    self.assertItemsEqual(modified, scope.modified)
-    self.assertItemsEqual(created, scope.created)
-
-  def test_print_statement(self):
-
-    def test_fn(a):
-      b = 0
-      c = 1
-      print(a, b)
-      return c
-
-    node = self._parse_and_analyze(test_fn)
-    print_node = node.body[0].body[2]
-    if isinstance(print_node, gast.Print):
-      # Python 2
-      print_args_scope = anno.getanno(print_node, 'args_scope')
-    else:
-      # Python 3
-      assert isinstance(print_node, gast.Expr)
-      # The call node should be the one being annotated.
-      print_node = print_node.value
-      print_args_scope = anno.getanno(print_node, 'args_scope')
-    # We basically need to detect which variables are captured by the call
-    # arguments.
-    self.assertScopeIs(print_args_scope, ('a', 'b'), (), ())
-
-  def test_call(self):
-
-    def test_fn(a):
-      b = 0
-      c = 1
-      foo(a, b)  # pylint:disable=undefined-variable
-      return c
-
-    node = self._parse_and_analyze(test_fn)
-    call_node = node.body[0].body[2].value
-    # We basically need to detect which variables are captured by the call
-    # arguments.
-    self.assertScopeIs(
-        anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
-
-  def test_while(self):
-
-    def test_fn(a):
-      b = a
-      while b > 0:
-        c = b
-        b -= 1
-      return b, c
-
-    node = self._parse_and_analyze(test_fn)
-    while_node = node.body[0].body[1]
-    self.assertScopeIs(
-        anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
-    self.assertScopeIs(
-        anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'),
-        ('b', 'c'), ('a', 'b', 'c'))
-
-  def test_for(self):
-
-    def test_fn(a):
-      b = a
-      for _ in a:
-        c = b
-        b -= 1
-      return b, c
-
-    node = self._parse_and_analyze(test_fn)
-    for_node = node.body[0].body[1]
-    self.assertScopeIs(
-        anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
-    self.assertScopeIs(
-        anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'),
-        ('b', 'c', '_'), ('a', 'b', 'c', '_'))
-
-  def test_if(self):
-
-    def test_fn(x):
-      if x > 0:
-        x = -x
-        y = 2 * x
-        z = -y
-      else:
-        x = 2 * x
-        y = -x
-        u = -y
-      return z, u
-
-    node = self._parse_and_analyze(test_fn)
-    if_node = node.body[0].body[0]
-    self.assertScopeIs(
-        anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'),
-        ('y', 'z'))
-    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
-    self.assertScopeIs(
-        anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
-        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
-    self.assertScopeIs(
-        anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'),
-        ('y', 'u'))
-    self.assertScopeIs(
-        anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
-        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
-
-
-if __name__ == '__main__':
-  test.main()
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Access information (reads, writes) resolution."""
+"""Activity analysis."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -24,6 +24,7 @@ import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 # TODO(mdan): Add support for PY3 (e.g. Param vs arg).
 
@@ -112,6 +113,14 @@ class Scope(object):
     self.params.add(name)
 
   def mark_creation(self, name):
+    if name.is_composite():
+      parent = name.parent
+      if self.has(parent):
+        # This is considered mutation of the parent, not creation.
+        # TODO(mdan): Is that really so?
+        return
+      else:
+        raise ValueError('Unknown symbol "%s".' % parent)
     self.created.add(name)
 
   def mark_write(self, name):
@@ -132,39 +141,48 @@ class Scope(object):
       self.parent.mark_returned(name)
 
 
-class AccessResolver(transformer.Base):
+class ActivityAnalizer(transformer.Base):
   """Annotates nodes with local scope information. See Scope."""
 
-  def __init__(self, context):
-    super(AccessResolver, self).__init__(context)
-    self.scope = Scope(None)
+  def __init__(self, context, parent_scope):
+    super(ActivityAnalizer, self).__init__(context)
+    self.scope = Scope(parent_scope)
     self._in_return_statement = False
 
-  def visit_Name(self, node):
-    # TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate.
-    self.generic_visit(node)
+  def _track_symbol(self, node):
+    qn = anno.getanno(node, anno.Basic.QN)
+
     if isinstance(node.ctx, gast.Store):
-      self.scope.mark_write(node.id)
+      self.scope.mark_write(qn)
     elif isinstance(node.ctx, gast.Load):
-      anno.setanno(node, 'is_local', self.scope.has(node.id))
-      self.scope.mark_read(node.id)
+      self.scope.mark_read(qn)
     elif isinstance(node.ctx, gast.Param):
       # Param contexts appear in function defs, so they have the meaning of
       # defining a variable.
       # TODO(mdan): This bay be incorrect with nested functions.
       # For nested functions, we'll have to add the notion of hiding args from
       # the parent scope, not writing to them.
-      self.scope.mark_creation(node.id)
-      self.scope.mark_param(node.id)
+      self.scope.mark_creation(qn)
+      self.scope.mark_param(qn)
     else:
-      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx),
-                                                            node.id))
-    anno.setanno(node, 'is_modified_since_entry',
-                 self.scope.is_modified_since_entry(node.id))
-    anno.setanno(node, 'is_param', self.scope.is_param(node.id))
+      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
+
+    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
+    anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY,
+                 self.scope.is_modified_since_entry(qn))
+    anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn))
 
     if self._in_return_statement:
-      self.scope.mark_returned(node.id)
+      self.scope.mark_returned(qn)
+
+  def visit_Name(self, node):
+    self.generic_visit(node)
+    self._track_symbol(node)
+    return node
+
+  def visit_Attribute(self, node):
+    self.generic_visit(node)
+    self._track_symbol(node)
     return node
 
   def visit_Print(self, node):
@@ -173,7 +191,7 @@ class AccessResolver(transformer.Base):
     self.scope = args_scope
     for n in node.values:
       self.visit(n)
-    anno.setanno(node, 'args_scope', args_scope)
+    anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
     self.scope = current_scope
     return node
 
@@ -186,7 +204,7 @@ class AccessResolver(transformer.Base):
     # TODO(mdan): Account starargs, kwargs
     for n in node.keywords:
       self.visit(n)
-    anno.setanno(node, 'args_scope', args_scope)
+    anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
     self.scope = current_scope
     self.visit(node.func)
     return node
@@ -197,7 +215,7 @@ class AccessResolver(transformer.Base):
     self.scope = block_scope
     for n in block:
       self.visit(n)
-    anno.setanno(node, '%s_scope' % scope_name, block_scope)
+    anno.setanno(node, scope_name, block_scope)
     self.scope = current_scope
     return node
 
@@ -209,36 +227,36 @@ class AccessResolver(transformer.Base):
     before_parent = Scope(None)
     before_parent.copy_from(self.scope)
     after_children = []
-    for child, name in children:
+    for child, scope_name in children:
       self.scope.copy_from(before_parent)
-      parent = self._process_block_node(parent, child, name)
+      parent = self._process_block_node(parent, child, scope_name)
       after_child = Scope(None)
       after_child.copy_from(self.scope)
       after_children.append(after_child)
     for after_child in after_children:
       self.scope.merge_from(after_child)
-    for child, name in children:
-      # TODO(mdan): We don't need this - we have the parent link from scope.
-      anno.setanno(parent, '%s_parent_scope' % name, self.scope)
     return parent
 
   def visit_If(self, node):
     self.visit(node.test)
-    node = self._process_parallel_blocks(
-        node, ((node.body, 'body'), (node.orelse, 'orelse')))
+    node = self._process_parallel_blocks(node,
+                                         ((node.body, NodeAnno.BODY_SCOPE),
+                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
     return node
 
   def visit_For(self, node):
     self.visit(node.target)
     self.visit(node.iter)
-    node = self._process_parallel_blocks(
-        node, ((node.body, 'body'), (node.orelse, 'orelse')))
+    node = self._process_parallel_blocks(node,
+                                         ((node.body, NodeAnno.BODY_SCOPE),
+                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
     return node
 
   def visit_While(self, node):
     self.visit(node.test)
-    node = self._process_parallel_blocks(
-        node, ((node.body, 'body'), (node.orelse, 'orelse')))
+    node = self._process_parallel_blocks(node,
+                                         ((node.body, NodeAnno.BODY_SCOPE),
+                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
     return node
 
   def visit_Return(self, node):
@@ -248,5 +266,5 @@ class AccessResolver(transformer.Base):
     return node
 
 
-def resolve(node, context):
-  return AccessResolver(context).visit(node)
+def resolve(node, context, parent_scope=None):
+  return ActivityAnalizer(context, parent_scope).visit(node)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py
new file mode 100644 (file)
index 0000000..e1eb954
--- /dev/null
@@ -0,0 +1,271 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for activity module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import context
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.qual_names import QN
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.platform import test
+
+
+class ScopeTest(test.TestCase):
+
+  def test_basic(self):
+    scope = activity.Scope(None)
+    self.assertFalse(scope.has(QN('foo')))
+
+    scope.mark_read(QN('foo'))
+    self.assertFalse(scope.has(QN('foo')))
+
+    scope.mark_write(QN('foo'))
+    self.assertTrue(scope.has(QN('foo')))
+
+    scope.mark_read(QN('bar'))
+    self.assertFalse(scope.has(QN('bar')))
+
+  def test_copy(self):
+    scope = activity.Scope(None)
+    scope.mark_write(QN('foo'))
+
+    other = activity.Scope(None)
+    other.copy_from(scope)
+
+    self.assertTrue(QN('foo') in other.created)
+
+    scope.mark_write(QN('bar'))
+    scope.copy_from(other)
+
+    self.assertFalse(QN('bar') in scope.created)
+
+    scope.mark_write(QN('bar'))
+    scope.merge_from(other)
+
+    self.assertTrue(QN('bar') in scope.created)
+    self.assertFalse(QN('bar') in other.created)
+
+  def test_nesting(self):
+    scope = activity.Scope(None)
+    scope.mark_write(QN('foo'))
+    scope.mark_read(QN('bar'))
+
+    child = activity.Scope(scope)
+    self.assertTrue(child.has(QN('foo')))
+    self.assertTrue(scope.has(QN('foo')))
+
+    child.mark_write(QN('bar'))
+    self.assertTrue(child.has(QN('bar')))
+    self.assertFalse(scope.has(QN('bar')))
+
+  def test_referenced(self):
+    scope = activity.Scope(None)
+    scope.mark_read(QN('a'))
+
+    child = activity.Scope(scope)
+    child.mark_read(QN('b'))
+
+    child2 = activity.Scope(child, isolated=False)
+    child2.mark_read(QN('c'))
+
+    self.assertTrue(QN('c') in child2.referenced)
+    self.assertTrue(QN('b') in child2.referenced)
+    self.assertFalse(QN('a') in child2.referenced)
+
+    self.assertTrue(QN('c') in child.referenced)
+    self.assertTrue(QN('b') in child.referenced)
+    self.assertFalse(QN('a') in child.referenced)
+
+
+class ActivityAnalizerTest(test.TestCase):
+
+  def _parse_and_analyze(self, test_fn):
+    node, source = parser.parse_entity(test_fn)
+    ctx = context.EntityContext(
+        namer=None,
+        source_code=source,
+        source_file=None,
+        namespace={},
+        arg_values=None,
+        arg_types=None,
+        recursive=True)
+    node = qual_names.resolve(node)
+    node = activity.resolve(node, ctx)
+    return node
+
+  def test_local_markers(self):
+
+    def test_fn(a):  # pylint:disable=unused-argument
+      b = c  # pylint:disable=undefined-variable
+      while b > 0:
+        b -= 1
+      return b
+
+    node = self._parse_and_analyze(test_fn)
+    self.assertFalse(
+        anno.getanno(node.body[0].body[0].value,
+                     NodeAnno.IS_LOCAL))  # c in b = c
+    self.assertTrue(
+        anno.getanno(node.body[0].body[1].test.left,
+                     NodeAnno.IS_LOCAL))  # b in b > 0
+    self.assertTrue(
+        anno.getanno(node.body[0].body[2].value,
+                     NodeAnno.IS_LOCAL))  # b in return b
+
+  def assertScopeIs(self, scope, used, modified, created):
+    self.assertItemsEqual(used, tuple(str(s) for s in scope.used))
+    self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified))
+    self.assertItemsEqual(created, tuple(str(s) for s in scope.created))
+
+  def test_print_statement(self):
+
+    def test_fn(a):
+      b = 0
+      c = 1
+      print(a, b)
+      return c
+
+    node = self._parse_and_analyze(test_fn)
+    print_node = node.body[0].body[2]
+    if isinstance(print_node, gast.Print):
+      # Python 2
+      print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+    else:
+      # Python 3
+      assert isinstance(print_node, gast.Expr)
+      # The call node should be the one being annotated.
+      print_node = print_node.value
+      print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+    # We basically need to detect which variables are captured by the call
+    # arguments.
+    self.assertScopeIs(print_args_scope, ('a', 'b'), (), ())
+
+  def test_call(self):
+
+    def test_fn(a):
+      b = 0
+      c = 1
+      foo(a, b)  # pylint:disable=undefined-variable
+      return c
+
+    node = self._parse_and_analyze(test_fn)
+    call_node = node.body[0].body[2].value
+    # We basically need to detect which variables are captured by the call
+    # arguments.
+    self.assertScopeIs(
+        anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+
+  def test_while(self):
+
+    def test_fn(a):
+      b = a
+      while b > 0:
+        c = b
+        b -= 1
+      return b, c
+
+    node = self._parse_and_analyze(test_fn)
+    while_node = node.body[0].body[1]
+    self.assertScopeIs(
+        anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
+        ('c',))
+    self.assertScopeIs(
+        anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+        ('b', 'c'), ('a', 'b', 'c'))
+
+  def test_for(self):
+
+    def test_fn(a):
+      b = a
+      for _ in a:
+        c = b
+        b -= 1
+      return b, c
+
+    node = self._parse_and_analyze(test_fn)
+    for_node = node.body[0].body[1]
+    self.assertScopeIs(
+        anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
+    self.assertScopeIs(
+        anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+        ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+
+  def test_if(self):
+
+    def test_fn(x):
+      if x > 0:
+        x = -x
+        y = 2 * x
+        z = -y
+      else:
+        x = 2 * x
+        y = -x
+        u = -y
+      return z, u
+
+    node = self._parse_and_analyze(test_fn)
+    if_node = node.body[0].body[0]
+    self.assertScopeIs(
+        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
+        ('y', 'z'))
+    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
+    self.assertScopeIs(
+        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
+        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+    self.assertScopeIs(
+        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
+        ('x', 'y', 'u'), ('y', 'u'))
+    self.assertScopeIs(
+        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+
+  def test_call_with_composite_names(self):
+
+    def foo(*_):
+      pass
+
+    def test_fn(a):
+      foo(a.b, a.c)
+      if a > 0:
+        a.b = 2
+      else:
+        d = 2
+        d.e = a.c
+        f = d.e + 1
+        a.c = f
+
+    node = self._parse_and_analyze(test_fn)
+    call_node = node.body[0].body[0].value
+    self.assertScopeIs(
+        anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (),
+        ())
+    if_node = node.body[0].body[1]
+    self.assertScopeIs(
+        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ())
+    self.assertScopeIs(
+        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+        ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py
new file mode 100644 (file)
index 0000000..2d8e494
--- /dev/null
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Annotations used by the static analizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from enum import Enum
+
+
+class NoValue(Enum):
+
+  def __repr__(self):
+    return self.name
+
+
+class NodeAnno(NoValue):
+  """Additionnal annotations used by the static analyzer.
+
+  These are in addition to the basic annotations declared in anno.py.
+  """
+
+  # Symbols
+
+  IS_LOCAL = 'Symbol is local to the function scope being analized.'
+  IS_PARAM = 'Symbol is a parameter to the function being analized.'
+  IS_MODIFIED_SINCE_ENTRY = (
+      'Symbol has been explicitly replaced in the current function scope.')
+
+  # Scopes
+  ARGS_SCOPE = 'The scope for the argument list of a function call.'
+  BODY_SCOPE = (
+      'The scope for the main body of a statement (True branch for if '
+      'statements, main body for loops).')
+  ORELSE_SCOPE = (
+      'The scope for the orelse body of a statement (False branch for if '
+      'statements, orelse body for loops).')
index 5a2903e..9c0a9a9 100644 (file)
@@ -16,7 +16,7 @@
 
 Live values are extracted from the known execution context.
 
-Requires annotations generated by AccessResolver.
+Requires activity analysis annotations.
 """
 
 from __future__ import absolute_import
@@ -27,6 +27,7 @@ import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
 
 
 class LiveValueResolver(transformer.Base):
@@ -44,12 +45,12 @@ class LiveValueResolver(transformer.Base):
   def visit_Name(self, node):
     self.generic_visit(node)
     if isinstance(node.ctx, gast.Load):
-      assert anno.hasanno(node, 'is_local'), node
-      symbol_is_local = anno.getanno(node, 'is_local')
-      assert anno.hasanno(node, 'is_modified_since_entry'), node
-      symbol_is_modified = anno.getanno(node, 'is_modified_since_entry')
-      assert anno.hasanno(node, 'is_param'), node
-      symbol_is_param = anno.getanno(node, 'is_param')
+      assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
+      symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
+      assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
+      symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY)
+      assert anno.hasanno(node, NodeAnno.IS_PARAM), node
+      symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)
 
       if not symbol_is_local and not symbol_is_param:
         if node.id in self.literals:
@@ -60,7 +61,11 @@ class LiveValueResolver(transformer.Base):
           anno.setanno(node, 'live_val', obj)
           anno.setanno(node, 'fqn', (obj.__name__,))
         else:
-          raise ValueError('Could not resolve symbol "%s".' % node.id)
+          pass
+          # TODO(mdan): Should we raise an error here?
+          # Can encounter this when:
+          #  * a symbol truly lacks reference
+          #  * a symbol is new, like the new name of a function we just renamed.
       else:
         pass
         # TODO(mdan): Attempt to trace its value through the local chain.
@@ -97,7 +102,7 @@ class LiveValueResolver(transformer.Base):
     elif isinstance(node.value, gast.Name):
       stem_name = node.value
       # All nonlocal symbols should be fully resolved.
-      assert anno.hasanno(stem_name, 'is_local'), stem_name
+      assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
       # TODO(mdan): Figure out what to do when calling attribute on local object
       # Maybe just leave as-is?
     return node
index f3057b3..9f64689 100644 (file)
@@ -21,7 +21,8 @@ from __future__ import print_function
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import context
 from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
 from tensorflow.python.framework import constant_op
@@ -46,7 +47,8 @@ class LiveValuesResolverTest(test.TestCase):
         arg_values=None,
         arg_types=arg_types,
         recursive=True)
-    node = access.resolve(node, ctx)
+    node = qual_names.resolve(node)
+    node = activity.resolve(node, ctx)
     node = live_values.resolve(node, ctx, literals)
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, literals)
index cf74142..8203bda 100644 (file)
@@ -116,28 +116,30 @@ class TypeInfoResolver(transformer.Base):
     return node
 
   def _process_function_arg(self, arg_name):
-    if self.function_level == 1 and arg_name in self.context.arg_types:
+    str_name = str(arg_name)
+    if self.function_level == 1 and str_name in self.context.arg_types:
       # Forge a node to hold the type information, so that method calls on
       # it can resolve the type.
-      type_holder = gast.Name(arg_name, gast.Load(), None)
-      type_string, type_obj = self.context.arg_types[arg_name]
+      type_holder = arg_name.ast()
+      type_string, type_obj = self.context.arg_types[str_name]
       anno.setanno(type_holder, 'type', type_obj)
       anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
       self.scope.setval(arg_name, type_holder)
 
   def visit_arg(self, node):
-    self._process_function_arg(node.arg)
+    self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
     return node
 
   def visit_Name(self, node):
     self.generic_visit(node)
+    qn = anno.getanno(node, anno.Basic.QN)
     if isinstance(node.ctx, gast.Param):
-      self._process_function_arg(node.id)
-    elif isinstance(node.ctx, gast.Load) and self.scope.hasval(node.id):
+      self._process_function_arg(qn)
+    elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
       # E.g. if we had
       # a = b
       # then for future references to `a` we should have traced_source = `b`
-      traced_source = self.scope.getval(node.id)
+      traced_source = self.scope.getval(qn)
       if anno.hasanno(traced_source, 'type'):
         anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
         anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
@@ -159,16 +161,11 @@ class TypeInfoResolver(transformer.Base):
     for t in targets:
       if isinstance(t, gast.Tuple):
         for i, e in enumerate(t.elts):
-          self.scope.setval(e.id,
-                            gast.Subscript(
-                                source, gast.Index(i), ctx=gast.Store()))
-      elif isinstance(t, gast.Name):
-        self.scope.setval(t.id, source)
-      elif isinstance(t, gast.Attribute):
-        if not (isinstance(t.value, gast.Name) and t.value.id == 'self'):
-          raise ValueError(
-              'Dont know how to handle assignment to attributes of objects'
-              ' other than "self": [%s].%s' % (t.value, t.attr))
+          self.scope.setval(
+              anno.getanno(e, anno.Basic.QN),
+              gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
+      elif isinstance(t, (gast.Name, gast.Attribute)):
+        self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
       else:
         raise ValueError('Dont know how to handle assignment to %s' % t)
 
index 68fa1ee..3659f94 100644 (file)
@@ -21,7 +21,8 @@ from __future__ import print_function
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import context
 from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
 from tensorflow.python.client import session
@@ -65,7 +66,8 @@ class TypeInfoResolverTest(test.TestCase):
         arg_values=None,
         arg_types=arg_types,
         recursive=True)
-    node = access.resolve(node, ctx)
+    node = qual_names.resolve(node)
+    node = activity.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
index 6be526f..1039fc8 100644 (file)
@@ -22,12 +22,13 @@ from __future__ import division
 from __future__ import print_function
 
 import ast
-import copy
 import textwrap
 
 import gast
 
+from tensorflow.contrib.py2tf.pyct import copier
 from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
 
 
 class ReplaceTransformer(gast.NodeTransformer):
@@ -41,6 +42,7 @@ class ReplaceTransformer(gast.NodeTransformer):
           that these placeholders will be replaced by.
     """
     self.replacements = replacements
+    self.in_replacements = False
 
   # TODO(mdan): Make a more detailed pass and clean up if needed.
 
@@ -62,34 +64,53 @@ class ReplaceTransformer(gast.NodeTransformer):
       node.name = repl.id
     return node
 
-  def visit_Name(self, node):
-    if node.id in self.replacements:
-      # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
-      new_nodes = copy.copy(self.replacements[node.id])
-      if isinstance(new_nodes, gast.AST):
-        new_nodes = [new_nodes]
-      # Preserve the target context.
-      for n in new_nodes:
-        if isinstance(n, gast.Tuple):
-          for e in n.elts:
-            e.ctx = node.ctx
-        n.ctx = node.ctx
-      if len(new_nodes) == 1:
-        new_nodes, = new_nodes
-      return new_nodes
+  def _set_inner_child_context(self, node, ctx):
+    if isinstance(node, gast.Attribute):
+      self._set_inner_child_context(node.value, ctx)
+      node.ctx = gast.Load()
+    elif isinstance(node, gast.Name):
+      node.ctx = ctx
     else:
+      raise ValueError('unexpected node type "%s"' % node)
+
+  def visit_Name(self, node):
+    if node.id not in self.replacements:
       return node
 
+    new_nodes = copier.copy_clean(self.replacements[node.id])
+    if isinstance(new_nodes, gast.AST):
+      new_nodes = [new_nodes]
+
+    # Preserve the target context.
+    for n in new_nodes:
+      if isinstance(n, gast.Tuple):
+        for e in n.elts:
+          self._set_inner_child_context(e, node.ctx)
+      if isinstance(n, gast.Attribute):
+        # For attributes, the inner Name node receives the context, while the
+        # outer ones have it set to Load.
+        self._set_inner_child_context(n, node.ctx)
+      else:
+        n.ctx = node.ctx
+
+    if len(new_nodes) == 1:
+      new_nodes, = new_nodes
+
+    return new_nodes
+
 
-def _strings_to_names(n):
+def _convert_to_ast(n):
+  """Convert from a known data type to AST."""
   if isinstance(n, str):
     # Note: the node will receive the ctx value from the template, see
     # ReplaceTransformer.visit_Name.
     return gast.Name(id=n, ctx=None, annotation=None)
+  if isinstance(n, qual_names.QN):
+    return n.ast()
   if isinstance(n, list):
-    return [_strings_to_names(e) for e in n]
+    return [_convert_to_ast(e) for e in n]
   if isinstance(n, tuple):
-    return tuple(_strings_to_names(e) for e in n)
+    return tuple(_convert_to_ast(e) for e in n)
   return n
 
 
@@ -122,5 +143,8 @@ def replace(template, **replacements):
     raise ValueError('Expected string template, got %s' % type(template))
   tree = parser.parse_str(textwrap.dedent(template))
   for k in replacements:
-    replacements[k] = _strings_to_names(replacements[k])
-  return ReplaceTransformer(replacements).visit(tree).body
+    replacements[k] = _convert_to_ast(replacements[k])
+  results = ReplaceTransformer(replacements).visit(tree).body
+  if isinstance(results, list):
+    return [qual_names.resolve(r) for r in results]
+  return qual_names.resolve(results)
index 1143131..0e3d07e 100644 (file)
@@ -27,6 +27,16 @@ from tensorflow.python.platform import test
 
 class TemplatesTest(test.TestCase):
 
+  def test_replace_tuple(self):
+    template = """
+      def test_fn(a, c):
+        return b,
+    """
+
+    node = templates.replace(template, b=('a', 'c'))[0]
+    result = compiler.ast_to_object(node)
+    self.assertEquals((2, 3), result.test_fn(2, 3))
+
   def test_replace_variable(self):
     template = """
       def test_fn(a):
index 8a836b7..877d52a 100644 (file)
@@ -23,6 +23,7 @@ import sys
 import gast
 import six
 
+from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import pretty_printer
 
 
@@ -44,16 +45,19 @@ class Base(gast.NodeTransformer):
     self.context = context
 
   def visit(self, node):
+    source_code = self.context.source_code
+    source_file = self.context.source_file
     try:
-      source_code = self.context.source_code
-      source_file = self.context.source_file
       if source_code and hasattr(node, 'lineno'):
         self._lineno = node.lineno
         self._col_offset = node.col_offset
+      if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+        return node
       return super(Base, self).visit(node)
-    except (ValueError, AttributeError, NotImplementedError) as e:
-      msg = '%s: %s\nOccurred at node:\n%s' % (e.__class__.__name__, str(e),
-                                               pretty_printer.fmt(node))
+    except (ValueError, AttributeError, KeyError, NotImplementedError,
+            AssertionError) as e:
+      msg = '%s: %s\nOccurred at node:\n%s' % (
+          e.__class__.__name__, str(e), pretty_printer.fmt(node, color=False))
       if source_code:
         line = source_code.splitlines()[self._lineno - 1]
       else: