Enable dynamic function calls. These are compiled just in time by inserting a call...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Feb 2018 14:00:21 +0000 (06:00 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187165096

tensorflow/contrib/py2tf/__init__.py
tensorflow/contrib/py2tf/converters/BUILD
tensorflow/contrib/py2tf/converters/call_trees.py
tensorflow/contrib/py2tf/converters/call_trees_test.py
tensorflow/contrib/py2tf/converters/converter_test_base.py
tensorflow/contrib/py2tf/impl/api.py

index 379fa7f..6531183 100644 (file)
@@ -23,6 +23,7 @@ from __future__ import print_function
 
 from tensorflow.contrib.py2tf import utils
 from tensorflow.contrib.py2tf.impl.api import convert
+from tensorflow.contrib.py2tf.impl.api import converted_call
 from tensorflow.contrib.py2tf.impl.api import graph_ready
 from tensorflow.contrib.py2tf.impl.api import to_code
 from tensorflow.contrib.py2tf.impl.api import to_graph
@@ -30,7 +31,8 @@ from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError
 from tensorflow.python.util.all_util import remove_undocumented
 
 _allowed_symbols = [
-    'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError'
+    'to_graph', 'to_code', 'convert', 'graph_ready', 'converted_call', 'utils',
+    'PyFlowParseError'
 ]
 
 remove_undocumented(__name__, _allowed_symbols)
index 42baaaa..78f46bc 100644 (file)
@@ -46,6 +46,7 @@ py_library(
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
         ":converters",
+        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/contrib/py2tf/pyct/static_analysis",
         "//tensorflow/contrib/py2tf/utils",
         "@gast_archive//:gast",
@@ -59,7 +60,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -70,7 +70,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -81,7 +80,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -92,7 +90,7 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/impl",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -103,7 +101,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -114,7 +111,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -125,7 +121,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -136,7 +131,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -157,7 +151,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -168,7 +161,6 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
@@ -184,7 +176,6 @@ py_test(
     ],
     deps = [
         ":test_lib",
-        "//tensorflow/contrib/py2tf/pyct",
         "//tensorflow/python:client_testlib",
     ],
 )
index 1050ba6..f18f9f6 100644 (file)
@@ -27,6 +27,7 @@ import types
 import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import inspect_utils
 from tensorflow.contrib.py2tf.pyct import parser
 from tensorflow.contrib.py2tf.pyct import templates
 from tensorflow.contrib.py2tf.pyct import transformer
@@ -72,9 +73,8 @@ class CallTreeTransformer(transformer.Base):
     self.uncompiled_modules = uncompiled_modules
     self.nocompile_decorators = nocompile_decorators
 
-  # pylint:disable=invalid-name
-
   def _resolve_name(self, node):
+    """Used to resolve decorator info."""
     if isinstance(node, gast.Call):
       return self._resolve_name(node.func)
     if isinstance(node, gast.Name):
@@ -99,7 +99,13 @@ class CallTreeTransformer(transformer.Base):
                          (owner_type, node.attr))
     return None
 
+  def _function_is_compilable(self, target_entity):
+    """Determines whether an entity can be compiled at all."""
+    # TODO(mdan): This is just a placeholder. Implement.
+    return not isinstance(target_entity, types.BuiltinFunctionType)
+
   def _should_compile(self, node, fqn):
+    """Determines whether an entity should be compiled in the context."""
     for i in range(1, len(fqn)):
       if fqn[:i] in self.uncompiled_modules:
         return False
@@ -141,33 +147,6 @@ class CallTreeTransformer(transformer.Base):
 
     return True
 
-  def _determine_function_owner(self, m):
-    # TODO(mdan): The parent type should be known at analysis. Use that instead.
-    if hasattr(m, 'im_class'):  # Python 2
-      return m.im_class
-    if hasattr(m, '__qualname__'):  # Python 3
-      # Object attributes: should be bound to "self".
-      if hasattr(m, '__self__'):
-        return type(m.__self__)
-
-      # Class attributes: should have the owner name in their namespace.
-      qn = m.__qualname__.split('.')
-      if len(qn) < 2:
-        return None
-      owner_name, func_name = qn[-2:]
-      if func_name != m.__name__:
-        raise ValueError('Inconsistent names detected '
-                         '(__qualname__[1] = "%s", __name__ = "%s") for %s.' %
-                         (func_name, m.__name__, m))
-      if owner_name == '<locals>':
-        return None
-      if owner_name not in self.context.namespace:
-        raise ValueError(
-            'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' %
-            (owner_name, m, self.context.namespace))
-      return self.context.namespace[owner_name]
-    return None
-
   def _rename_compilable_function(self, node):
     assert anno.hasanno(node.func, 'live_val')
     assert anno.hasanno(node.func, 'fqn')
@@ -182,7 +161,11 @@ class CallTreeTransformer(transformer.Base):
           target_fqn, live_entity=target_entity)
       do_rename = True
     else:
-      owner_type = self._determine_function_owner(target_entity)
+      if anno.hasanno(node.func, 'parent_type'):
+        owner_type = anno.getanno(node.func, 'parent_type')
+      else:
+        # Fallback - not reliable.
+        owner_type = inspect_utils.getmethodclass(target_entity)
       new_name, do_rename = self.context.namer.compiled_function_name(
           target_fqn, live_entity=target_entity, owner_type=owner_type)
 
@@ -202,9 +185,32 @@ class CallTreeTransformer(transformer.Base):
     """
     return templates.replace(template, func=node.func, original_args=node.args)
 
-  def _function_is_compilable(self, target_entity):
-    # TODO(mdan): This is just a placeholder. Implement.
-    return not isinstance(target_entity, types.BuiltinFunctionType)
+  def _converted_call(self, node):
+    """Inlines a dynamic conversion for a dynamic function."""
+    # TODO(mdan): Pass information on the statically compiled functions.
+    # Having access to the statically compiled functions can help avoid
+    # unnecessary compilation.
+    # For example, this would lead to function `a` being compiled twice:
+    #
+    #   def a():
+    #     v = b
+    #     b()
+    #   def b():
+    #     a()
+    #
+    # This is really a problem with recursive calls, which currently can
+    # only be gated by a static condition, and should be rare.
+    # TODO(mdan): It probably makes sense to use dynamic conversion every time.
+    # Before we could convert all the time though, we'd need a reasonable
+    # caching mechanism.
+    template = """
+      py2tf_api.converted_call(func, True, False, {}, original_args)
+    """
+    call_expr = templates.replace(
+        template, func=node.func, original_args=node.args)
+    return call_expr[0].value
+
+  # pylint:disable=invalid-name
 
   def visit_Expr(self, node):
     if isinstance(node.value, gast.Call):
@@ -245,9 +251,9 @@ class CallTreeTransformer(transformer.Base):
         raise NotImplementedError('py_func with return values')
     else:
       if self.context.recursive:
-        raise NotImplementedError('Could not resolve target function.')
+        node = self._converted_call(node)
       else:
-        # TODO(mdan): Double check. Is this reachable code?
+        # Unresolved functions are allowed in non-recursive mode.
         pass
     return node
 
index 777648d..d482a9e 100644 (file)
@@ -47,6 +47,21 @@ class CallTreesTest(converter_test_base.TestCase):
       result.renamed_test_fn_1 = renamed_test_fn_1
       self.assertEquals(3, result.test_fn_2(1))
 
+  def test_dynamic_function(self):
+
+    def test_fn_1():
+      raise ValueError('This should be masked by the mock.')
+
+    def test_fn_2(f):
+      return f() + 3
+
+    node = self.parse_and_analyze(test_fn_2, {})
+    node = call_trees.transform(node, self.ctx, (), ())
+
+    with self.compiled(node) as result:
+      # 10 = 7 (from the mock) + 3 (from test_fn_2)
+      self.assertEquals(10, result.test_fn_2(test_fn_1))
+
   def test_simple_methods(self):
 
     class TestClass(object):
@@ -59,6 +74,7 @@ class CallTreesTest(converter_test_base.TestCase):
 
     node = self.parse_and_analyze(
         TestClass.test_fn_2, {'TestClass': TestClass},
+        namer=converter_test_base.FakeNoRenameNamer(),
         arg_types={'self': (TestClass.__name__, TestClass)})
     node = call_trees.transform(node, self.ctx, (), ())
 
index afa5c2f..1f98d84 100644 (file)
@@ -25,6 +25,7 @@ from tensorflow.contrib.py2tf import utils
 from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.contrib.py2tf.pyct import context
 from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import pretty_printer
 from tensorflow.contrib.py2tf.pyct import qual_names
 from tensorflow.contrib.py2tf.pyct.static_analysis import activity
 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
@@ -52,26 +53,43 @@ class FakeNamer(object):
     return ('renamed_%s' % '_'.join(original_fqn)), True
 
 
+class FakeNoRenameNamer(FakeNamer):
+
+  def compiled_function_name(self, original_fqn, **_):
+    return str(original_fqn), False
+
+
 class TestCase(test.TestCase):
   """Base class for unit tests in this module. Contains relevant utilities."""
 
   @contextlib.contextmanager
   def compiled(self, node, *symbols):
-    source = '<compile failed>'
+    source = None
+
+    self.dynamic_calls = []
+    def converted_call(*args):
+      """Mock version of api.converted_call."""
+      self.dynamic_calls.append(args)
+      return 7
+
     try:
       result, source = compiler.ast_to_object(node)
-      result.tf = self.make_fake_tf(*symbols)
+      result.tf = self.make_fake_mod('fake_tf', *symbols)
       result.py2tf_utils = utils
+      result.py2tf_api = self.make_fake_mod('fake_api', converted_call)
       yield result
     except Exception:  # pylint:disable=broad-except
-      print('Offending compiled code:\n%s' % source)
+      if source is None:
+        print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
+      else:
+        print('Offending compiled code:\n%s' % source)
       raise
 
-  def make_fake_tf(self, *symbols):
-    fake_tf = imp.new_module('fake_tf')
+  def make_fake_mod(self, name, *symbols):
+    fake_mod = imp.new_module(name)
     for s in symbols:
-      setattr(fake_tf, s.__name__, s)
-    return fake_tf
+      setattr(fake_mod, s.__name__, s)
+    return fake_mod
 
   def attach_namespace(self, module, **ns):
     for k, v in ns.items():
index 29d2e03..48100aa 100644 (file)
@@ -26,7 +26,9 @@ import six
 from tensorflow.contrib.py2tf.impl import config
 from tensorflow.contrib.py2tf.impl import conversion
 from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import inspect_utils
 from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.utils import builtins
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import tf_inspect
 
@@ -110,28 +112,7 @@ def convert(recursive=False, verbose=False, arg_types=None):
 
     @wraps(f)
     def wrapper(*args, **kwargs):
-      """Wrapper that calls the compiled version of the wrapped function."""
-      partial_types = ()
-      arg_values = {}
-      arg_names = tf_inspect.getargspec(f)[0]
-      for name, arg in zip(arg_names, args):
-        arg_values[name] = arg
-        arg_class = arg.__class__
-        # If arg_value_hints specifies any name, use that instead.
-        if name not in arg_types:
-          arg_types[name] = (arg_class.__name__, arg_class)
-        if name == 'self' and tf_inspect.isclass(arg_class):
-          # Annotated methods need to specify that their owner type is partial,
-          # otherwise other members they call will not be converted.
-          partial_types = (arg_class,)
-      wrapped = to_graph(
-          f,
-          recursive=recursive,
-          verbose=verbose,
-          arg_values=arg_values,
-          arg_types=arg_types,
-          partial_types=partial_types)
-      return wrapped(*args, **kwargs)
+      return converted_call(f, recursive, verbose, arg_types, *args, **kwargs)
 
     # Sometimes the decorator is just desugared, making it impossible to detect.
     # This attribute makes detection easier.
@@ -141,6 +122,78 @@ def convert(recursive=False, verbose=False, arg_types=None):
   return decorator
 
 
+def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
+  """Compiles a function call inline."""
+  # TODO(mdan): This needs cleanup.
+  # In particular, we may want to avoid renaming functions altogether.
+
+  if conversion.is_whitelisted_for_graph(f):
+    return f(*args, **kwargs)
+
+  unknown_arg_value = object()  # Sentinel for arguments of unknown value
+
+  if tf_inspect.isbuiltin(f):
+    return builtins.dynamic_builtin(f, *args, **kwargs)
+
+  if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
+    # Regular functions
+    target_entity = f
+    arg_map_target = f
+    effective_args = args
+    f_class = inspect_utils.getmethodclass(f)
+
+    if f_class is not None:
+      partial_types = (f_class,)
+    else:
+      partial_types = ()
+
+  elif tf_inspect.isclass(f):
+    # Constructors
+    target_entity = f
+    arg_map_target = f.__init__
+    effective_args = (unknown_arg_value,) + args
+    partial_types = ()
+
+  elif hasattr(f, '__call__') and hasattr(f, '__class__'):
+    # Callable objects
+    target_entity = f.__call__
+    arg_map_target = f.__call__
+    effective_args = (f,) + args
+    partial_types = (f.__class__,)
+
+  else:
+    NotImplementedError('unknown callable type "%s"' % type(f))
+
+  arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
+  for name, arg in arg_values.items():
+    if arg is unknown_arg_value:
+      continue
+    arg_class = arg.__class__
+    # If arg_value_hints specifies any name, use that instead.
+    if name not in arg_types:
+      arg_types[name] = (arg_class.__name__, arg_class)
+
+  # When called from within a decorator, this is the only indication that
+  # the function is a method - it appears that the decorator is applied
+  # before the method is bound.
+  if not partial_types:
+    if 'self' in arg_values:
+      if tf_inspect.isclass(arg_values['self'].__class__):
+        partial_types = (arg_values['self'].__class__,)
+    elif 'cls' in arg_values:
+      if tf_inspect.isclass(arg_values['cls']):
+        partial_types = (arg_values['cls'],)
+
+  converted_f = to_graph(
+      target_entity,
+      recursive=recursive,
+      verbose=verbose,
+      arg_values=arg_values,
+      arg_types=arg_types,
+      partial_types=partial_types)
+  return converted_f(*effective_args, **kwargs)
+
+
 def to_graph(e,
              recursive=True,
              verbose=False,
@@ -189,7 +242,7 @@ def to_graph(e,
   # The compiled code should see everything the entry function saw.
   # TODO(mdan): This might not work well if the call tree spans modules?
   if tf_inspect.isfunction(e):
-    compiled_node.__dict__.update(six.get_function_globals(e))
+    compiled_node.__dict__.update(inspect_utils.getnamespace(e))
   compiled_fn = getattr(compiled_node, name)
 
   if verbose: