def __init__(self, remove_decorators):
self.remove_decorators = remove_decorators
+ self.additional_dependencies = set()
# pylint:disable=invalid-name
dec_func = dec.func
else:
dec_func = dec
+
+ # Special cases.
+ # TODO(mdan): Is there any way we can treat these more generically?
+ # We may want to forego using decorators altogether if we can't
+ # properly support them.
+ if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',):
+ # Assumption: decorators are only visible in the AST when converting
+ # a function inline (via another decorator).
+ # In that case, the converted function is no longer part of the
+ # original object that it was declared into.
+ # This is currently verified by tests.
+ continue
+
if not anno.hasanno(dec_func, 'live_val'):
raise ValueError(
'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func))
+
dec_value = anno.getanno(dec_func, 'live_val')
if dec_value not in self.remove_decorators:
- kept_decorators.append(dec)
- node.decorator_list = kept_decorators
+ kept_decorators.append((dec, dec_value))
+
+ for _, dec_value in kept_decorators:
+ if dec_value.__module__ == '__main__':
+ raise ValueError(
+ 'decorator "%s" was not allowed because it is declared '
+ 'in the module "%s". To fix this, declare it in a separate '
+ 'module that we can import it from.' % (dec_value,
+ dec_value.__module__))
+ else:
+ self.additional_dependencies.add(dec_value)
+
+ node.decorator_list = [dec for dec, _ in kept_decorators]
return node
# pylint:enable=invalid-name
def transform(node, remove_decorators):
transformer = DecoratorsTransformer(remove_decorators)
node = transformer.visit(node)
- return node
+ return node, transformer.additional_dependencies
from __future__ import division
from __future__ import print_function
-import textwrap
+from functools import wraps
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.converters import decorators
from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.platform import test
-from tensorflow.python.util import tf_inspect
+
+
+# The Python parser only briefly captures decorators into the AST.
+# The interpreter desugars them on load, and the decorated function loses any
+# trace of the decorator (which is notmally what you would expect, since
+# they are meant to be transparent).
+# However, decorators are still visible when you analyze the function
+# from inside a decorator, before it was applied - as is the case
+# with our conversion decorators.
+
+
+def simple_decorator(f):
+ return lambda a: f(a) + 1
+
+
+def self_removing_decorator(removing_wrapper):
+ def decorator(f):
+ @wraps(f)
+ def wrapper(*args):
+ # This removing wrapper is defined in the test below. This setup is so
+ # intricate just to simulate how we use the transformer in practice.
+ transformed_f = removing_wrapper(f, (self_removing_decorator,))
+ return transformed_f(*args) + 1
+ return wrapper
+ return decorator
class DecoratorsTest(converter_test_base.TestCase):
- def test_function_decorator(self):
+ def _remover_wrapper(self, f, remove_decorators):
+ namespace = {
+ 'self_removing_decorator': self_removing_decorator,
+ 'simple_decorator': simple_decorator
+ }
+ node = self.parse_and_analyze(f, namespace)
+ node, _ = decorators.transform(node, remove_decorators=remove_decorators)
+ result, _ = compiler.ast_to_object(node)
+ return getattr(result, f.__name__)
- def function_decorator():
+ def test_noop(self):
- def decorator(f):
- return lambda a: f(a) + 1
+ def test_fn(a):
+ return a
- return decorator
+ node = self.parse_and_analyze(test_fn, {})
+ node, deps = decorators.transform(node, remove_decorators=())
+ result, _ = compiler.ast_to_object(node)
- # The Python parser does capture decorators into the AST.
- # However, the interpreter desugars them on load, and refering to the
- # decorated function at runtime usually loses any trace of the decorator.
- # Below is an example when that doesn't happen.
- def static_wrapper():
+ self.assertFalse(deps)
+ self.assertEqual(1, result.test_fn(1))
- @function_decorator()
- def test_fn(a): # pylint:disable=unused-variable
- return a
+ def test_function(self):
- node = self.parse_and_analyze(static_wrapper,
- {'function_decorator': function_decorator})
- node = node.body[0].body[0]
+ @self_removing_decorator(self._remover_wrapper)
+ def test_fn(a):
+ return a
- node = decorators.transform(node, remove_decorators=())
- # Since the decorator is not removed, we need to include its source
- # code. We cannot do it after the fact because decorators are executed
- # on load.
- result, _ = compiler.ast_to_object(
- node,
- source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator)))
- self.assertEqual(2, result.test_fn(1))
+ # 2 = 1 (a) + 1 (decorator applied exactly once)
+ self.assertEqual(2, test_fn(1))
- node = decorators.transform(node, remove_decorators=(function_decorator,))
- with self.compiled(node) as result:
- self.assertEqual(1, result.test_fn(1))
+ def test_method(self):
- def test_simple_decorator(self):
+ class TestClass(object):
- def simple_decorator(f):
- return lambda a: f(a) + 1
+ @self_removing_decorator(self._remover_wrapper)
+ def test_fn(self, a):
+ return a
- # The Python parser does capture decorators into the AST.
- # However, the interpreter desugars them upon load, and refering to the
- # decorated function at runtime usually loses any trace of the decorator.
- # Below is an example when that doesn't happen.
- def static_wrapper():
+ # 2 = 1 (a) + 1 (decorator applied exactly once)
+ self.assertEqual(2, TestClass().test_fn(1))
- @simple_decorator
- def test_fn(a): # pylint:disable=unused-variable
+ def test_multiple_decorators(self):
+
+ class TestClass(object):
+
+ # Note that reversing the order of this two doesn't work.
+ @classmethod
+ @self_removing_decorator(self._remover_wrapper)
+ def test_fn(cls, a):
return a
- node = self.parse_and_analyze(static_wrapper,
- {'simple_decorator': simple_decorator})
- node = node.body[0].body[0]
-
- node = decorators.transform(node, remove_decorators=())
- # Since the decorator is not removed, we need to include its source
- # code. We cannot do it after the fact because decorators are executed
- # on load.
- result, _ = compiler.ast_to_object(
- node,
- source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator)))
- self.assertEqual(2, result.test_fn(1))
-
- node = decorators.transform(node, remove_decorators=(simple_decorator,))
- with self.compiled(node) as result:
- self.assertEqual(1, result.test_fn(1))
+ # 2 = 1 (a) + 1 (decorator applied exactly once)
+ self.assertEqual(2, TestClass.test_fn(1))
+
+ def test_nested_decorators(self):
+
+ @self_removing_decorator(self._remover_wrapper)
+ def test_fn(a):
+ @simple_decorator
+ def inner_fn(b):
+ return b + 11
+ return inner_fn(a)
+
+ with self.assertRaises(ValueError):
+ test_fn(1)
+
+ # TODO(mdan): Uncomment this test once converter_test_base is updated.
+ # (can't do it now because it has unrelated pending changes)
+ # def test_nested_decorators(self):
+ #
+ # @self_removing_decorator(self._remover_wrapper)
+ # def test_fn(a):
+ # @imported_decorator
+ # def inner_fn(b):
+ # return b + 11
+ # return inner_fn(a)
+ #
+ # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
+ # self.assertEqual(14, test_fn(1))
if __name__ == '__main__':
off.
dependency_cache: dict[object]: ast; maps original entities to their
converted AST
+ additional_imports: set(object); additional entities which for any reason
+ cannot be attached after loading and need to be explicitly imported
+ in the generated code
name_map: dict[string]: string; maps original entities to the name of
their converted counterparts
api_module: A reference to the api module. The reference needs to be passed
self.nocompile_decorators = nocompile_decorators
self.partial_types = partial_types if partial_types else ()
self.dependency_cache = {}
+ self.additional_imports = set()
self.name_map = {}
self.api_module = api_module
arg_values=arg_values,
arg_types=arg_types,
recursive=conversion_map.recursive)
- node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
+ node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
# TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
node.name = new_name
conversion_map.update_name_map(namer)
+ # TODO(mdan): Use this at compilation.
+ conversion_map.additional_imports.update(deps)
+
return node, new_name
# source.
# TODO(mdan): Is it feasible to reconstruct intermediate source code?
ctx.source_code = None
- node = decorators.transform(node, nocompile_decorators)
+ node, deps = decorators.transform(node, nocompile_decorators)
node = break_statements.transform(node, ctx)
node = asserts.transform(node, ctx)
node = logical_expressions.transform(node)
node = side_effect_guards.transform(node, ctx)
- return node
+ return node, deps