from __future__ import print_function
from collections import namedtuple
+import functools
import inspect as _inspect
+import six
+
from tensorflow.python.util import tf_decorator
ArgSpec = _inspect.ArgSpec
"""TFDecorator-aware replacement for inspect.getargspec.
Args:
- object: A callable, possibly decorated.
+ object: A callable (function or partial function), possibly decorated.
Returns:
The `ArgSpec` that describes the signature of the outermost decorator that
changes the callable's signature. If the callable is not decorated,
`inspect.getargspec()` will be called directly on the callable.
+
+ Raises:
+ ValueError: When callable's function signature can not be expressed with
+ ArgSpec.
"""
- decorators, target = tf_decorator.unwrap(object)
- return next((d.decorator_argspec for d in decorators
- if d.decorator_argspec is not None), _inspect.getargspec(target))
+
+ def get_argspec_with_decorator(obj):
+ decorators, target = tf_decorator.unwrap(obj)
+ return next((d.decorator_argspec
+ for d in decorators
+ if d.decorator_argspec is not None),
+ _inspect.getargspec(target))
+
+ if not isinstance(object, functools.partial):
+ return get_argspec_with_decorator(object)
+
+ # When callable is a functools.partial object, we construct its ArgSpec with
+ # following strategy:
+ # - If callable partial contains default value for positional arguments (ie.
+ # object.args), then final ArgSpec doesn't contain those positional arguments.
+ # - If callable partial contains default value for keyword arguments (ie.
+ # object.keywords), then we merge them with wrapped target. Default values
+ # from callable partial takes precedence over those from wrapped target.
+ #
+ # However, there is a case where it is impossible to construct a valid
+ # ArgSpec. Python requires arguments that have no default values must be
+ # defined before those with default values. ArgSpec structure is only valid
+ # when this presumption holds true because default values are expressed as a
+ # tuple of values without keywords and they are always assumed to belong to
+ # last K arguments where K is number of default values present.
+ #
+ # Since functools.partial can give default value to any argument, this
+ # presumption may no longer hold in some cases. For example:
+ #
+ # def func(m, n):
+ # return 2 * m + n
+ # partialed = functools.partial(func, m=1)
+ #
+ # This example will result in m having a default value but n doesn't. This is
+ # usually not allowed in Python and can not be expressed in ArgSpec correctly.
+ #
+ # Thus, we must detect cases like this by finding first argument with default
+ # value and ensures all following arguments also have default values. When
+ # this is not true, a ValueError is raised.
+
+ n_prune_args = len(object.args)
+ partial_keywords = object.keywords or {}
+
+ args, varargs, keywords, defaults = get_argspec_with_decorator(object.func)
+
+ # Pruning first n_prune_args arguments.
+ args = args[n_prune_args:]
+
+ # Partial function may give default value to any argument, therefore length
+ # of default value list must be len(args) to allow each argument to
+ # potentially be given a default value.
+ all_defaults = [None] * len(args)
+ if defaults:
+ all_defaults[-len(defaults):] = defaults
+
+ # Fill in default values provided by partial function in all_defaults.
+ for kw, default in six.iteritems(partial_keywords):
+ idx = args.index(kw)
+ all_defaults[idx] = default
+
+ # Find first argument with default value set.
+ first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
+
+ # If no default values are found, return ArgSpec with defaults=None.
+ if first_default is None:
+ return ArgSpec(args, varargs, keywords, None)
+
+ # Checks if all arguments have default value set after first one.
+ invalid_default_values = [
+ args[i] for i, j in enumerate(all_defaults) if not j and i > first_default
+ ]
+
+ if invalid_default_values:
+ raise ValueError('Some arguments %s do not have default value, but they '
+ 'are positioned after those with default values. This can '
+ 'not be expressed with ArgSpec.' % invalid_default_values)
+
+ return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
def getfullargspec(obj): # pylint: disable=redefined-builtin
from __future__ import division
from __future__ import print_function
+import functools
import inspect
from tensorflow.python.platform import test
outer_argspec)
self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
+ def testGetArgSpecOnPartialPositionalArgumentOnly(self):
+ """Tests getargspec on partial function with only positional arguments."""
+
+ def func(m, n):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, 7)
+ argspec = tf_inspect.ArgSpec(
+ args=['n'], varargs=None, keywords=None, defaults=None)
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialInvalidArgspec(self):
+ """Tests getargspec on partial function that doesn't have valid argspec."""
+
+ def func(m, n, l, k=4):
+ return 2 * m + l + n * k
+
+ partial_func = functools.partial(func, n=7)
+
+ exception_message = (r"Some arguments \['l'\] do not have default value, "
+ "but they are positioned after those with default "
+ "values. This can not be expressed with ArgSpec.")
+ with self.assertRaisesRegexp(ValueError, exception_message):
+ tf_inspect.getargspec(partial_func)
+
+ def testGetArgSpecOnPartialValidArgspec(self):
+ """Tests getargspec on partial function with valid argspec."""
+
+ def func(m, n, l, k=4):
+ return 2 * m + l + n * k
+
+ partial_func = functools.partial(func, n=7, l=2)
+ argspec = tf_inspect.ArgSpec(
+ args=['m', 'n', 'l', 'k'],
+ varargs=None,
+ keywords=None,
+ defaults=(7, 2, 4))
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialNoArgumentsLeft(self):
+ """Tests getargspec on partial function that prunes all arguments."""
+
+ def func(m, n):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, 7, 10)
+ argspec = tf_inspect.ArgSpec(
+ args=[], varargs=None, keywords=None, defaults=None)
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialKeywordArgument(self):
+ """Tests getargspec on partial function that prunes some arguments."""
+
+ def func(m, n):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, n=7)
+ argspec = tf_inspect.ArgSpec(
+ args=['m', 'n'], varargs=None, keywords=None, defaults=(7,))
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
+ """Tests getargspec on partial function that prunes argument by keyword."""
+
+ def func(m=1, n=2):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, n=7)
+ argspec = tf_inspect.ArgSpec(
+ args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialWithVarargs(self):
+ """Tests getargspec on partial function with variable arguments."""
+
+ def func(m, *arg):
+ return m + len(arg)
+
+ partial_func = functools.partial(func, 7, 8)
+ argspec = tf_inspect.ArgSpec(
+ args=[], varargs='arg', keywords=None, defaults=None)
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialWithVarkwargs(self):
+ """Tests getargspec on partial function with variable keyword arguments."""
+
+ def func(m, n, **kwarg):
+ return m * n + len(kwarg)
+
+ partial_func = functools.partial(func, 7)
+ argspec = tf_inspect.ArgSpec(
+ args=['n'], varargs=None, keywords='kwarg', defaults=None)
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialWithDecorator(self):
+ """Tests getargspec on decorated partial function."""
+
+ @test_decorator('decorator')
+ def func(m=1, n=2):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, n=7)
+ argspec = tf_inspect.ArgSpec(
+ args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
+
+ self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+
+ def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self):
+ """Tests getargspec on partial function with decorated argspec."""
+
+ argspec = tf_inspect.ArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ keywords=None,
+ defaults=(1, 'hello'))
+ decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
+ argspec)
+ partial_argspec = tf_inspect.ArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ keywords=None,
+ defaults=(2, 1, 'hello'))
+ partial_with_decorator = functools.partial(decorator, a=2)
+
+ self.assertEqual(argspec, tf_inspect.getargspec(decorator))
+ self.assertEqual(partial_argspec,
+ tf_inspect.getargspec(partial_with_decorator))
+
def testGetDoc(self):
self.assertEqual('Test Decorated Function With Defaults Docstring.',
tf_inspect.getdoc(test_decorated_function_with_defaults))