import weakref
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
from tensorflow.python.layers import base
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
# pylint: disable=protected-access
# Explanation for protected-access disable: Network has lots of same-class and
def add(self, layer_func):
if isinstance(layer_func, base.Layer):
- args = estimator_util.fn_args(layer_func.call)
+ args = function_utils.fn_args(layer_func.call)
self.track_layer(layer_func)
elif callable(layer_func):
- args = estimator_util.fn_args(layer_func)
+ args = function_utils.fn_args(layer_func)
else:
raise TypeError(
"Sequential.add() takes only tf.layers.Layer objects or callables; "
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:linear",
- "//tensorflow/python/estimator:util",
],
)
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.util import function_utils
_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
def _verify_metric_fn_args(metric_fn):
- args = set(estimator_util.fn_args(metric_fn))
+ args = set(function_utils.fn_args(metric_fn))
invalid_args = list(args - _VALID_METRIC_FN_ARGS)
if invalid_args:
raise ValueError('metric_fn (%s) has following not expected args: %s' %
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
- metric_fn_args = estimator_util.fn_args(metric_fn)
+ metric_fn_args = function_utils.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
import six
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import dnn as dnn_core
from tensorflow.python.estimator.canned import linear as linear_core
from tensorflow.python.framework import ops
+from tensorflow.python.util import function_utils
# pylint: disable=protected-access
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
ValueError: if logit_fn does not return a Tensor or a dictionary mapping
strings to Tensors.
"""
- logit_fn_args = util.fn_args(logit_fn)
+ logit_fn_args = function_utils.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.training import device_setter as device_setter_lib
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
@deprecation.deprecated(
"""Replicate the loss computation across devices."""
tower_specs = []
- model_fn_args = util.fn_args(model_fn)
+ model_fn_args = function_utils.fn_args(model_fn)
optional_params = {}
if 'params' in model_fn_args:
optional_params['params'] = copy.deepcopy(params)
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
__all__ = ["Experiment"]
def _get_standardized_predicate_fn(predicate_fn):
- pred_fn_args = estimator_util.fn_args(predicate_fn)
+ pred_fn_args = function_utils.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
- model_fn_args = util.fn_args(self._model_fn)
+ model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
# Makes deep copy with `config` and params` in case user mutates them.
if isinstance(host_call[1], (tuple, list)):
fullargspec = tf_inspect.getfullargspec(host_call[0])
- fn_args = util.fn_args(host_call[0])
+ fn_args = function_utils.fn_args(host_call[0])
# wrapped_hostcall_with_global_step uses varargs, so we allow that.
if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
raise RuntimeError(
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
- input_fn_args = util.fn_args(input_fn)
+ input_fn_args = function_utils.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
)
py_test(
+ name = "function_utils_test",
+ srcs = ["util/function_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":util",
+ ],
+)
+
+py_test(
name = "tf_contextlib_test",
size = "small",
srcs = ["util/tf_contextlib_test.py"],
],
)
-py_test(
- name = "util_test",
- srcs = ["util_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":util",
- "//tensorflow/python:client_testlib",
- ],
-)
-
py_library(
name = "estimator",
srcs = [
":metric_keys",
":model_fn",
":prediction_keys",
- ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:weights_broadcast_ops",
"//tensorflow/python/feature_column",
"//tensorflow/python/ops/losses",
import six
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Raises:
ValueError: If the signature is unexpected.
"""
- loss_fn_args = util.fn_args(loss_fn)
+ loss_fn_args = function_utils.fn_args(loss_fn)
for required_arg in ['labels', 'logits']:
if required_arg not in loss_fn_args:
raise ValueError(
Returns:
Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
"""
- loss_fn_args = util.fn_args(loss_fn)
+ loss_fn_args = function_utils.fn_args(loss_fn)
kwargs = {}
if 'features' in loss_fn_args:
kwargs['features'] = features
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import errors
from tensorflow.python.training import warm_starting_util
from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
Raises:
ValueError: if input_fn takes invalid arguments.
"""
- input_fn_args = util.fn_args(input_fn)
+ input_fn_args = function_utils.fn_args(input_fn)
kwargs = {}
if 'mode' in input_fn_args:
kwargs['mode'] = mode
Raises:
ValueError: if model_fn returns invalid objects.
"""
- model_fn_args = util.fn_args(self._model_fn)
+ model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
- args = set(util.fn_args(model_fn))
+ args = set(function_utils.fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if params is not None and 'params' not in args:
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
_TMP_DIR = '/tmp'
_ANOTHER_TMP_DIR = '/another_tmp'
_, _, _, _, _ = features, labels, mode, config, params
est = estimator.Estimator(model_fn=model_fn)
- model_fn_args = util.fn_args(est.model_fn)
+ model_fn_args = function_utils.fn_args(est.model_fn)
self.assertEqual(
set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
_, _ = features, labels
est = estimator.Estimator(model_fn=model_fn)
- model_fn_args = util.fn_args(est.model_fn)
+ model_fn_args = function_utils.fn_args(est.model_fn)
self.assertEqual(
set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
-from tensorflow.python.estimator import util
from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
message='tf_random_seed must be integer.')
_validate('device_fn', lambda device_fn: six.callable(device_fn) and
- set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
+ set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
message='device_fn must be callable with exactly'
' one argument "op".')
# limitations under the License.
# ==============================================================================
-"""Utility to retrieve function args."""
+"""Utilities for Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import time
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _is_bounded_method(fn):
- _, fn = tf_decorator.unwrap(fn)
- return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
-
-
-def _is_callable_object(obj):
- return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
-
-
-def fn_args(fn):
- """Get argument names for function-like object.
-
- Args:
- fn: Function, or function-like object (e.g., result of `functools.partial`).
-
- Returns:
- `tuple` of string argument names.
-
- Raises:
- ValueError: if partial function has positionally bound arguments
- """
- if isinstance(fn, functools.partial):
- args = fn_args(fn.func)
- args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
- else:
- if _is_callable_object(fn):
- fn = fn.__call__
- args = tf_inspect.getfullargspec(fn).args
- if _is_bounded_method(fn):
- args.remove('self')
- return tuple(args)
+from tensorflow.python.util import function_utils
+fn_args = function_utils.fn_args
# When we create a timestamped directory, there is a small chance that the
# directory already exists because another process is also creating these
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
+from tensorflow.python.estimator import util as function_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import checkpointable
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
self._uses_inputs_arg = True
self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
import copy
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
try:
call_has_scope_arg = self._call_has_scope_arg
except AttributeError:
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
self._call_has_scope_arg = 'scope' in self._call_fn_args
call_has_scope_arg = self._call_has_scope_arg
if call_has_scope_arg:
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
"use_resource": use_resource,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
- if "constraint" in estimator_util.fn_args(custom_getter):
+ if "constraint" in function_utils.fn_args(custom_getter):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
import six
from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.estimator import util
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import session_run_hook
+from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
`step_context`. It may also optionally have `self` for cases when it
belongs to an object.
"""
- step_fn_arguments = util.fn_args(step_fn)
+ step_fn_arguments = function_utils.fn_args(step_fn)
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
'self',
'step_context',
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to retrieve function args."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def _is_bounded_method(fn):
+ _, fn = tf_decorator.unwrap(fn)
+ return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
+
+
+def _is_callable_object(obj):
+ return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
+
+
+def fn_args(fn):
+ """Get argument names for function-like object.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `tuple` of string argument names.
+
+ Raises:
+ ValueError: if partial function has positionally bound arguments
+ """
+ if isinstance(fn, functools.partial):
+ args = fn_args(fn.func)
+ args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
+ else:
+ if _is_callable_object(fn):
+ fn = fn.__call__
+ args = tf_inspect.getfullargspec(fn).args
+ if _is_bounded_method(fn):
+ args.remove('self')
+ return tuple(args)
import functools
-from tensorflow.python.estimator import util
from tensorflow.python.platform import test
+from tensorflow.python.util import function_utils
class FnArgsTest(test.TestCase):
def test_simple_function(self):
def fn(a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(fn))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
def test_callable(self):
def __call__(self, a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(Foo()))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
def test_bounded_method(self):
def bar(self, a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
def test_partial_function(self):
expected_test_arg = 123
wrapped_fn = functools.partial(fn, test_arg=123)
- self.assertEqual(('a',), util.fn_args(wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
def test_partial_function_with_positional_args(self):
expected_test_arg = 123
wrapped_fn = functools.partial(fn, 123)
- self.assertEqual(('a',), util.fn_args(wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
self.assertEqual(3, wrapped_fn(3))
self.assertEqual(3, wrapped_fn(a=3))
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
def test_double_partial_with_positional_args_in_outer_layer(self):
expected_test_arg1 = 123
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, 123)
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))