Add another utility that captures a function's namespace as a mapping from symbol...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 24 Feb 2018 01:22:37 +0000 (17:22 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 24 Feb 2018 01:27:52 +0000 (17:27 -0800)
Update getmethodclass with a hopefully more robust method.

PiperOrigin-RevId: 186847003

tensorflow/contrib/py2tf/pyct/inspect_utils.py
tensorflow/contrib/py2tf/pyct/inspect_utils_test.py

index 86cf52a..c1af95e 100644 (file)
@@ -21,22 +21,53 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
 import six
 
 from tensorflow.python.util import tf_inspect
 
 
+def getnamespace(f):
+  """Returns the complete namespace of a function.
+
+  Namespace is defined here as the mapping of all non-local variables to values.
+  This includes the globals and the closure variables. Note that this captures
+  the entire globals collection of the function, and may contain extra symbols
+  that it does not actually use.
+
+  Args:
+    f: User defined function.
+  Returns:
+    A dict mapping symbol names to values.
+  """
+  namespace = dict(six.get_function_globals(f))
+  closure = six.get_function_closure(f)
+  freevars = six.get_function_code(f).co_freevars
+  if freevars and closure:
+    for name, cell in zip(freevars, closure):
+      namespace[name] = cell.cell_contents
+  return namespace
+
+
 def getcallargs(c, *args, **kwargs):
   """Extension of getcallargs to non-function callables."""
-  if tf_inspect.isfunction(c):
+  if tf_inspect.isfunction(c) or tf_inspect.ismethod(c):
     # The traditional getcallargs
     return tf_inspect.getcallargs(c, *args, **kwargs)
 
   if tf_inspect.isclass(c):
-    # Constructors: pass a fake None for self, then remove it.
-    arg_map = tf_inspect.getcallargs(c.__init__, None, *args, **kwargs)
-    assert 'self' in arg_map, 'no "self" argument, is this not a constructor?'
-    del arg_map['self']
+    # Constructors: use a sentinel to remove the self argument.
+    self_sentinel = object()
+    arg_map = tf_inspect.getcallargs(
+        c.__init__, self_sentinel, *args, **kwargs)
+    # Find and remove the self arg. We cannot assume it's called 'self'.
+    self_arg_name = None
+    for name, value in arg_map.items():
+      if value is self_sentinel:
+        self_arg_name = name
+        break
+    del arg_map[self_arg_name]
     return arg_map
 
   if hasattr(c, '__call__'):
@@ -46,8 +77,29 @@ def getcallargs(c, *args, **kwargs):
   raise NotImplementedError('unknown callable "%s"' % type(c))
 
 
-def getmethodclass(m, namespace):
-  """Resolves a function's owner, e.g. a method's class."""
+def getmethodclass(m):
+  """Resolves a function's owner, e.g. a method's class.
+
+  Note that this returns the object that the function was retrieved from, not
+  necessarily the class where it was defined.
+
+  This function relies on Python stack frame support in the interpreter, and
+  has the same limitations that inspect.currentframe.
+
+  Limitations. This function will only work correctly if the owned class is
+  visible in the caller's global or local variables.
+
+  Args:
+    m: A user defined function
+
+  Returns:
+    The class that this function was retrieved from, or None if the function
+    is not an object or class method, or the class that owns the object or
+    method is not visible to m.
+
+  Raises:
+    ValueError: if the class could not be resolved for any unexpected reason.
+  """
 
   # Instance method and class methods: should be bound to a non-null "self".
   # If self is a class, then it's a class method.
@@ -57,34 +109,38 @@ def getmethodclass(m, namespace):
         return m.__self__
       return type(m.__self__)
 
-  # Class and static methods: platform specific.
-  if hasattr(m, 'im_class'):  # Python 2
-    return m.im_class
-
-  if hasattr(m, '__qualname__'):  # Python 3
-    qn = m.__qualname__.split('.')
-    if len(qn) < 2:
-      return None
-    owner_name, func_name = qn[-2:]
-    assert func_name == m.__name__, (
-        'inconsistent names detected '
-        '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % (func_name,
-                                                               m.__name__, m))
-    if owner_name == '<locals>':
-      return None
-    if owner_name not in namespace:
-      raise ValueError(
-          'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' %
-          (owner_name, m, namespace))
-    return namespace[owner_name]
-
-  if six.PY2:
-    # In Python 2 it's impossible, to our knowledge, to detect the class of a
-    # static function. So we're forced to walk all the objects in the
-    # namespace and see if they own it. If any reader finds a better solution,
-    # please let us know.
-    for _, v in namespace.items():
-      if hasattr(v, m.__name__) and getattr(v, m.__name__) is m:
-        return v
+  # Class, static and unbound methods: search all defined classes in any
+  # namespace. This is inefficient but more robust method.
+  owners = []
+  caller_frame = tf_inspect.currentframe().f_back
+  try:
+    # TODO(mdan): This doesn't consider cell variables.
+    # TODO(mdan): This won't work if the owner is hidden inside a container.
+    # Cell variables may be pulled using co_freevars and the closure.
+    for v in itertools.chain(caller_frame.f_locals.values(),
+                             caller_frame.f_globals.values()):
+      if hasattr(v, m.__name__):
+        candidate = getattr(v, m.__name__)
+        # Py2 methods may be bound or unbound, extract im_func to get the
+        # underlying function.
+        if hasattr(candidate, 'im_func'):
+          candidate = candidate.im_func
+        if hasattr(m, 'im_func'):
+          m = m.im_func
+        if candidate is m:
+          owners.append(v)
+  finally:
+    del caller_frame
+
+  if owners:
+    if len(owners) == 1:
+      return owners[0]
+
+    # If multiple owners are found, and are not subclasses, raise an error.
+    owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
+    for o in owner_types:
+      if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
+        return o
+    raise ValueError('Found too many owners of %s: %s' % (m, owners))
 
   return None
index 5d92e75..d96c3df 100644 (file)
@@ -20,6 +20,8 @@ from __future__ import print_function
 
 from functools import wraps
 
+import six
+
 from tensorflow.contrib.py2tf.pyct import inspect_utils
 from tensorflow.python.platform import test
 
@@ -76,6 +78,10 @@ def free_function():
   pass
 
 
+def factory():
+  return free_function
+
+
 def free_factory():
   def local_function():
     pass
@@ -84,6 +90,43 @@ def free_factory():
 
 class InspectUtilsTest(test.TestCase):
 
+  def test_getnamespace_globals(self):
+    ns = inspect_utils.getnamespace(factory)
+    self.assertEqual(ns['free_function'], free_function)
+
+  def test_getnamespace_hermetic(self):
+
+    # Intentionally hiding the global function to make sure we don't overwrite
+    # it in the global namespace.
+    free_function = object()  # pylint:disable=redefined-outer-name
+
+    def test_fn():
+      return free_function
+
+    ns = inspect_utils.getnamespace(test_fn)
+    globs = six.get_function_globals(test_fn)
+    self.assertTrue(ns['free_function'] is free_function)
+    self.assertFalse(globs['free_function'] is free_function)
+
+  def test_getnamespace_locals(self):
+
+    def called_fn():
+      return 0
+
+    closed_over_list = []
+    closed_over_primitive = 1
+
+    def local_fn():
+      closed_over_list.append(1)
+      local_var = 1
+      return called_fn() + local_var + closed_over_primitive
+
+    ns = inspect_utils.getnamespace(local_fn)
+    self.assertEqual(ns['called_fn'], called_fn)
+    self.assertEqual(ns['closed_over_list'], closed_over_list)
+    self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
+    self.assertTrue('local_var' not in ns)
+
   def test_getcallargs_constructor(self):
 
     class TestSuperclass(object):
@@ -123,48 +166,47 @@ class InspectUtilsTest(test.TestCase):
   def test_getmethodclass(self):
 
     self.assertEqual(
-        inspect_utils.getmethodclass(free_function, {}), None)
+        inspect_utils.getmethodclass(free_function), None)
     self.assertEqual(
-        inspect_utils.getmethodclass(free_factory(), {}), None)
+        inspect_utils.getmethodclass(free_factory()), None)
 
-    ns = {'TestClass': TestClass}
     self.assertEqual(
-        inspect_utils.getmethodclass(TestClass.member_function, ns),
+        inspect_utils.getmethodclass(TestClass.member_function),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(TestClass.decorated_member, ns),
+        inspect_utils.getmethodclass(TestClass.decorated_member),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(TestClass.fn_decorated_member, ns),
+        inspect_utils.getmethodclass(TestClass.fn_decorated_member),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(TestClass.wrap_decorated_member, ns),
+        inspect_utils.getmethodclass(TestClass.wrap_decorated_member),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(TestClass.static_method, ns),
+        inspect_utils.getmethodclass(TestClass.static_method),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(TestClass.class_method, ns),
+        inspect_utils.getmethodclass(TestClass.class_method),
         TestClass)
 
     test_obj = TestClass()
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.member_function, ns),
+        inspect_utils.getmethodclass(test_obj.member_function),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.decorated_member, ns),
+        inspect_utils.getmethodclass(test_obj.decorated_member),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.fn_decorated_member, ns),
+        inspect_utils.getmethodclass(test_obj.fn_decorated_member),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.wrap_decorated_member, ns),
+        inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.static_method, ns),
+        inspect_utils.getmethodclass(test_obj.static_method),
         TestClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.class_method, ns),
+        inspect_utils.getmethodclass(test_obj.class_method),
         TestClass)
 
   def test_getmethodclass_locals(self):
@@ -190,34 +232,33 @@ class InspectUtilsTest(test.TestCase):
         pass
 
     self.assertEqual(
-        inspect_utils.getmethodclass(local_function, {}), None)
+        inspect_utils.getmethodclass(local_function), None)
 
-    ns = {'LocalClass': LocalClass}
     self.assertEqual(
-        inspect_utils.getmethodclass(LocalClass.member_function, ns),
+        inspect_utils.getmethodclass(LocalClass.member_function),
         LocalClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(LocalClass.decorated_member, ns),
+        inspect_utils.getmethodclass(LocalClass.decorated_member),
         LocalClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(LocalClass.fn_decorated_member, ns),
+        inspect_utils.getmethodclass(LocalClass.fn_decorated_member),
         LocalClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(LocalClass.wrap_decorated_member, ns),
+        inspect_utils.getmethodclass(LocalClass.wrap_decorated_member),
         LocalClass)
 
     test_obj = LocalClass()
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.member_function, ns),
+        inspect_utils.getmethodclass(test_obj.member_function),
         LocalClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.decorated_member, ns),
+        inspect_utils.getmethodclass(test_obj.decorated_member),
         LocalClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.fn_decorated_member, ns),
+        inspect_utils.getmethodclass(test_obj.fn_decorated_member),
         LocalClass)
     self.assertEqual(
-        inspect_utils.getmethodclass(test_obj.wrap_decorated_member, ns),
+        inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
         LocalClass)