Move fn_args utility into core TensorFlow from Estimator.
authorMichael Case <mikecase@google.com>
Fri, 11 May 2018 17:58:17 +0000 (10:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 18:01:30 +0000 (11:01 -0700)
Working on untangling TF/Estimator deps. Some core TF code depends on Estimator
by using the fn_args utility function within Estimator.

PiperOrigin-RevId: 196277612

20 files changed:
tensorflow/contrib/eager/python/network.py
tensorflow/contrib/estimator/BUILD
tensorflow/contrib/estimator/python/estimator/extenders.py
tensorflow/contrib/estimator/python/estimator/logit_fns.py
tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
tensorflow/contrib/learn/python/learn/experiment.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/python/BUILD
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/estimator_test.py
tensorflow/python/estimator/run_config.py
tensorflow/python/estimator/util.py
tensorflow/python/keras/_impl/keras/engine/base_layer.py
tensorflow/python/layers/base.py
tensorflow/python/ops/variable_scope.py
tensorflow/python/training/monitored_session.py
tensorflow/python/util/function_utils.py [new file with mode: 0644]
tensorflow/python/util/function_utils_test.py [moved from tensorflow/python/estimator/util_test.py with 85% similarity]

index 44828be..9af50ee 100644 (file)
@@ -23,7 +23,6 @@ import os
 import weakref
 
 from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.framework import ops
 from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
 from tensorflow.python.layers import base
@@ -33,6 +32,7 @@ from tensorflow.python.training import checkpoint_utils
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training import training_util
 from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
 
 # pylint: disable=protected-access
 # Explanation for protected-access disable: Network has lots of same-class and
@@ -545,10 +545,10 @@ class Sequential(Network):
 
   def add(self, layer_func):
     if isinstance(layer_func, base.Layer):
-      args = estimator_util.fn_args(layer_func.call)
+      args = function_utils.fn_args(layer_func.call)
       self.track_layer(layer_func)
     elif callable(layer_func):
-      args = estimator_util.fn_args(layer_func)
+      args = function_utils.fn_args(layer_func)
     else:
       raise TypeError(
           "Sequential.add() takes only tf.layers.Layer objects or callables; "
index 53bbafd..df08dc2 100644 (file)
@@ -366,9 +366,9 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:util",
         "//tensorflow/python/estimator:dnn",
         "//tensorflow/python/estimator:linear",
-        "//tensorflow/python/estimator:util",
     ],
 )
 
index 201699e..bf08be0 100644 (file)
@@ -22,12 +22,12 @@ import six
 
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.estimator.export.export_output import PredictOutput
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 from tensorflow.python.ops import clip_ops
 from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.util import function_utils
 
 
 _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
@@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer):
 
 
 def _verify_metric_fn_args(metric_fn):
-  args = set(estimator_util.fn_args(metric_fn))
+  args = set(function_utils.fn_args(metric_fn))
   invalid_args = list(args - _VALID_METRIC_FN_ARGS)
   if invalid_args:
     raise ValueError('metric_fn (%s) has following not expected args: %s' %
@@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn):
 
 def _call_metric_fn(metric_fn, features, labels, predictions, config):
   """Calls metric fn with proper arguments."""
-  metric_fn_args = estimator_util.fn_args(metric_fn)
+  metric_fn_args = function_utils.fn_args(metric_fn)
   kwargs = {}
   if 'features' in metric_fn_args:
     kwargs['features'] = features
index 09c2862..c8b0dd6 100644 (file)
@@ -41,10 +41,10 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.python.estimator import util
 from tensorflow.python.estimator.canned import dnn as dnn_core
 from tensorflow.python.estimator.canned import linear as linear_core
 from tensorflow.python.framework import ops
+from tensorflow.python.util import function_utils
 
 # pylint: disable=protected-access
 dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
@@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config):
     ValueError: if logit_fn does not return a Tensor or a dictionary mapping
       strings to Tensors.
   """
-  logit_fn_args = util.fn_args(logit_fn)
+  logit_fn_args = function_utils.fn_args(logit_fn)
   kwargs = {}
   if 'mode' in logit_fn_args:
     kwargs['mode'] = mode
index f856444..cda23aa 100644 (file)
@@ -32,7 +32,6 @@ import six
 from tensorflow.core.framework import node_def_pb2
 from tensorflow.python.client import device_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
 from tensorflow.python.estimator.export import export_output as export_output_lib
 from tensorflow.python.framework import device as framework_device
 from tensorflow.python.framework import ops as ops_lib
@@ -48,6 +47,7 @@ from tensorflow.python.platform import tf_logging
 from tensorflow.python.training import device_setter as device_setter_lib
 from tensorflow.python.training import optimizer as optimizer_lib
 from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
 
 
 @deprecation.deprecated(
@@ -521,7 +521,7 @@ def _get_loss_towers(model_fn,
   """Replicate the loss computation across devices."""
   tower_specs = []
 
-  model_fn_args = util.fn_args(model_fn)
+  model_fn_args = function_utils.fn_args(model_fn)
   optional_params = {}
   if 'params' in model_fn_args:
     optional_params['params'] = copy.deepcopy(params)
index dfc6a39..541da90 100644 (file)
@@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable
 from tensorflow.contrib.learn.python.learn.estimators import run_config
 from tensorflow.contrib.tpu.python.tpu import tpu_estimator
 from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import basic_session_run_hooks
 from tensorflow.python.training import saver
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
 
 __all__ = ["Experiment"]
 
 
 def _get_standardized_predicate_fn(predicate_fn):
-  pred_fn_args = estimator_util.fn_args(predicate_fn)
+  pred_fn_args = function_utils.fn_args(predicate_fn)
   if "checkpoint_path" not in pred_fn_args:
     # pylint: disable=unused-argument
     def _pred_fn_wrapper(eval_results, checkpoint_path):
index afc8c7d..1bf2fc5 100644 (file)
@@ -46,7 +46,6 @@ from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -68,6 +67,7 @@ from tensorflow.python.training import evaluation
 from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training
 from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 
@@ -1269,7 +1269,7 @@ class _ModelFnWrapper(object):
 
   def _call_model_fn(self, features, labels, is_export_mode=False):
     """Calls the model_fn with required parameters."""
-    model_fn_args = util.fn_args(self._model_fn)
+    model_fn_args = function_utils.fn_args(self._model_fn)
     kwargs = {}
 
     # Makes deep copy with `config` and params` in case user mutates them.
@@ -1361,7 +1361,7 @@ class _OutfeedHostCall(object):
 
       if isinstance(host_call[1], (tuple, list)):
         fullargspec = tf_inspect.getfullargspec(host_call[0])
-        fn_args = util.fn_args(host_call[0])
+        fn_args = function_utils.fn_args(host_call[0])
         # wrapped_hostcall_with_global_step uses varargs, so we allow that.
         if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
           raise RuntimeError(
@@ -1938,7 +1938,7 @@ class TPUEstimator(estimator_lib.Estimator):
     Raises:
       ValueError: if input_fn takes invalid arguments or does not have `params`.
     """
-    input_fn_args = util.fn_args(input_fn)
+    input_fn_args = function_utils.fn_args(input_fn)
     config = self.config  # a deep copy.
     kwargs = {}
     if 'params' in input_fn_args:
index 8b904a1..cc96d5a 100644 (file)
@@ -3250,6 +3250,16 @@ py_test(
 )
 
 py_test(
+    name = "function_utils_test",
+    srcs = ["util/function_utils_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":client_testlib",
+        ":util",
+    ],
+)
+
+py_test(
     name = "tf_contextlib_test",
     size = "small",
     srcs = ["util/tf_contextlib_test.py"],
index 2d9a084..a498e85 100644 (file)
@@ -445,16 +445,6 @@ py_library(
     ],
 )
 
-py_test(
-    name = "util_test",
-    srcs = ["util_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":util",
-        "//tensorflow/python:client_testlib",
-    ],
-)
-
 py_library(
     name = "estimator",
     srcs = [
@@ -645,7 +635,6 @@ py_library(
         ":metric_keys",
         ":model_fn",
         ":prediction_keys",
-        ":util",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:check_ops",
         "//tensorflow/python:control_flow_ops",
@@ -659,6 +648,7 @@ py_library(
         "//tensorflow/python:string_ops",
         "//tensorflow/python:summary",
         "//tensorflow/python:training",
+        "//tensorflow/python:util",
         "//tensorflow/python:weights_broadcast_ops",
         "//tensorflow/python/feature_column",
         "//tensorflow/python/ops/losses",
index 2326373..dcf8b15 100644 (file)
@@ -24,7 +24,6 @@ import collections
 import six
 
 from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import util
 from tensorflow.python.estimator.canned import metric_keys
 from tensorflow.python.estimator.canned import prediction_keys
 from tensorflow.python.estimator.export import export_output
@@ -46,6 +45,7 @@ from tensorflow.python.ops.losses import losses
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.summary import summary
 from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
 
 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
@@ -461,7 +461,7 @@ def _validate_loss_fn_args(loss_fn):
   Raises:
     ValueError: If the signature is unexpected.
   """
-  loss_fn_args = util.fn_args(loss_fn)
+  loss_fn_args = function_utils.fn_args(loss_fn)
   for required_arg in ['labels', 'logits']:
     if required_arg not in loss_fn_args:
       raise ValueError(
@@ -484,7 +484,7 @@ def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):
   Returns:
     Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
   """
-  loss_fn_args = util.fn_args(loss_fn)
+  loss_fn_args = function_utils.fn_args(loss_fn)
   kwargs = {}
   if 'features' in loss_fn_args:
     kwargs['features'] = features
index 9cfc680..5fdda04 100644 (file)
@@ -36,7 +36,6 @@ from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
 from tensorflow.python.estimator.export import export as export_helpers
 from tensorflow.python.estimator.export import export_output
 from tensorflow.python.framework import errors
@@ -63,6 +62,7 @@ from tensorflow.python.training import training_util
 from tensorflow.python.training import warm_starting_util
 from tensorflow.python.util import compat
 from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
@@ -1052,7 +1052,7 @@ class Estimator(object):
     Raises:
       ValueError: if input_fn takes invalid arguments.
     """
-    input_fn_args = util.fn_args(input_fn)
+    input_fn_args = function_utils.fn_args(input_fn)
     kwargs = {}
     if 'mode' in input_fn_args:
       kwargs['mode'] = mode
@@ -1078,7 +1078,7 @@ class Estimator(object):
     Raises:
       ValueError: if model_fn returns invalid objects.
     """
-    model_fn_args = util.fn_args(self._model_fn)
+    model_fn_args = function_utils.fn_args(self._model_fn)
     kwargs = {}
     if 'labels' in model_fn_args:
       kwargs['labels'] = labels
@@ -1483,7 +1483,7 @@ def _get_replica_device_setter(config):
 
 def _verify_model_fn_args(model_fn, params):
   """Verifies model fn arguments."""
-  args = set(util.fn_args(model_fn))
+  args = set(function_utils.fn_args(model_fn))
   if 'features' not in args:
     raise ValueError('model_fn (%s) must include features argument.' % model_fn)
   if params is not None and 'params' not in args:
index 0f268f5..1b70189 100644 (file)
@@ -33,7 +33,6 @@ from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import estimator
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
 from tensorflow.python.estimator.export import export
 from tensorflow.python.estimator.export import export_output
 from tensorflow.python.estimator.inputs import numpy_io
@@ -72,6 +71,7 @@ from tensorflow.python.training import saver_test_utils
 from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training
 from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
 
 _TMP_DIR = '/tmp'
 _ANOTHER_TMP_DIR = '/another_tmp'
@@ -332,7 +332,7 @@ class EstimatorConstructorTest(test.TestCase):
       _, _, _, _, _ = features, labels, mode, config, params
 
     est = estimator.Estimator(model_fn=model_fn)
-    model_fn_args = util.fn_args(est.model_fn)
+    model_fn_args = function_utils.fn_args(est.model_fn)
     self.assertEqual(
         set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
 
@@ -342,7 +342,7 @@ class EstimatorConstructorTest(test.TestCase):
       _, _ = features, labels
 
     est = estimator.Estimator(model_fn=model_fn)
-    model_fn_args = util.fn_args(est.model_fn)
+    model_fn_args = function_utils.fn_args(est.model_fn)
     self.assertEqual(
         set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
 
index 8162b24..c7707be 100644 (file)
@@ -27,8 +27,8 @@ import six
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
-from tensorflow.python.estimator import util
 from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -283,7 +283,7 @@ def _validate_properties(run_config):
             message='tf_random_seed must be integer.')
 
   _validate('device_fn', lambda device_fn: six.callable(device_fn) and
-            set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
+            set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
             message='device_fn must be callable with exactly'
                     ' one argument "op".')
 
index bb4bdd3..e4e1d37 100644 (file)
 # limitations under the License.
 # ==============================================================================
 
-"""Utility to retrieve function args."""
+"""Utilities for Estimators."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import functools
 import os
 import time
 
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _is_bounded_method(fn):
-  _, fn = tf_decorator.unwrap(fn)
-  return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
-
-
-def _is_callable_object(obj):
-  return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
-
-
-def fn_args(fn):
-  """Get argument names for function-like object.
-
-  Args:
-    fn: Function, or function-like object (e.g., result of `functools.partial`).
-
-  Returns:
-    `tuple` of string argument names.
-
-  Raises:
-    ValueError: if partial function has positionally bound arguments
-  """
-  if isinstance(fn, functools.partial):
-    args = fn_args(fn.func)
-    args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
-  else:
-    if _is_callable_object(fn):
-      fn = fn.__call__
-    args = tf_inspect.getfullargspec(fn).args
-    if _is_bounded_method(fn):
-      args.remove('self')
-  return tuple(args)
+from tensorflow.python.util import function_utils
 
+fn_args = function_utils.fn_args
 
 # When we create a timestamped directory, there is a small chance that the
 # directory already exists because another process is also creating these
index 16ee295..72ab77f 100644 (file)
@@ -25,7 +25,7 @@ import numpy as np
 from six.moves import zip  # pylint: disable=redefined-builtin
 
 from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
+from tensorflow.python.estimator import util as function_utils
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -44,6 +44,7 @@ from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.training import checkpointable
+from tensorflow.python.util import function_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
@@ -146,7 +147,7 @@ class Layer(checkpointable.CheckpointableBase):
     # return tensors. When using graph execution, _losses is a list of ops.
     self._losses = []
     self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
-    self._call_fn_args = estimator_util.fn_args(self.call)
+    self._call_fn_args = function_utils.fn_args(self.call)
     self._compute_previous_mask = ('mask' in self._call_fn_args or
                                    hasattr(self, 'compute_mask'))
     self._uses_inputs_arg = True
@@ -644,7 +645,7 @@ class Layer(checkpointable.CheckpointableBase):
         self._compute_previous_mask):
       previous_mask = collect_previous_mask(inputs)
       if not hasattr(self, '_call_fn_args'):
-        self._call_fn_args = estimator_util.fn_args(self.call)
+        self._call_fn_args = function_utils.fn_args(self.call)
       if ('mask' in self._call_fn_args and 'mask' not in kwargs and
           not generic_utils.is_all_none(previous_mask)):
         # The previous layer generated a mask, and mask was not explicitly pass
index 64db49c..2040e00 100644 (file)
@@ -20,12 +20,12 @@ from __future__ import print_function
 import copy
 
 from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.keras._impl.keras.engine import base_layer
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.util import function_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
@@ -308,7 +308,7 @@ class Layer(base_layer.Layer):
       try:
         call_has_scope_arg = self._call_has_scope_arg
       except AttributeError:
-        self._call_fn_args = estimator_util.fn_args(self.call)
+        self._call_fn_args = function_utils.fn_args(self.call)
         self._call_has_scope_arg = 'scope' in self._call_fn_args
         call_has_scope_arg = self._call_has_scope_arg
       if call_has_scope_arg:
index adb0f59..f5970fd 100644 (file)
@@ -32,7 +32,6 @@ from six import iteritems
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -41,6 +40,7 @@ from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import function_utils
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util.tf_export import tf_export
 
@@ -422,7 +422,7 @@ class _VariableStore(object):
           "use_resource": use_resource,
       }
       # `fn_args` can handle functions, `functools.partial`, `lambda`.
-      if "constraint" in estimator_util.fn_args(custom_getter):
+      if "constraint" in function_utils.fn_args(custom_getter):
         custom_getter_kwargs["constraint"] = constraint
       return custom_getter(**custom_getter_kwargs)
     else:
index f584a00..fece337 100644 (file)
@@ -25,7 +25,6 @@ import sys
 import six
 
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.estimator import util
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -41,6 +40,7 @@ from tensorflow.python.training import queue_runner
 from tensorflow.python.training import saver as training_saver
 from tensorflow.python.training import session_manager as sm
 from tensorflow.python.training import session_run_hook
+from tensorflow.python.util import function_utils
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -620,7 +620,7 @@ class _MonitoredSession(object):
         `step_context`. It may also optionally have `self` for cases when it
         belongs to an object.
     """
-    step_fn_arguments = util.fn_args(step_fn)
+    step_fn_arguments = function_utils.fn_args(step_fn)
     if step_fn_arguments != ('step_context',) and step_fn_arguments != (
         'self',
         'step_context',
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
new file mode 100644 (file)
index 0000000..7bbbde3
--- /dev/null
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to retrieve function args."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def _is_bounded_method(fn):
+  _, fn = tf_decorator.unwrap(fn)
+  return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
+
+
+def _is_callable_object(obj):
+  return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
+
+
+def fn_args(fn):
+  """Get argument names for function-like object.
+
+  Args:
+    fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+  Returns:
+    `tuple` of string argument names.
+
+  Raises:
+    ValueError: if partial function has positionally bound arguments
+  """
+  if isinstance(fn, functools.partial):
+    args = fn_args(fn.func)
+    args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
+  else:
+    if _is_callable_object(fn):
+      fn = fn.__call__
+    args = tf_inspect.getfullargspec(fn).args
+    if _is_bounded_method(fn):
+      args.remove('self')
+  return tuple(args)
similarity index 85%
rename from tensorflow/python/estimator/util_test.py
rename to tensorflow/python/util/function_utils_test.py
index 4b2c8d7..e78cf6a 100644 (file)
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import functools
 
-from tensorflow.python.estimator import util
 from tensorflow.python.platform import test
+from tensorflow.python.util import function_utils
 
 
 class FnArgsTest(test.TestCase):
@@ -29,7 +29,7 @@ class FnArgsTest(test.TestCase):
   def test_simple_function(self):
     def fn(a, b):
       return a + b
-    self.assertEqual(('a', 'b'), util.fn_args(fn))
+    self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
 
   def test_callable(self):
 
@@ -38,7 +38,7 @@ class FnArgsTest(test.TestCase):
       def __call__(self, a, b):
         return a + b
 
-    self.assertEqual(('a', 'b'), util.fn_args(Foo()))
+    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
 
   def test_bounded_method(self):
 
@@ -47,7 +47,7 @@ class FnArgsTest(test.TestCase):
       def bar(self, a, b):
         return a + b
 
-    self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
+    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
 
   def test_partial_function(self):
     expected_test_arg = 123
@@ -59,7 +59,7 @@ class FnArgsTest(test.TestCase):
 
     wrapped_fn = functools.partial(fn, test_arg=123)
 
-    self.assertEqual(('a',), util.fn_args(wrapped_fn))
+    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
 
   def test_partial_function_with_positional_args(self):
     expected_test_arg = 123
@@ -71,7 +71,7 @@ class FnArgsTest(test.TestCase):
 
     wrapped_fn = functools.partial(fn, 123)
 
-    self.assertEqual(('a',), util.fn_args(wrapped_fn))
+    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
 
     self.assertEqual(3, wrapped_fn(3))
     self.assertEqual(3, wrapped_fn(a=3))
@@ -88,7 +88,7 @@ class FnArgsTest(test.TestCase):
     wrapped_fn = functools.partial(fn, test_arg2=456)
     double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
 
-    self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
 
   def test_double_partial_with_positional_args_in_outer_layer(self):
     expected_test_arg1 = 123
@@ -102,7 +102,7 @@ class FnArgsTest(test.TestCase):
     wrapped_fn = functools.partial(fn, test_arg2=456)
     double_wrapped_fn = functools.partial(wrapped_fn, 123)
 
-    self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
 
     self.assertEqual(3, double_wrapped_fn(3))
     self.assertEqual(3, double_wrapped_fn(a=3))
@@ -119,7 +119,7 @@ class FnArgsTest(test.TestCase):
     wrapped_fn = functools.partial(fn, 123)  # binds to test_arg1
     double_wrapped_fn = functools.partial(wrapped_fn, 456)  # binds to test_arg2
 
-    self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
 
     self.assertEqual(3, double_wrapped_fn(3))
     self.assertEqual(3, double_wrapped_fn(a=3))