Support functools.partial as callable object in tf_inspect.getargspec.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 19:55:02 +0000 (12:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 19:57:28 +0000 (12:57 -0700)
PiperOrigin-RevId: 197036874

tensorflow/python/util/tf_inspect.py
tensorflow/python/util/tf_inspect_test.py

index 663036d..33b389c 100644 (file)
@@ -18,8 +18,11 @@ from __future__ import division
 from __future__ import print_function
 
 from collections import namedtuple
+import functools
 import inspect as _inspect
 
+import six
+
 from tensorflow.python.util import tf_decorator
 
 ArgSpec = _inspect.ArgSpec
@@ -43,16 +46,95 @@ def getargspec(object):  # pylint: disable=redefined-builtin
   """TFDecorator-aware replacement for inspect.getargspec.
 
   Args:
-    object: A callable, possibly decorated.
+    object: A callable (function or partial function), possibly decorated.
 
   Returns:
     The `ArgSpec` that describes the signature of the outermost decorator that
     changes the callable's signature. If the callable is not decorated,
     `inspect.getargspec()` will be called directly on the callable.
+
+  Raises:
+    ValueError: When callable's function signature can not be expressed with
+    ArgSpec.
   """
-  decorators, target = tf_decorator.unwrap(object)
-  return next((d.decorator_argspec for d in decorators
-               if d.decorator_argspec is not None), _inspect.getargspec(target))
+
+  def get_argspec_with_decorator(obj):
+    decorators, target = tf_decorator.unwrap(obj)
+    return next((d.decorator_argspec
+                 for d in decorators
+                 if d.decorator_argspec is not None),
+                _inspect.getargspec(target))
+
+  if not isinstance(object, functools.partial):
+    return get_argspec_with_decorator(object)
+
+  # When callable is a functools.partial object, we construct its ArgSpec with
+  # following strategy:
+  # - If callable partial contains default value for positional arguments (ie.
+  # object.args), then final ArgSpec doesn't contain those positional arguments.
+  # - If callable partial contains default value for keyword arguments (ie.
+  # object.keywords), then we merge them with wrapped target. Default values
+  # from callable partial takes precedence over those from wrapped target.
+  #
+  # However, there is a case where it is impossible to construct a valid
+  # ArgSpec. Python requires arguments that have no default values must be
+  # defined before those with default values. ArgSpec structure is only valid
+  # when this presumption holds true because default values are expressed as a
+  # tuple of values without keywords and they are always assumed to belong to
+  # last K arguments where K is number of default values present.
+  #
+  # Since functools.partial can give default value to any argument, this
+  # presumption may no longer hold in some cases. For example:
+  #
+  # def func(m, n):
+  #   return 2 * m + n
+  # partialed = functools.partial(func, m=1)
+  #
+  # This example will result in m having a default value but n doesn't. This is
+  # usually not allowed in Python and can not be expressed in ArgSpec correctly.
+  #
+  # Thus, we must detect cases like this by finding first argument with default
+  # value and ensures all following arguments also have default values. When
+  # this is not true, a ValueError is raised.
+
+  n_prune_args = len(object.args)
+  partial_keywords = object.keywords or {}
+
+  args, varargs, keywords, defaults = get_argspec_with_decorator(object.func)
+
+  # Pruning first n_prune_args arguments.
+  args = args[n_prune_args:]
+
+  # Partial function may give default value to any argument, therefore length
+  # of default value list must be len(args) to allow each argument to
+  # potentially be given a default value.
+  all_defaults = [None] * len(args)
+  if defaults:
+    all_defaults[-len(defaults):] = defaults
+
+  # Fill in default values provided by partial function in all_defaults.
+  for kw, default in six.iteritems(partial_keywords):
+    idx = args.index(kw)
+    all_defaults[idx] = default
+
+  # Find first argument with default value set.
+  first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
+
+  # If no default values are found, return ArgSpec with defaults=None.
+  if first_default is None:
+    return ArgSpec(args, varargs, keywords, None)
+
+  # Checks if all arguments have default value set after first one.
+  invalid_default_values = [
+      args[i] for i, j in enumerate(all_defaults) if not j and i > first_default
+  ]
+
+  if invalid_default_values:
+    raise ValueError('Some arguments %s do not have default value, but they '
+                     'are positioned after those with default values. This can '
+                     'not be expressed with ArgSpec.' % invalid_default_values)
+
+  return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
 
 
 def getfullargspec(obj):  # pylint: disable=redefined-builtin
index 1294084..325131c 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 import inspect
 
 from tensorflow.python.platform import test
@@ -109,6 +110,141 @@ class TfInspectTest(test.TestCase):
                                                outer_argspec)
     self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
 
+  def testGetArgSpecOnPartialPositionalArgumentOnly(self):
+    """Tests getargspec on partial function with only positional arguments."""
+
+    def func(m, n):
+      return 2 * m + n
+
+    partial_func = functools.partial(func, 7)
+    argspec = tf_inspect.ArgSpec(
+        args=['n'], varargs=None, keywords=None, defaults=None)
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialInvalidArgspec(self):
+    """Tests getargspec on partial function that doesn't have valid argspec."""
+
+    def func(m, n, l, k=4):
+      return 2 * m + l + n * k
+
+    partial_func = functools.partial(func, n=7)
+
+    exception_message = (r"Some arguments \['l'\] do not have default value, "
+                         "but they are positioned after those with default "
+                         "values. This can not be expressed with ArgSpec.")
+    with self.assertRaisesRegexp(ValueError, exception_message):
+      tf_inspect.getargspec(partial_func)
+
+  def testGetArgSpecOnPartialValidArgspec(self):
+    """Tests getargspec on partial function with valid argspec."""
+
+    def func(m, n, l, k=4):
+      return 2 * m + l + n * k
+
+    partial_func = functools.partial(func, n=7, l=2)
+    argspec = tf_inspect.ArgSpec(
+        args=['m', 'n', 'l', 'k'],
+        varargs=None,
+        keywords=None,
+        defaults=(7, 2, 4))
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialNoArgumentsLeft(self):
+    """Tests getargspec on partial function that prunes all arguments."""
+
+    def func(m, n):
+      return 2 * m + n
+
+    partial_func = functools.partial(func, 7, 10)
+    argspec = tf_inspect.ArgSpec(
+        args=[], varargs=None, keywords=None, defaults=None)
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialKeywordArgument(self):
+    """Tests getargspec on partial function that prunes some arguments."""
+
+    def func(m, n):
+      return 2 * m + n
+
+    partial_func = functools.partial(func, n=7)
+    argspec = tf_inspect.ArgSpec(
+        args=['m', 'n'], varargs=None, keywords=None, defaults=(7,))
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
+    """Tests getargspec on partial function that prunes argument by keyword."""
+
+    def func(m=1, n=2):
+      return 2 * m + n
+
+    partial_func = functools.partial(func, n=7)
+    argspec = tf_inspect.ArgSpec(
+        args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialWithVarargs(self):
+    """Tests getargspec on partial function with variable arguments."""
+
+    def func(m, *arg):
+      return m + len(arg)
+
+    partial_func = functools.partial(func, 7, 8)
+    argspec = tf_inspect.ArgSpec(
+        args=[], varargs='arg', keywords=None, defaults=None)
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialWithVarkwargs(self):
+    """Tests getargspec on partial function with variable keyword arguments."""
+
+    def func(m, n, **kwarg):
+      return m * n + len(kwarg)
+
+    partial_func = functools.partial(func, 7)
+    argspec = tf_inspect.ArgSpec(
+        args=['n'], varargs=None, keywords='kwarg', defaults=None)
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialWithDecorator(self):
+    """Tests getargspec on decorated partial function."""
+
+    @test_decorator('decorator')
+    def func(m=1, n=2):
+      return 2 * m + n
+
+    partial_func = functools.partial(func, n=7)
+    argspec = tf_inspect.ArgSpec(
+        args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
+
+    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+  def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self):
+    """Tests getargspec on partial function with decorated argspec."""
+
+    argspec = tf_inspect.ArgSpec(
+        args=['a', 'b', 'c'],
+        varargs=None,
+        keywords=None,
+        defaults=(1, 'hello'))
+    decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
+                                         argspec)
+    partial_argspec = tf_inspect.ArgSpec(
+        args=['a', 'b', 'c'],
+        varargs=None,
+        keywords=None,
+        defaults=(2, 1, 'hello'))
+    partial_with_decorator = functools.partial(decorator, a=2)
+
+    self.assertEqual(argspec, tf_inspect.getargspec(decorator))
+    self.assertEqual(partial_argspec,
+                     tf_inspect.getargspec(partial_with_decorator))
+
   def testGetDoc(self):
     self.assertEqual('Test Decorated Function With Defaults Docstring.',
                      tf_inspect.getdoc(test_decorated_function_with_defaults))