Update decorators transformer with additional clarifications in the tests and handlin...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Feb 2018 01:40:17 +0000 (17:40 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 01:48:41 +0000 (17:48 -0800)
PiperOrigin-RevId: 186390564

tensorflow/contrib/py2tf/converters/decorators.py
tensorflow/contrib/py2tf/converters/decorators_test.py
tensorflow/contrib/py2tf/impl/conversion.py

index 3f620c1cd2d9b75f82410754a7e812e13eabe3ae..68bf241ef33292f0581ccb3c44f313f853c92ba7 100644 (file)
@@ -33,6 +33,7 @@ class DecoratorsTransformer(gast.NodeTransformer):
 
   def __init__(self, remove_decorators):
     self.remove_decorators = remove_decorators
+    self.additional_dependencies = set()
 
   # pylint:disable=invalid-name
 
@@ -44,13 +45,38 @@ class DecoratorsTransformer(gast.NodeTransformer):
         dec_func = dec.func
       else:
         dec_func = dec
+
+      # Special cases.
+      # TODO(mdan): Is there any way we can treat these more generically?
+      # We may want to forego using decorators altogether if we can't
+      # properly support them.
+      if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',):
+        # Assumption: decorators are only visible in the AST when converting
+        # a function inline (via another decorator).
+        # In that case, the converted function is no longer part of the
+        # original object that it was declared into.
+        # This is currently verified by tests.
+        continue
+
       if not anno.hasanno(dec_func, 'live_val'):
         raise ValueError(
             'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func))
+
       dec_value = anno.getanno(dec_func, 'live_val')
       if dec_value not in self.remove_decorators:
-        kept_decorators.append(dec)
-    node.decorator_list = kept_decorators
+        kept_decorators.append((dec, dec_value))
+
+    for _, dec_value in kept_decorators:
+      if dec_value.__module__ == '__main__':
+        raise ValueError(
+            'decorator "%s" was not allowed because it is declared '
+            'in the module "%s". To fix this, declare it in a separate '
+            'module that we can import it from.' % (dec_value,
+                                                    dec_value.__module__))
+      else:
+        self.additional_dependencies.add(dec_value)
+
+    node.decorator_list = [dec for dec, _ in kept_decorators]
     return node
 
   # pylint:enable=invalid-name
@@ -59,4 +85,4 @@ class DecoratorsTransformer(gast.NodeTransformer):
 def transform(node, remove_decorators):
   transformer = DecoratorsTransformer(remove_decorators)
   node = transformer.visit(node)
-  return node
+  return node, transformer.additional_dependencies
index 402fa0dda28e696f70d0354ca4abf3a6c83506d9..c75e5461746f27d14a54b7ac06e7f77d868372c8 100644 (file)
@@ -18,84 +18,121 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import textwrap
+from functools import wraps
 
 from tensorflow.contrib.py2tf.converters import converter_test_base
 from tensorflow.contrib.py2tf.converters import decorators
 from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.python.platform import test
-from tensorflow.python.util import tf_inspect
+
+
+# The Python parser only briefly captures decorators into the AST.
+# The interpreter desugars them on load, and the decorated function loses any
+# trace of the decorator (which is notmally what you would expect, since
+# they are meant to be transparent).
+# However, decorators are still visible when you analyze the function
+# from inside a decorator, before it was applied - as is the case
+# with our conversion decorators.
+
+
+def simple_decorator(f):
+  return lambda a: f(a) + 1
+
+
+def self_removing_decorator(removing_wrapper):
+  def decorator(f):
+    @wraps(f)
+    def wrapper(*args):
+      # This removing wrapper is defined in the test below. This setup is so
+      # intricate just to simulate how we use the transformer in practice.
+      transformed_f = removing_wrapper(f, (self_removing_decorator,))
+      return transformed_f(*args) + 1
+    return wrapper
+  return decorator
 
 
 class DecoratorsTest(converter_test_base.TestCase):
 
-  def test_function_decorator(self):
+  def _remover_wrapper(self, f, remove_decorators):
+    namespace = {
+        'self_removing_decorator': self_removing_decorator,
+        'simple_decorator': simple_decorator
+    }
+    node = self.parse_and_analyze(f, namespace)
+    node, _ = decorators.transform(node, remove_decorators=remove_decorators)
+    result, _ = compiler.ast_to_object(node)
+    return getattr(result, f.__name__)
 
-    def function_decorator():
+  def test_noop(self):
 
-      def decorator(f):
-        return lambda a: f(a) + 1
+    def test_fn(a):
+      return a
 
-      return decorator
+    node = self.parse_and_analyze(test_fn, {})
+    node, deps = decorators.transform(node, remove_decorators=())
+    result, _ = compiler.ast_to_object(node)
 
-    # The Python parser does capture decorators into the AST.
-    # However, the interpreter desugars them on load, and refering to the
-    # decorated function at runtime usually loses any trace of the decorator.
-    # Below is an example when that doesn't happen.
-    def static_wrapper():
+    self.assertFalse(deps)
+    self.assertEqual(1, result.test_fn(1))
 
-      @function_decorator()
-      def test_fn(a):  # pylint:disable=unused-variable
-        return a
+  def test_function(self):
 
-    node = self.parse_and_analyze(static_wrapper,
-                                  {'function_decorator': function_decorator})
-    node = node.body[0].body[0]
+    @self_removing_decorator(self._remover_wrapper)
+    def test_fn(a):
+      return a
 
-    node = decorators.transform(node, remove_decorators=())
-    # Since the decorator is not removed, we need to include its source
-    # code. We cannot do it after the fact because decorators are executed
-    # on load.
-    result, _ = compiler.ast_to_object(
-        node,
-        source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator)))
-    self.assertEqual(2, result.test_fn(1))
+    # 2 = 1 (a) + 1 (decorator applied exactly once)
+    self.assertEqual(2, test_fn(1))
 
-    node = decorators.transform(node, remove_decorators=(function_decorator,))
-    with self.compiled(node) as result:
-      self.assertEqual(1, result.test_fn(1))
+  def test_method(self):
 
-  def test_simple_decorator(self):
+    class TestClass(object):
 
-    def simple_decorator(f):
-      return lambda a: f(a) + 1
+      @self_removing_decorator(self._remover_wrapper)
+      def test_fn(self, a):
+        return a
 
-    # The Python parser does capture decorators into the AST.
-    # However, the interpreter desugars them upon load, and refering to the
-    # decorated function at runtime usually loses any trace of the decorator.
-    # Below is an example when that doesn't happen.
-    def static_wrapper():
+    # 2 = 1 (a) + 1 (decorator applied exactly once)
+    self.assertEqual(2, TestClass().test_fn(1))
 
-      @simple_decorator
-      def test_fn(a):  # pylint:disable=unused-variable
+  def test_multiple_decorators(self):
+
+    class TestClass(object):
+
+      # Note that reversing the order of this two doesn't work.
+      @classmethod
+      @self_removing_decorator(self._remover_wrapper)
+      def test_fn(cls, a):
         return a
 
-    node = self.parse_and_analyze(static_wrapper,
-                                  {'simple_decorator': simple_decorator})
-    node = node.body[0].body[0]
-
-    node = decorators.transform(node, remove_decorators=())
-    # Since the decorator is not removed, we need to include its source
-    # code. We cannot do it after the fact because decorators are executed
-    # on load.
-    result, _ = compiler.ast_to_object(
-        node,
-        source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator)))
-    self.assertEqual(2, result.test_fn(1))
-
-    node = decorators.transform(node, remove_decorators=(simple_decorator,))
-    with self.compiled(node) as result:
-      self.assertEqual(1, result.test_fn(1))
+    # 2 = 1 (a) + 1 (decorator applied exactly once)
+    self.assertEqual(2, TestClass.test_fn(1))
+
+  def test_nested_decorators(self):
+
+    @self_removing_decorator(self._remover_wrapper)
+    def test_fn(a):
+      @simple_decorator
+      def inner_fn(b):
+        return b + 11
+      return inner_fn(a)
+
+    with self.assertRaises(ValueError):
+      test_fn(1)
+
+  # TODO(mdan): Uncomment this test once converter_test_base is updated.
+  # (can't do it now because it has unrelated pending changes)
+  # def test_nested_decorators(self):
+  #
+  #   @self_removing_decorator(self._remover_wrapper)
+  #   def test_fn(a):
+  #     @imported_decorator
+  #     def inner_fn(b):
+  #       return b + 11
+  #     return inner_fn(a)
+  #
+  #   # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
+  #   self.assertEqual(14, test_fn(1))
 
 
 if __name__ == '__main__':
index 7610f0427be45832dcc12e8f32b65292254eadfd..f3dc6b4d06aef8b5a7d926578206dc8315dd3749 100644 (file)
@@ -56,6 +56,9 @@ class ConversionMap(object):
         off.
     dependency_cache: dict[object]: ast; maps original entities to their
         converted AST
+    additional_imports: set(object); additional entities which for any reason
+        cannot be attached after loading and need to be explicitly imported
+        in the generated code
     name_map: dict[string]: string; maps original entities to the name of
         their converted counterparts
     api_module: A reference to the api module. The reference needs to be passed
@@ -70,6 +73,7 @@ class ConversionMap(object):
     self.nocompile_decorators = nocompile_decorators
     self.partial_types = partial_types if partial_types else ()
     self.dependency_cache = {}
+    self.additional_imports = set()
     self.name_map = {}
     self.api_module = api_module
 
@@ -218,7 +222,7 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
       arg_values=arg_values,
       arg_types=arg_types,
       recursive=conversion_map.recursive)
-  node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
+  node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
 
   # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
   new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
@@ -229,6 +233,9 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
 
   node.name = new_name
   conversion_map.update_name_map(namer)
+  # TODO(mdan): Use this at compilation.
+  conversion_map.additional_imports.update(deps)
+
   return node, new_name
 
 
@@ -271,7 +278,7 @@ def node_to_graph(node, ctx, nocompile_decorators):
   # source.
   # TODO(mdan): Is it feasible to reconstruct intermediate source code?
   ctx.source_code = None
-  node = decorators.transform(node, nocompile_decorators)
+  node, deps = decorators.transform(node, nocompile_decorators)
   node = break_statements.transform(node, ctx)
   node = asserts.transform(node, ctx)
 
@@ -296,4 +303,4 @@ def node_to_graph(node, ctx, nocompile_decorators):
   node = logical_expressions.transform(node)
   node = side_effect_guards.transform(node, ctx)
 
-  return node
+  return node, deps