More updates:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Feb 2018 17:03:30 +0000 (09:03 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 17:15:21 +0000 (09:15 -0800)
 * add a verbose flag, useful for debugging
 * fix the canonicalization of old-style print statement
 * fix broken py_func wrapper
 * expand the conditional statement to return a dummy value if we cannot fine any return values, and call without an assignment so that the side effect guards catch it instead
 * streamline the converter tests a bit more
 * avoid aliasing "tf" and "self" in the side effect guards
 * improve the namer to generate shorter names

PiperOrigin-RevId: 185002802

25 files changed:
tensorflow/contrib/py2tf/converters/BUILD
tensorflow/contrib/py2tf/converters/asserts.py
tensorflow/contrib/py2tf/converters/break_canonicalization_test.py
tensorflow/contrib/py2tf/converters/builtin_functions.py
tensorflow/contrib/py2tf/converters/builtin_functions_test.py
tensorflow/contrib/py2tf/converters/call_trees.py
tensorflow/contrib/py2tf/converters/call_trees_test.py
tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py
tensorflow/contrib/py2tf/converters/control_flow.py
tensorflow/contrib/py2tf/converters/control_flow_test.py
tensorflow/contrib/py2tf/converters/converter_test_base.py
tensorflow/contrib/py2tf/converters/decorators_test.py
tensorflow/contrib/py2tf/converters/for_canonicalization_test.py
tensorflow/contrib/py2tf/converters/logical_expressions_test.py
tensorflow/contrib/py2tf/converters/side_effect_guards.py
tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
tensorflow/contrib/py2tf/impl/api.py
tensorflow/contrib/py2tf/impl/naming.py
tensorflow/contrib/py2tf/pyct/BUILD
tensorflow/contrib/py2tf/pyct/compiler.py
tensorflow/contrib/py2tf/pyct/compiler_test.py
tensorflow/contrib/py2tf/pyct/qual_names.py
tensorflow/contrib/py2tf/pyct/qual_names_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/pyct/templates.py
tensorflow/contrib/py2tf/pyct/templates_test.py

index 68ea2477780cc7c8c54b228ed1ca5bac78c9622a..62de8107a3db562bcf5cd57067d5176c2b71aa27 100644 (file)
@@ -47,6 +47,7 @@ py_library(
         "//tensorflow/contrib/py2tf/pyct/static_analysis",
         "//tensorflow/contrib/py2tf/utils",
         "@gast_archive//:gast",
+        "@six_archive//:six",
     ],
 )
 
@@ -73,8 +74,8 @@ py_test(
 )
 
 py_test(
-    name = "call_trees_test",
-    srcs = ["call_trees_test.py"],
+    name = "builtin_functions_test",
+    srcs = ["builtin_functions_test.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
@@ -84,8 +85,8 @@ py_test(
 )
 
 py_test(
-    name = "decorators_test",
-    srcs = ["decorators_test.py"],
+    name = "call_trees_test",
+    srcs = ["call_trees_test.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
@@ -117,8 +118,8 @@ py_test(
 )
 
 py_test(
-    name = "builtin_functions_test",
-    srcs = ["builtin_functions_test.py"],
+    name = "decorators_test",
+    srcs = ["decorators_test.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
index 2d6ee1d09829b538815dbb9794868c13f51578fc..5b9b8e772bed82df2429fd6cb94dbf7b565e22b3 100644 (file)
@@ -35,7 +35,7 @@ class AssertsTransformer(transformer.Base):
     # Note: The lone tf.Assert call will be wrapped with control_dependencies
     # by side_effect_guards.
     template = """
-      tf.Assert(test, [tf.constant(msg)])
+      tf.Assert(test, [msg])
     """
 
     if node.msg is None:
index 54c4d99361f00ba2b5b79323f5feddcbbdfc99e8..2243398100880483d40a1ba7451a229e0dbe115b 100644 (file)
@@ -19,18 +19,10 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.py2tf.converters import break_canonicalization
-from tensorflow.contrib.py2tf.converters import control_flow
 from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.platform import test
 
 
-class TestNamer(control_flow.SymbolNamer):
-
-  def new_symbol(self, name_root, _):
-    return name_root
-
-
 class BreakCanonicalizationTest(converter_test_base.TestCase):
 
   def test_basic_break(self):
@@ -44,15 +36,15 @@ class BreakCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = break_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    self.assertEqual(test_fn(0), result.test_fn(0))
-    self.assertEqual(test_fn(1), result.test_fn(1))
-    self.assertEqual(test_fn(2), result.test_fn(2))
-    self.assertEqual(test_fn(3), result.test_fn(3))
-    self.assertEqual(test_fn(4), result.test_fn(4))
+    with self.compiled(node) as result:
+      self.assertEqual(test_fn(0), result.test_fn(0))
+      self.assertEqual(test_fn(1), result.test_fn(1))
+      self.assertEqual(test_fn(2), result.test_fn(2))
+      self.assertEqual(test_fn(3), result.test_fn(3))
+      self.assertEqual(test_fn(4), result.test_fn(4))
 
   def test_basic_break_for_loop(self):
 
@@ -76,16 +68,17 @@ class BreakCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = break_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    # The break is incompletely canonicalized. Everything is in place, but
-    # the loop does not break.
-    self.assertEqual(test_equiv_fn([]), result.test_fn([]))
-    self.assertEqual(test_equiv_fn([1]), result.test_fn([1]))
-    self.assertEqual(test_equiv_fn([2]), result.test_fn([2]))
-    self.assertEqual(test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
+    with self.compiled(node) as result:
+      # The break is incompletely canonicalized. Everything is in place, but
+      # the loop does not break.
+      self.assertEqual(test_equiv_fn([]), result.test_fn([]))
+      self.assertEqual(test_equiv_fn([1]), result.test_fn([1]))
+      self.assertEqual(test_equiv_fn([2]), result.test_fn([2]))
+      self.assertEqual(
+          test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
 
   def test_continue_deeply_nested(self):
 
@@ -104,15 +97,15 @@ class BreakCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v, u, w
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = break_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    self.assertEqual(test_fn(0), result.test_fn(0))
-    self.assertEqual(test_fn(1), result.test_fn(1))
-    self.assertEqual(test_fn(2), result.test_fn(2))
-    self.assertEqual(test_fn(3), result.test_fn(3))
-    self.assertEqual(test_fn(4), result.test_fn(4))
+    with self.compiled(node) as result:
+      self.assertEqual(test_fn(0), result.test_fn(0))
+      self.assertEqual(test_fn(1), result.test_fn(1))
+      self.assertEqual(test_fn(2), result.test_fn(2))
+      self.assertEqual(test_fn(3), result.test_fn(3))
+      self.assertEqual(test_fn(4), result.test_fn(4))
 
 
 if __name__ == '__main__':
index 3e56634106c2c9c1e4c334d4e61cedee395853a9..310681dd016ca94bf2b28d27a4968cc0c10a5842 100644 (file)
@@ -25,7 +25,18 @@ from tensorflow.contrib.py2tf.pyct import transformer
 
 
 class BuiltinFunctionTransformer(transformer.Base):
-  """Transforms Print nodes to Call so they can be handled as functions."""
+  """Handles builtin functions and canonicalizes old-style print statement.
+
+  This transformer only covers functions that are translated into a
+  TF equivalent, like `len`.
+  Note that the `print` statement is converted to a function call here, but
+  wrapping the print function to a `py_func` is done by `call_trees` as a
+  generic uncompilable function wrap.
+  """
+
+  # TODO(mdan): Handle print entirely in here.
+  # Fully handling print here makes sense especially since we're considering
+  # using tf.Print instead.
 
   def __init__(self, context):
     super(BuiltinFunctionTransformer, self).__init__(context)
@@ -48,10 +59,14 @@ class BuiltinFunctionTransformer(transformer.Base):
 
   def visit_Print(self, node):
     self.generic_visit(node)
+    args = node.values
+    # Following is the case when calling print(a, b)
+    if len(args) == 1 and isinstance(args[0], gast.Tuple):
+      args = args[0].elts
     template = """
       fname(args)
     """
-    return templates.replace(template, fname='print', args=node.values)
+    return templates.replace(template, fname='print', args=args)
 
   # pylint:enable=invalid-name
 
index be76066242856c85784221166a70299187b11b14..983d1ffc03466ab3e2148e8cdf6e54050b9d3947 100644 (file)
@@ -18,11 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import gast
+import sys
+
+import six
 
 from tensorflow.contrib.py2tf.converters import builtin_functions
 from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
@@ -37,13 +38,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
 
     node = self.parse_and_analyze(test_fn, {'len': len})
     node = builtin_functions.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', array_ops)
 
-    with self.test_session() as sess:
-      self.assertEqual(3,
-                       sess.run(
-                           result.test_fn(constant_op.constant([0, 0, 0]))))
+    with self.compiled(node, array_ops.shape) as result:
+      with self.test_session() as sess:
+        self.assertEqual(3,
+                         sess.run(
+                             result.test_fn(constant_op.constant([0, 0, 0]))))
 
   def test_print(self):
 
@@ -52,10 +52,36 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
 
     node = self.parse_and_analyze(test_fn, {'print': print})
     node = builtin_functions.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    result.test_fn('a')
-    self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+    with self.compiled(node) as result:
+      try:
+        out_capturer = six.StringIO()
+        sys.stdout = out_capturer
+        result.test_fn('a')
+        self.assertEqual(out_capturer.getvalue(), 'a\n')
+      finally:
+        sys.stdout = sys.__stdout__
+
+  def test_print_tuple(self):
+
+    def test_fn(a, b, c):
+      print(a, b, c)
+
+    node = self.parse_and_analyze(test_fn, {'print': print})
+    node = builtin_functions.transform(node, self.ctx)
+
+    with self.compiled(node) as result:
+      try:
+        out_capturer = six.StringIO()
+        sys.stdout = out_capturer
+        result.test_fn('a', 1, [2, 3])
+        # It appears that the print output looks odd only under Python 2.
+        if six.PY2:
+          self.assertEqual(out_capturer.getvalue(), "('a', 1, [2, 3])\n")
+        else:
+          self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
+      finally:
+        sys.stdout = sys.__stdout__
 
 
 if __name__ == '__main__':
index 834baf258d3b5a0eee02e79ccdf0ebfc61b2c9df..60096d5a7b7c6eea9f3ade75ba492d54732b3550 100644 (file)
@@ -28,6 +28,7 @@ import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
 from tensorflow.contrib.py2tf.pyct import templates
 from tensorflow.contrib.py2tf.pyct import transformer
 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
@@ -197,21 +198,33 @@ class CallTreeTransformer(transformer.Base):
     return node
 
   def _wrap_to_py_func_no_return(self, node):
+    func_qn = anno.getanno(node.func, anno.Basic.QN)
     args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
+    wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
+                                                 args_scope.referenced)
+    wrapper_args = []
+    for arg in node.args:
+      if anno.hasanno(arg, anno.Basic.QN):
+        arg_qn = anno.getanno(arg, anno.Basic.QN)
+      else:
+        arg_qn = qual_names.QN('arg')
+      wrapper_args.append(
+          self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced))
     # TODO(mdan): Properly handle varargs, kwargs, etc.
+    # TODO(mdan): This is best handled as a dynamic dispatch.
+    # That way we can separate tensors from non-tensor args.
     template = """
-      def wrapper(args):
-        call(args)
+      def wrapper(wrapper_args):
+        call(wrapper_args)
         return 1
-      tf.py_func(wrapper, [args], [tf.int64])
+      tf.py_func(wrapper, original_args, [tf.int64])
     """
     wrapper_def, call_expr = templates.replace(
         template,
         call=node.func,
-        wrapper=self.context.namer.compiled_function_name(node.func.id)[0],
-        args=tuple(args_scope.used))
-    anno.setanno(call_expr.value, NodeAnno.ARGS_SCOPE, args_scope)
-    # TODO(mdan): Rename this annotation to 'graph_ready'
+        wrapper=wrapper_name,
+        original_args=gast.List(elts=node.args, ctx=None),
+        wrapper_args=wrapper_args)
     anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)
 
     return (wrapper_def, call_expr)
index e63c10de0fed72333a6d571f9b9a4f1cb50b5f1d..18a5c1e6e35f49a01649c17e6cd5647389e1f526 100644 (file)
@@ -20,23 +20,11 @@ from __future__ import print_function
 
 from tensorflow.contrib.py2tf.converters import call_trees
 from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
-class TestNamer(call_trees.FunctionNamer):
-
-  def compiled_function_name(self,
-                             original_fqn,
-                             live_entity=None,
-                             owner_type=None):
-    if owner_type is not None:
-      return None, False
-    return ('renamed_%s' % '_'.join(original_fqn)), True
-
-
 class CallTreesTest(converter_test_base.TestCase):
 
   def test_basic(self):
@@ -50,14 +38,14 @@ class CallTreesTest(converter_test_base.TestCase):
     def test_fn_2(a):
       return test_fn_1(a) + 1
 
-    node = self.parse_and_analyze(
-        test_fn_2, {'test_fn_1': test_fn_1}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
     node = call_trees.transform(node, self.ctx, (), ())
-    result = compiler.ast_to_object(node)
-    # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
-    setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
 
-    self.assertEquals(3, result.test_fn_2(1))
+    with self.compiled(node) as result:
+      # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1
+      # manually.
+      result.renamed_test_fn_1 = renamed_test_fn_1
+      self.assertEquals(3, result.test_fn_2(1))
 
   def test_simple_methods(self):
 
@@ -71,13 +59,12 @@ class CallTreesTest(converter_test_base.TestCase):
 
     node = self.parse_and_analyze(
         TestClass.test_fn_2, {'TestClass': TestClass},
-        namer=TestNamer(),
         arg_types={'self': (TestClass.__name__, TestClass)})
     node = call_trees.transform(node, self.ctx, (), ())
-    result = compiler.ast_to_object(node)
 
-    tc = TestClass()
-    self.assertEquals(3, result.test_fn_2(tc, 1))
+    with self.compiled(node) as result:
+      tc = TestClass()
+      self.assertEquals(3, result.test_fn_2(tc, 1))
 
   def test_uncompiled_modules(self):
 
@@ -86,26 +73,22 @@ class CallTreesTest(converter_test_base.TestCase):
       a = math_ops.add(a, constant_op.constant(1))
       return a
 
-    node = self.parse_and_analyze(
-        test_fn, {
-            'math_ops': math_ops,
-            'constant_op': constant_op
-        },
-        namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {
+        'math_ops': math_ops,
+        'constant_op': constant_op
+    })
     node = call_trees.transform(node, self.ctx,
                                 set(((math_ops.__name__,),
                                      (constant_op.__name__,))), ())
-    result = compiler.ast_to_object(node)
-    setattr(result, 'math_ops', math_ops)
-    setattr(result, 'constant_op', constant_op)
-
-    with self.test_session() as sess:
-      # Not renamed, because the converter doesn't rename the definition itself.
-      # (the caller is responsible for that).
-      result_tensor = result.test_fn(constant_op.constant(1))
-      result_val = sess.run(result_tensor)
 
-    self.assertEquals(3, result_val)
+    with self.compiled(node) as result:
+      result.math_ops = math_ops
+      result.constant_op = constant_op
+      with self.test_session() as sess:
+        # Not renamed, because the converter doesn't rename the definition
+        # itself (the caller is responsible for that).
+        result_tensor = result.test_fn(constant_op.constant(1))
+        self.assertEquals(3, sess.run(result_tensor))
 
 
 if __name__ == '__main__':
index 4b188195595e4ea4be4359eb87830d587c1de1de..2a0fb2d88b54114d558f1ea4cf9b1dc53b21e5cf 100644 (file)
@@ -19,18 +19,10 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.py2tf.converters import continue_canonicalization
-from tensorflow.contrib.py2tf.converters import control_flow
 from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.platform import test
 
 
-class TestNamer(control_flow.SymbolNamer):
-
-  def new_symbol(self, name_root, _):
-    return name_root
-
-
 class ContinueCanonicalizationTest(converter_test_base.TestCase):
 
   def test_basic_continue(self):
@@ -44,15 +36,15 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = continue_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    self.assertEqual(test_fn(0), result.test_fn(0))
-    self.assertEqual(test_fn(1), result.test_fn(1))
-    self.assertEqual(test_fn(2), result.test_fn(2))
-    self.assertEqual(test_fn(3), result.test_fn(3))
-    self.assertEqual(test_fn(4), result.test_fn(4))
+    with self.compiled(node) as result:
+      self.assertEqual(test_fn(0), result.test_fn(0))
+      self.assertEqual(test_fn(1), result.test_fn(1))
+      self.assertEqual(test_fn(2), result.test_fn(2))
+      self.assertEqual(test_fn(3), result.test_fn(3))
+      self.assertEqual(test_fn(4), result.test_fn(4))
 
   def test_basic_continue_for_loop(self):
 
@@ -65,14 +57,14 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = continue_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    self.assertEqual(test_fn([]), result.test_fn([]))
-    self.assertEqual(test_fn([1]), result.test_fn([1]))
-    self.assertEqual(test_fn([2]), result.test_fn([2]))
-    self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3]))
+    with self.compiled(node) as result:
+      self.assertEqual(test_fn([]), result.test_fn([]))
+      self.assertEqual(test_fn([1]), result.test_fn([1]))
+      self.assertEqual(test_fn([2]), result.test_fn([2]))
+      self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3]))
 
   def test_continue_deeply_nested(self):
 
@@ -91,15 +83,15 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase):
         v.append(x)
       return v, u, w
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = continue_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    self.assertEqual(test_fn(0), result.test_fn(0))
-    self.assertEqual(test_fn(1), result.test_fn(1))
-    self.assertEqual(test_fn(2), result.test_fn(2))
-    self.assertEqual(test_fn(3), result.test_fn(3))
-    self.assertEqual(test_fn(4), result.test_fn(4))
+    with self.compiled(node) as result:
+      self.assertEqual(test_fn(0), result.test_fn(0))
+      self.assertEqual(test_fn(1), result.test_fn(1))
+      self.assertEqual(test_fn(2), result.test_fn(2))
+      self.assertEqual(test_fn(3), result.test_fn(3))
+      self.assertEqual(test_fn(4), result.test_fn(4))
 
 
 if __name__ == '__main__':
index 46316919c867eeae5e63ff5a986de3aae5138ddf..d53e3e4fd6d87004cbe55bd430346ad263e898ea 100644 (file)
@@ -54,6 +54,49 @@ class ControlFlowTransformer(transformer.Base):
   def visit_For(self, node):
     assert False, 'for statement should have been canonicalized at this point'
 
+  def _create_cond_branch(self, body_name, aliased_orig_names,
+                          aliased_new_names, body, returns):
+    if aliased_orig_names:
+      template = """
+        def body_name():
+          aliased_new_names, = aliased_orig_names,
+          body
+          return (returns,)
+      """
+      return templates.replace(
+          template,
+          body_name=body_name,
+          body=body,
+          aliased_orig_names=aliased_orig_names,
+          aliased_new_names=aliased_new_names,
+          returns=returns)
+    else:
+      template = """
+        def body_name():
+          body
+          return (returns,)
+      """
+      return templates.replace(
+          template, body_name=body_name, body=body, returns=returns)
+
+  def _create_cond_expr(self, results, test, body_name, orelse_name):
+    if results is not None:
+      template = """
+        results = py2tf_utils.run_cond(test, body_name, orelse_name)
+      """
+      return templates.replace(
+          template,
+          test=test,
+          results=results,
+          body_name=body_name,
+          orelse_name=orelse_name)
+    else:
+      template = """
+        py2tf_utils.run_cond(test, body_name, orelse_name)
+      """
+      return templates.replace(
+          template, test=test, body_name=body_name, orelse_name=orelse_name)
+
   def visit_If(self, node):
     self.generic_visit(node)
 
@@ -67,7 +110,7 @@ class ControlFlowTransformer(transformer.Base):
       raise ValueError(
           'The else branch creates new symbols that the if branch does not.')
 
-    all_modified = tuple(body_scope.modified | orelse_scope.modified)
+    modified = tuple(body_scope.modified | orelse_scope.modified)
     all_referenced = body_scope.referenced | orelse_scope.referenced
 
     # Alias the closure variables inside the conditional functions
@@ -84,56 +127,41 @@ class ControlFlowTransformer(transformer.Base):
     node_body = ast_util.rename_symbols(node.body, alias_map)
     node_orelse = ast_util.rename_symbols(node.orelse, alias_map)
 
-    if len(all_modified) == 1:
-      results = all_modified[0]
+    if not modified:
+      # When the cond would return no value, we leave the cond called without
+      # results. That in turn should trigger the side effect guards. The
+      # branch functions will return a dummy value that ensures cond
+      # actually has some return value as well.
+      results = None
+    elif len(modified) == 1:
+      results = modified[0]
     else:
-      results = gast.Tuple([s.ast() for s in all_modified], None)
+      results = gast.Tuple([s.ast() for s in modified], None)
 
-    if aliased_orig_names:
-      template = """
-        def body_name():
-          aliased_new_names, = aliased_orig_names,
-          body
-          return (all_results,)
-        def orelse_name():
-          aliased_new_names, = aliased_orig_names,
-          orelse
-          return (all_results,)
-        results = py2tf_utils.run_cond(test, body_name, orelse_name)
-      """
-      body_name = self.context.namer.new_symbol('if_true', all_referenced)
-      return templates.replace(
-          template,
-          test=node.test,
-          body_name=body_name,
-          body=node_body,
-          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
-          orelse=node_orelse,
-          aliased_orig_names=tuple(aliased_orig_names),
-          aliased_new_names=tuple(aliased_new_names),
-          all_results=tuple(alias_map[s] if s in aliased_orig_names else s
-                            for s in all_modified),
-          results=results)
+    body_name = self.context.namer.new_symbol('if_true', all_referenced)
+    orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
+    if modified:
+      body_returns = tuple(
+          alias_map[s] if s in aliased_orig_names else s for s in modified)
     else:
-      template = """
-        def body_name():
-          body
-          return (all_results,)
-        def orelse_name():
-          orelse
-          return (all_results,)
-        results = py2tf_utils.run_cond(test, body_name, orelse_name)
-      """
-      body_name = self.context.namer.new_symbol('if_true', all_referenced)
-      return templates.replace(
-          template,
-          test=node.test,
-          body_name=body_name,
-          body=node_body,
-          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
-          orelse=node_orelse,
-          all_results=tuple(s for s in all_modified),
-          results=results)
+      body_returns = templates.replace('tf.ones(())')[0].value
+
+    body_def = self._create_cond_branch(
+        body_name,
+        aliased_orig_names=tuple(aliased_orig_names),
+        aliased_new_names=tuple(aliased_new_names),
+        body=node_body,
+        returns=body_returns)
+    orelse_def = self._create_cond_branch(
+        orelse_name,
+        aliased_orig_names=tuple(aliased_orig_names),
+        aliased_new_names=tuple(aliased_new_names),
+        body=node_orelse,
+        returns=body_returns)
+    cond_expr = self._create_cond_expr(results, node.test, body_name,
+                                       orelse_name)
+
+    return body_def + orelse_def + cond_expr
 
   def visit_While(self, node):
     self.generic_visit(node)
index 677d60e0af8aaa475f066537c04b8b38f8bc0c9e..b785b284a7fb7a0257551326c88b44a341b295ba 100644 (file)
@@ -18,26 +18,13 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.py2tf import utils
 from tensorflow.contrib.py2tf.converters import control_flow
 from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.platform import test
 
 
-class TestNamer(control_flow.SymbolNamer):
-
-  def new_symbol(self, name_root, used):
-    i = 0
-    while True:
-      name = '%s%d' % (name_root, i)
-      if name not in used:
-        return name
-      i += 1
-
-
 class ControlFlowTest(converter_test_base.TestCase):
 
   def test_simple_while(self):
@@ -50,15 +37,13 @@ class ControlFlowTest(converter_test_base.TestCase):
         i += 1
       return s, i, n
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = control_flow.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', control_flow_ops)
-    setattr(result, 'py2tf_utils', utils)
 
-    with self.test_session() as sess:
-      self.assertEqual((10, 5, 5),
-                       sess.run(result.test_fn(constant_op.constant(5))))
+    with self.compiled(node, control_flow_ops.while_loop) as result:
+      with self.test_session() as sess:
+        self.assertEqual((10, 5, 5),
+                         sess.run(result.test_fn(constant_op.constant(5))))
 
   def test_while_single_var(self):
 
@@ -67,14 +52,12 @@ class ControlFlowTest(converter_test_base.TestCase):
         n -= 1
       return n
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = control_flow.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', control_flow_ops)
-    setattr(result, 'py2tf_utils', utils)
 
-    with self.test_session() as sess:
-      self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
+    with self.compiled(node, control_flow_ops.while_loop) as result:
+      with self.test_session() as sess:
+        self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
 
   def test_simple_if(self):
 
@@ -87,17 +70,15 @@ class ControlFlowTest(converter_test_base.TestCase):
         b = 2 * n
       return a, b
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = control_flow.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', control_flow_ops)
-    setattr(result, 'py2tf_utils', utils)
 
-    with self.test_session() as sess:
-      self.assertEqual((-1, 0), sess.run(
-          result.test_fn(constant_op.constant(1))))
-      self.assertEqual((0, -2),
-                       sess.run(result.test_fn(constant_op.constant(-1))))
+    with self.compiled(node, control_flow_ops.cond) as result:
+      with self.test_session() as sess:
+        self.assertEqual((-1, 0),
+                         sess.run(result.test_fn(constant_op.constant(1))))
+        self.assertEqual((0, -2),
+                         sess.run(result.test_fn(constant_op.constant(-1))))
 
   def test_if_single_var(self):
 
@@ -106,14 +87,12 @@ class ControlFlowTest(converter_test_base.TestCase):
         n = -n
       return n
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = control_flow.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', control_flow_ops)
-    setattr(result, 'py2tf_utils', utils)
 
-    with self.test_session() as sess:
-      self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+    with self.compiled(node, control_flow_ops.cond) as result:
+      with self.test_session() as sess:
+        self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
 
 
 if __name__ == '__main__':
index 5b23db33e16e518cbbc6dcfe8137645c77439500..67747183dd323a799a04943ce4c7fe8c4093d002 100644 (file)
@@ -18,8 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import contextlib
 import imp
 
+from tensorflow.contrib.py2tf import utils
+from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.contrib.py2tf.pyct import context
 from tensorflow.contrib.py2tf.pyct import parser
 from tensorflow.contrib.py2tf.pyct import qual_names
@@ -29,9 +32,41 @@ from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
 from tensorflow.python.platform import test
 
 
+class FakeNamer(object):
+
+  def new_symbol(self, name_root, used):
+    i = 0
+    while True:
+      name = '%s%d' % (name_root, i)
+      if name not in used:
+        return name
+      i += 1
+
+  def compiled_function_name(self,
+                             original_fqn,
+                             live_entity=None,
+                             owner_type=None):
+    del live_entity
+    if owner_type is not None:
+      return None, False
+    return ('renamed_%s' % '_'.join(original_fqn)), True
+
+
 class TestCase(test.TestCase):
   """Base class for unit tests in this module. Contains relevant utilities."""
 
+  @contextlib.contextmanager
+  def compiled(self, node, *symbols):
+    source = '<compile failed>'
+    try:
+      result, source = compiler.ast_to_object(node)
+      result.tf = self.make_fake_tf(*symbols)
+      result.py2tf_utils = utils
+      yield result
+    except Exception:  # pylint:disable=broad-except
+      print('Offending compiled code:\n%s' % source)
+      raise
+
   def make_fake_tf(self, *symbols):
     fake_tf = imp.new_module('fake_tf')
     for s in symbols:
@@ -51,7 +86,7 @@ class TestCase(test.TestCase):
                         recursive=True):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(
-        namer=namer,
+        namer=namer or FakeNamer(),
         source_code=source,
         source_file=None,
         namespace=namespace,
index f50d593043aeb76d63beb3cb6c301122c9ed8948..402fa0dda28e696f70d0354ca4abf3a6c83506d9 100644 (file)
@@ -53,14 +53,17 @@ class DecoratorsTest(converter_test_base.TestCase):
     node = node.body[0].body[0]
 
     node = decorators.transform(node, remove_decorators=())
-    result = compiler.ast_to_object(
+    # Since the decorator is not removed, we need to include its source
+    # code. We cannot do it after the fact because decorators are executed
+    # on load.
+    result, _ = compiler.ast_to_object(
         node,
         source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator)))
     self.assertEqual(2, result.test_fn(1))
 
     node = decorators.transform(node, remove_decorators=(function_decorator,))
-    result = compiler.ast_to_object(node)
-    self.assertEqual(1, result.test_fn(1))
+    with self.compiled(node) as result:
+      self.assertEqual(1, result.test_fn(1))
 
   def test_simple_decorator(self):
 
@@ -82,14 +85,17 @@ class DecoratorsTest(converter_test_base.TestCase):
     node = node.body[0].body[0]
 
     node = decorators.transform(node, remove_decorators=())
-    result = compiler.ast_to_object(
+    # Since the decorator is not removed, we need to include its source
+    # code. We cannot do it after the fact because decorators are executed
+    # on load.
+    result, _ = compiler.ast_to_object(
         node,
         source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator)))
     self.assertEqual(2, result.test_fn(1))
 
     node = decorators.transform(node, remove_decorators=(simple_decorator,))
-    result = compiler.ast_to_object(node)
-    self.assertEqual(1, result.test_fn(1))
+    with self.compiled(node) as result:
+      self.assertEqual(1, result.test_fn(1))
 
 
 if __name__ == '__main__':
index 142bd4aea12fa26468372472cb8a08e1f6b0e8ac..910c4dcc0081a5632e5324268c15fd3bde5d875b 100644 (file)
@@ -18,19 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.py2tf.converters import control_flow
 from tensorflow.contrib.py2tf.converters import converter_test_base
 from tensorflow.contrib.py2tf.converters import for_canonicalization
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.platform import test
 
 
-class TestNamer(control_flow.SymbolNamer):
-
-  def new_symbol(self, name_root, _):
-    return name_root
-
-
 class ControlFlowTest(converter_test_base.TestCase):
 
   def test_basic_for(self):
@@ -41,14 +33,14 @@ class ControlFlowTest(converter_test_base.TestCase):
         s += e
       return s
 
-    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+    node = self.parse_and_analyze(test_fn, {})
     node = for_canonicalization.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
 
-    l = [1, 2, 3]
-    self.assertEqual(test_fn(l), result.test_fn(l))
-    l = []
-    self.assertEqual(test_fn(l), result.test_fn(l))
+    with self.compiled(node) as result:
+      l = [1, 2, 3]
+      self.assertEqual(test_fn(l), result.test_fn(l))
+      l = []
+      self.assertEqual(test_fn(l), result.test_fn(l))
 
 
 if __name__ == '__main__':
index d711065099b24ad814104e6460e6ca551b31b3e6..a28326c517d468230f35e45f0fbfe5257d769895 100644 (file)
@@ -20,7 +20,6 @@ from __future__ import print_function
 
 from tensorflow.contrib.py2tf.converters import converter_test_base
 from tensorflow.contrib.py2tf.converters import logical_expressions
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
@@ -34,12 +33,11 @@ class GradientsFunctionTest(converter_test_base.TestCase):
 
     node = self.parse_and_analyze(test_fn, {})
     node = logical_expressions.transform(node)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', math_ops)
 
-    with self.test_session() as sess:
-      self.assertTrue(sess.run(result.test_fn(1, 1)))
-      self.assertFalse(sess.run(result.test_fn(1, 2)))
+    with self.compiled(node, math_ops.equal) as result:
+      with self.test_session() as sess:
+        self.assertTrue(sess.run(result.test_fn(1, 1)))
+        self.assertFalse(sess.run(result.test_fn(1, 2)))
 
   def test_bool_ops(self):
 
@@ -48,11 +46,11 @@ class GradientsFunctionTest(converter_test_base.TestCase):
 
     node = self.parse_and_analyze(test_fn, {})
     node = logical_expressions.transform(node)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', math_ops)
 
-    with self.test_session() as sess:
-      self.assertTrue(sess.run(result.test_fn(True, False, True)))
+    with self.compiled(node, math_ops.logical_or,
+                       math_ops.logical_and) as result:
+      with self.test_session() as sess:
+        self.assertTrue(sess.run(result.test_fn(True, False, True)))
 
 
 if __name__ == '__main__':
index 948cb96c3fe6245f6eee8dcf9d8cd0a0f789319a..5895dc495482b5d1206b85ff7b66177e6260626b 100644 (file)
@@ -128,8 +128,11 @@ class SideEffectGuardTransformer(transformer.Base):
       # _visit_and_reindent.
       args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
       # NOTE: We can't guard object attributes because they may not be writable.
-      guarded_args = tuple(
-          s for s in args_scope.used if not s.is_composite())
+      # In addition, avoid renaming well-known names.
+      # TODO(mdan): Move these names into config.
+      unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
+      guarded_args = tuple(s for s in args_scope.used
+                           if not s.is_composite() and s not in unguarded_names)
 
       # TODO(mdan): Include all arguments which depended on guarded_args too.
       # For example, the following will still cause a race:
index 409c8b02c59d288856f84fba8b74ea7d9d664358..b41c6fa5b9564a8459aa9a223ac157f8efcb8826 100644 (file)
@@ -18,153 +18,155 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.py2tf import utils
 from tensorflow.contrib.py2tf.converters import converter_test_base
 from tensorflow.contrib.py2tf.converters import side_effect_guards
-from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
-class TestNamer(side_effect_guards.SymbolNamer):
-
-  def new_symbol(self, name_root, _):
-    return 'renamed_%s' % name_root
-
-
 class SideEffectGuardsTest(converter_test_base.TestCase):
 
-  def _transform_and_compile(self, test_fn):
-    ns = {
-        'control_flow_ops': control_flow_ops,
-        'constant_op': constant_op,
-        'gen_math_ops': gen_math_ops,
-        'ops': ops,
-        'state_ops': state_ops,
-    }
-    node = self.parse_and_analyze(
-        test_fn, ns,
-        namer=TestNamer())
-    node = side_effect_guards.transform(node, self.ctx)
-    result = compiler.ast_to_object(node)
-    self.attach_namespace(result, **ns)
-    result.tf = self.make_fake_tf(array_ops.identity, control_flow_ops.Assert,
-                                  gen_math_ops.greater,
-                                  ops.control_dependencies, ops.Tensor)
-    result.py2tf_utils = utils
-    return result.test_fn, node
-
   def test_side_effect_on_return_only_variable(self):
 
+    tf = None
+
     def test_fn(a):
-      state_ops.assign(a, a + 1)
+      tf.assign(a, a + 1)
       return a
 
-    tf_test_fn, node = self._transform_and_compile(test_fn)
+    node = self.parse_and_analyze(test_fn, {})
+    node = side_effect_guards.transform(node, self.ctx)
 
-    self.assertEqual(len(node.body[0].body), 1)
-    with self.test_session() as sess:
-      v = variables.Variable(2)
-      sess.run(v.initializer)
-      # NOTE: We don't expect the assignment to execute in this case, because
-      # variables cannot be reliably guarded.
-      self.assertEqual(2, sess.run(tf_test_fn(v)))
+    with self.compiled(node, state_ops.assign, ops.control_dependencies,
+                       ops.Tensor) as result:
+      self.assertEqual(len(node.body[0].body), 1)
+      with self.test_session() as sess:
+        v = variables.Variable(2)
+        sess.run(v.initializer)
+        # NOTE: We don't expect the assignment to execute in this case, because
+        # variables cannot be reliably guarded.
+        self.assertEqual(2, sess.run(result.test_fn(v)))
 
   def test_side_effect_on_used_variable(self):
 
+    tf = None
+
     def test_fn(a):
-      state_ops.assign(a, a + 1)
+      tf.assign(a, a + 1)
       return a + 1
 
-    tf_test_fn, node = self._transform_and_compile(test_fn)
+    node = self.parse_and_analyze(test_fn, {})
+    node = side_effect_guards.transform(node, self.ctx)
 
-    self.assertEqual(len(node.body[0].body), 1)
-    with self.test_session() as sess:
-      v = variables.Variable(2)
-      sess.run(v.initializer)
-      # NOTE: Unlike test_side_effect_on_return_only_variable, the variable was
-      # used in the local scope and so we could catch the assign's side effect.
-      self.assertEqual(4, sess.run(tf_test_fn(v)))
+    with self.compiled(node, state_ops.assign, ops.control_dependencies,
+                       ops.Tensor) as result:
+      self.assertEqual(len(node.body[0].body), 1)
+      with self.test_session() as sess:
+        v = variables.Variable(2)
+        sess.run(v.initializer)
+        # NOTE: Unlike test_side_effect_on_return_only_variable, the variable
+        # was used in the local scope and so we could catch the assign's side
+        # effect.
+        self.assertEqual(4, sess.run(result.test_fn(v)))
 
   def test_side_effect_on_tensor(self):
 
+    tf = None
+
     def test_fn(a):
-      control_flow_ops.Assert(gen_math_ops.greater(a, 0), ['expected in throw'])
+      tf.Assert(a > 0, ['expected in throw'])
       return a
 
-    tf_test_fn, node = self._transform_and_compile(test_fn)
+    node = self.parse_and_analyze(test_fn, {})
+    node = side_effect_guards.transform(node, self.ctx)
 
-    self.assertEqual(len(node.body[0].body), 1)
-    with self.test_session() as sess:
-      # NOTE: In this case we can also capture the side effect because the
-      # argument is a tensor ans we can wrap it inside an identity.
-      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
-                                   'expected in throw'):
-        sess.run(tf_test_fn(constant_op.constant(-1)))
+    with self.compiled(node, array_ops.identity, control_flow_ops.Assert,
+                       ops.control_dependencies, ops.Tensor) as result:
+      self.assertEqual(len(node.body[0].body), 1)
+      with self.test_session() as sess:
+        # NOTE: In this case we can also capture the side effect because the
+        # argument is a tensor ans we can wrap it inside an identity.
+        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+                                     'expected in throw'):
+          sess.run(result.test_fn(constant_op.constant(-1)))
 
   def test_multiline_block(self):
 
+    tf = None
+
     def test_fn(a):
-      state_ops.assign(a, a + 1)
+      tf.assign(a, a + 1)
       b = a + 1
-      state_ops.assign(a, b + 1)
+      tf.assign(a, b + 1)
       c = b + 1
       d = c + 1
       return d
 
-    tf_test_fn, node = self._transform_and_compile(test_fn)
+    node = self.parse_and_analyze(test_fn, {})
+    node = side_effect_guards.transform(node, self.ctx)
 
-    self.assertEqual(len(node.body[0].body), 1)
-    with self.test_session() as sess:
-      v = variables.Variable(2)
-      sess.run(v.initializer)
-      self.assertEqual(6, sess.run(tf_test_fn(v)))
+    with self.compiled(node, array_ops.identity, state_ops.assign,
+                       ops.control_dependencies, ops.Tensor) as result:
+      self.assertEqual(len(node.body[0].body), 1)
+      with self.test_session() as sess:
+        v = variables.Variable(2)
+        sess.run(v.initializer)
+        self.assertEqual(6, sess.run(result.test_fn(v)))
 
   def test_multiline_nested_block(self):
 
+    tf = None
+
     def test_fn(a):
-      with ops.name_scope('foo'):
-        state_ops.assign(a, a + 1)
+      with tf.name_scope('foo'):
+        tf.assign(a, a + 1)
         b = a + 1
-        # state_ops.assign(a, b + 1)
         c = b + 1
         d = c + 1
       return d
 
-    tf_test_fn, node = self._transform_and_compile(test_fn)
+    node = self.parse_and_analyze(test_fn, {})
+    node = side_effect_guards.transform(node, self.ctx)
 
-    self.assertEqual(len(node.body[0].body[0].body), 1)
-    with self.test_session() as sess:
-      v = variables.Variable(2)
-      sess.run(v.initializer)
-      self.assertEqual(6, sess.run(tf_test_fn(v)))
+    with self.compiled(node, array_ops.identity, state_ops.assign,
+                       ops.control_dependencies, ops.name_scope,
+                       ops.Tensor) as result:
+      self.assertEqual(len(node.body[0].body[0].body), 1)
+      with self.test_session() as sess:
+        v = variables.Variable(2)
+        sess.run(v.initializer)
+        self.assertEqual(6, sess.run(result.test_fn(v)))
 
   def test_multiline_block_unsafe(self):
 
+    tf = None
+
     def test_fn(a):
-      state_ops.assign(a, a + 1)
+      tf.assign(a, a + 1)
       b = a + 1
-      state_ops.assign(a, a + 1)
+      tf.assign(a, a + 1)
       c = b + 1
       d = c + 1
       return d
 
-    tf_test_fn, node = self._transform_and_compile(test_fn)
+    node = self.parse_and_analyze(test_fn, {})
+    node = side_effect_guards.transform(node, self.ctx)
 
-    self.assertEqual(len(node.body[0].body), 1)
-    with self.test_session() as sess:
-      v = variables.Variable(2)
-      sess.run(v.initializer)
-      # NOTE: This intentionally highlights the flakiness. The test should be
-      # tightened down once that is solved.
-      self.assertTrue(sess.run(tf_test_fn(v)) in (6, 7))
+    with self.compiled(node, array_ops.identity, state_ops.assign,
+                       ops.control_dependencies, ops.Tensor) as result:
+      self.assertEqual(len(node.body[0].body), 1)
+      with self.test_session() as sess:
+        v = variables.Variable(2)
+        sess.run(v.initializer)
+        # NOTE: This intentionally highlights the flakiness. The test should be
+        # tightened down once that is solved.
+        self.assertTrue(sess.run(result.test_fn(v)) in (6, 7))
 
 
 if __name__ == '__main__':
index 85d40f31580d156bf719e059bb3580a068595cb5..8ae1c701698ae9a4efbde45222ff6c3db6e92521 100644 (file)
@@ -27,6 +27,7 @@ from tensorflow.contrib.py2tf.impl import config
 from tensorflow.contrib.py2tf.impl import conversion
 from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import tf_inspect
 
 # TODO(mdan): Properly document the type hints.
@@ -83,7 +84,7 @@ def convert_inline(f, *args, **kwargs):
   return convert(arg_value_hints)(f)(*args, **kwargs)
 
 
-def convert(recursive=False, arg_types=None):
+def convert(recursive=False, verbose=False, arg_types=None):
   """Decorator that compiles a function to graph mode.
 
   The decorator is dynamic - invoking compilation whenever the decorated
@@ -92,6 +93,7 @@ def convert(recursive=False, arg_types=None):
   Args:
     recursive: Whether to recusrively convert any functions that the decorator
         function may call.
+    verbose: Whether to output the compiled code in the logs.
     arg_types: See to_graph.
 
   Returns:
@@ -125,6 +127,7 @@ def convert(recursive=False, arg_types=None):
       wrapped = to_graph(
           f,
           recursive=recursive,
+          verbose=verbose,
           arg_values=arg_values,
           arg_types=arg_types,
           partial_types=partial_types)
@@ -140,6 +143,7 @@ def convert(recursive=False, arg_types=None):
 
 def to_graph(e,
              recursive=True,
+             verbose=False,
              arg_values=None,
              arg_types=None,
              partial_types=None):
@@ -155,6 +159,7 @@ def to_graph(e,
     e: A Python entity.
     recursive: Whether to recusrively convert any functions that the decorator
         function may call.
+    verbose: Whether to output the compiled code in the logs.
     arg_values: A dict containing value hints for symbols like function
         parameters.
     arg_types: A dict containing type hints for symbols like function
@@ -178,14 +183,17 @@ def to_graph(e,
     module.body.append(parser.parse_str(import_line))
   for dep in conversion_map.dependency_cache.values():
     module.body.append(dep)
-  compiled_node = compiler.ast_to_object(module)
+  compiled_node, compiled_src = compiler.ast_to_object(module)
 
   # The compiled code should see everything the entry function saw.
   # TODO(mdan): This might not work well if the call tree spans modules?
   if tf_inspect.isfunction(e):
     compiled_node.__dict__.update(six.get_function_globals(e))
-
   compiled_fn = getattr(compiled_node, name)
+
+  if verbose:
+    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
+
   return compiled_fn
 
 
index d31462cba060bf6c04eefbad3ce7f166db994ab3..51326091de13715c32d0a79279f1d3274e48ad10 100644 (file)
@@ -115,8 +115,14 @@ class Namer(object):
       else:
         raise ValueError('Unexpected symbol type "%s"' % type(s))
 
+    pieces = name_root.split('_')
+    if pieces[-1].isdigit():
+      name_root = '_'.join(pieces[:-1])
+      n = int(pieces[-1])
+    else:
+      n = 0
     new_name = name_root
-    n = 0
+
     while (new_name in self.global_namespace or
            new_name in all_reserved_locals or new_name in self.generated_names):
       n += 1
index 91054fe61df283d45cc84b4ed675e2a6c844e3b6..e3c0da4b10f9ffbee1b2a906b64d4762f41d97b4 100644 (file)
@@ -92,6 +92,16 @@ py_test(
     ],
 )
 
+py_test(
+    name = "qual_names_test",
+    srcs = ["qual_names_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":pyct",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
 py_test(
     name = "templates_test",
     srcs = ["templates_test.py"],
index fc71469d1eaeb92352e3b50cb743621d7e5eb1d5..0caadf18c0db2a5e557c94f4df7a3f7a7321bd60 100644 (file)
@@ -63,4 +63,4 @@ def ast_to_object(node, indentation='  ', source_prefix=None):
       f.write(source_prefix)
       f.write('\n')
     f.write(source)
-  return imp.load_source(module_name, f.name)
+  return imp.load_source(module_name, f.name), source
index e0cde43566310b99bac5035285154fde906fa127..c1f84238efa7dd6fc0748748a2cb4f074572b4c6 100644 (file)
@@ -41,6 +41,7 @@ class CompilerTest(test.TestCase):
                 targets=[gast.Name('a', gast.Store(), None)],
                 value=gast.Str('c'))
         ])
+
     self.assertEqual(
         textwrap.dedent("""
             if 1:
@@ -70,15 +71,19 @@ class CompilerTest(test.TestCase):
         decorator_list=[],
         returns=None)
 
-    mod = compiler.ast_to_object(node)
+    module, source = compiler.ast_to_object(node)
 
-    self.assertEqual(2, mod.f(1))
-    with open(mod.__file__, 'r') as temp_output:
+    expected_source = """
+      def f(a):
+        return a + 1
+    """
+    self.assertEqual(
+        textwrap.dedent(expected_source).strip(),
+        source.strip())
+    self.assertEqual(2, module.f(1))
+    with open(module.__file__, 'r') as temp_output:
       self.assertEqual(
-          textwrap.dedent("""
-              def f(a):
-                return a + 1
-          """).strip(),
+          textwrap.dedent(expected_source).strip(),
           temp_output.read().strip())
 
 
index 11e3838467783807530576413bec20c9904f873f..8717ee6cff198ff31f6cbdb7213e5a8dd3df1149 100644 (file)
@@ -31,9 +31,7 @@ from tensorflow.contrib.py2tf.pyct import anno
 
 
 class QN(object):
-  """Represents a qualified name.
-
-  """
+  """Represents a qualified name."""
 
   def __init__(self, base, attr=None):
     if attr:
@@ -42,8 +40,15 @@ class QN(object):
       self._parent = base
       self.qn = base.qn + (attr,)
     else:
-      self._parent = None
-      self.qn = tuple(base.split('.'))
+      if isinstance(base, QN):
+        if base.is_composite():
+          self._parent = base.parent
+        else:
+          self._parent = None
+        self.qn = base.qn
+      else:
+        self._parent = None
+        self.qn = tuple(base.split('.'))
 
   def is_composite(self):
     return len(self.qn) > 1
diff --git a/tensorflow/contrib/py2tf/pyct/qual_names_test.py b/tensorflow/contrib/py2tf/pyct/qual_names_test.py
new file mode 100644 (file)
index 0000000..1b1eee2
--- /dev/null
@@ -0,0 +1,108 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for qual_names module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.python.platform import test
+
+
+class QNTest(test.TestCase):
+
+  def test_basic(self):
+    a = qual_names.QN('a')
+    self.assertEqual(a.qn, ('a',))
+    self.assertEqual(str(a), 'a')
+    self.assertEqual(a.ssf(), 'a')
+    self.assertEqual(a.ast().id, 'a')
+    self.assertFalse(a.is_composite())
+    with self.assertRaises(ValueError):
+      _ = a.parent
+
+    a_b = qual_names.QN(a, 'b')
+    self.assertEqual(a_b.qn, ('a', 'b'))
+    self.assertEqual(str(a_b), 'a.b')
+    self.assertEqual(a_b.ssf(), 'a_b')
+    self.assertEqual(a_b.ast().value.id, 'a')
+    self.assertEqual(a_b.ast().attr, 'b')
+    self.assertTrue(a_b.is_composite())
+    self.assertEqual(a_b.parent.qn, ('a',))
+
+    a2 = qual_names.QN(a)
+    self.assertEqual(a2.qn, ('a',))
+    with self.assertRaises(ValueError):
+      _ = a.parent
+
+    a_b2 = qual_names.QN(a_b)
+    self.assertEqual(a_b2.qn, ('a', 'b'))
+    self.assertEqual(a_b2.parent.qn, ('a',))
+
+    self.assertTrue(a2 == a)
+    self.assertFalse(a2 is a)
+
+    self.assertTrue(a_b.parent == a)
+    self.assertTrue(a_b2.parent == a)
+
+    self.assertTrue(a_b2 == a_b)
+    self.assertFalse(a_b2 is a_b)
+    self.assertFalse(a_b2 == a)
+
+    with self.assertRaises(ValueError):
+      qual_names.QN('a', 'b')
+
+  def test_hashable(self):
+    d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'}
+
+    self.assertEqual(d[qual_names.QN('a')], 'a')
+    self.assertEqual(d[qual_names.QN('b')], 'b')
+    self.assertTrue(qual_names.QN('c') not in d)
+
+
+class QNResolverTest(test.TestCase):
+
+  def assertQNStringIs(self, node, qn_str):
+    self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
+
+  def test_resolve(self):
+    samples = """
+      a
+      a.b
+      (c, d.e)
+      [f, (g.h.i)]
+      j(k, l)
+    """
+    nodes = qual_names.resolve(parser.parse_str(textwrap.dedent(samples)))
+    nodes = tuple(n.value for n in nodes.body)
+
+    self.assertQNStringIs(nodes[0], 'a')
+    self.assertQNStringIs(nodes[1], 'a.b')
+    self.assertQNStringIs(nodes[2].elts[0], 'c')
+    self.assertQNStringIs(nodes[2].elts[1], 'd.e')
+    self.assertQNStringIs(nodes[3].elts[0], 'f')
+    self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
+    self.assertQNStringIs(nodes[4].func, 'j')
+    self.assertQNStringIs(nodes[4].args[0], 'k')
+    self.assertQNStringIs(nodes[4].args[1], 'l')
+
+
+if __name__ == '__main__':
+  test.main()
index 5fd5252619f104ad2bb95863533c3f274819ffe0..c40e4d0fb783191705a412ab2728daabb61eda0f 100644 (file)
@@ -59,7 +59,7 @@ class ReplaceTransformer(gast.NodeTransformer):
       repl = self.replacements[node.name]
       if not isinstance(repl, (gast.Name, ast.Name)):
         raise ValueError(
-            'A function name can only be replaced by a Name node. Found: %s',
+            'A function name can only be replaced by a Name node. Found: %s' %
             repl)
       node.name = repl.id
     return node
@@ -70,6 +70,8 @@ class ReplaceTransformer(gast.NodeTransformer):
       node.ctx = gast.Load()
     elif isinstance(node, gast.Name):
       node.ctx = ctx
+    elif isinstance(node, (gast.Str, gast.Num)):
+      pass
     else:
       raise ValueError('unexpected node type "%s"' % node)
 
index 0e3d07e378972de67d67d9ef7fd9bef351c5b5be..8ccfde8573724741b0bbe4eacb3c54beb381ee7e 100644 (file)
@@ -34,7 +34,8 @@ class TemplatesTest(test.TestCase):
     """
 
     node = templates.replace(template, b=('a', 'c'))[0]
-    result = compiler.ast_to_object(node)
+    result, _ = compiler.ast_to_object(node)
+
     self.assertEquals((2, 3), result.test_fn(2, 3))
 
   def test_replace_variable(self):
@@ -46,7 +47,7 @@ class TemplatesTest(test.TestCase):
     """
 
     node = templates.replace(template, a='b')[0]
-    result = compiler.ast_to_object(node)
+    result, _ = compiler.ast_to_object(node)
     self.assertEquals(7, result.test_fn(2))
 
   def test_replace_function_name(self):
@@ -58,7 +59,7 @@ class TemplatesTest(test.TestCase):
     """
 
     node = templates.replace(template, fname='test_fn')[0]
-    result = compiler.ast_to_object(node)
+    result, _ = compiler.ast_to_object(node)
     self.assertEquals(7, result.test_fn(2))
 
   def test_code_block(self):
@@ -75,7 +76,7 @@ class TemplatesTest(test.TestCase):
                 gast.Name('a', None, None)
             ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
         ] * 2)[0]
-    result = compiler.ast_to_object(node)
+    result, _ = compiler.ast_to_object(node)
     self.assertEquals(3, result.test_fn(1))