Add a convenience function, build_supervised_input_receiver_fn_from_input_fn,
authorKarmel Allison <karmel@google.com>
Wed, 30 May 2018 23:25:00 +0000 (16:25 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 30 May 2018 23:27:30 +0000 (16:27 -0700)
that takes an Estimator input_fn and returns an input receiver function.

PiperOrigin-RevId: 198638593

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/python/BUILD
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/export/export.py
tensorflow/python/estimator/export/export_test.py
tensorflow/python/estimator/util.py
tensorflow/python/estimator/util_test.py [new file with mode: 0644]

index aeb7ba5..4465833 100644 (file)
@@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.estimator.export import export_output as export_output_lib
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -2748,7 +2749,8 @@ class _Inputs(object):
     """
     iterator = self._dataset.make_initializable_iterator()
     # pylint: disable=protected-access
-    hook = estimator_lib._DatasetInitializerHook(iterator)
+    hook = estimator_util._DatasetInitializerHook(iterator)
+    # pylint: enable=protected-access
     self._iterator = iterator
     return hook
 
index 679ef93..0542c2f 100644 (file)
@@ -2699,7 +2699,6 @@ py_library(
         ":util",
         ":variables",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/estimator:util",
         "@six_archive//:six",
     ],
 )
index 0754041..9c4d58b 100644 (file)
@@ -446,7 +446,26 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/python:platform",
+        "//tensorflow/python:training",
         "//tensorflow/python:util",
+        "//tensorflow/python/data",
+    ],
+)
+
+py_test(
+    name = "util_test",
+    srcs = ["util_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],  # b/67510291
+    deps = [
+        ":util",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:training",
+        "//tensorflow/python/data",
+        "//third_party/py/numpy",
+        "@six_archive//:six",
     ],
 )
 
@@ -598,6 +617,7 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":util",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
index 331ee74..cfbf7e2 100644 (file)
@@ -32,10 +32,10 @@ from tensorflow.core.framework import summary_pb2
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.client import session as tf_session
-from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.estimator.export import export as export_helpers
 from tensorflow.python.estimator.export import export_output
 from tensorflow.python.framework import errors
@@ -964,17 +964,9 @@ class Estimator(object):
   def _get_features_from_input_fn(self, input_fn, mode):
     """Extracts the `features` from return values of `input_fn`."""
     result = self._call_input_fn(input_fn, mode)
-    input_hooks = []
-    if isinstance(result, dataset_ops.Dataset):
-      iterator = result.make_initializable_iterator()
-      input_hooks.append(_DatasetInitializerHook(iterator))
-      result = iterator.get_next()
-    if isinstance(result, (list, tuple)):
-      # Unconditionally drop the label (the second element of result).
-      result = result[0]
-
+    result, _, hooks = estimator_util.parse_input_fn_result(result)
     self._validate_features_in_predict_input(result)
-    return result, input_hooks
+    return result, hooks
 
   def _validate_features_in_predict_input(self, result):
     if not _has_dataset_or_queue_runner(result):
@@ -984,25 +976,13 @@ class Estimator(object):
 
   def _get_features_and_labels_from_input_fn(self, input_fn, mode):
     """Extracts the `features` and labels from return values of `input_fn`."""
-    input_hooks = []
     if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
       result = self._distribution.distribute_dataset(
           lambda: self._call_input_fn(input_fn, mode))
-      iterator = result.make_initializable_iterator()
-      input_hooks.append(_DatasetInitializerHook(iterator))
-      result = iterator.get_next()
     else:
       result = self._call_input_fn(input_fn, mode)
-      if isinstance(result, dataset_ops.Dataset):
-        iterator = result.make_initializable_iterator()
-        input_hooks.append(_DatasetInitializerHook(iterator))
-        result = iterator.get_next()
-    if isinstance(result, (list, tuple)):
-      if len(result) != 2:
-        raise ValueError(
-            'input_fn should return (features, labels) as a len 2 tuple.')
-      return result[0], result[1], input_hooks
-    return result, None, input_hooks
+
+    return estimator_util.parse_input_fn_result(result)
 
   def _extract_batch_length(self, preds_evaluated):
     """Extracts batch length of predictions."""
@@ -1067,9 +1047,15 @@ class Estimator(object):
       mode: ModeKeys
 
     Returns:
-      Either features or (features, labels) where features and labels are:
-        features - `Tensor` or dictionary of string feature name to `Tensor`.
-        labels - `Tensor` or dictionary of `Tensor` with labels.
+      The return value of the passed input_fn, which should be one of:
+
+        * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+            tuple (features, labels) with same constraints as below.
+        * A tuple (features, labels): Where `features` is a `Tensor` or a
+          dictionary of string feature name to `Tensor` and `labels` is a
+          `Tensor` or a dictionary of string label name to `Tensor`. Both
+          `features` and `labels` are consumed by `model_fn`. They should
+          satisfy the expectation of `model_fn` from inputs.
 
     Raises:
       ValueError: if input_fn takes invalid arguments.
@@ -1610,19 +1596,6 @@ def _has_dataset_or_queue_runner(maybe_tensor):
   # Now, check queue.
   return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
 
-
-class _DatasetInitializerHook(training.SessionRunHook):
-
-  def __init__(self, iterator):
-    self._iterator = iterator
-
-  def begin(self):
-    self._initializer = self._iterator.initializer
-
-  def after_create_session(self, session, coord):
-    del coord
-    session.run(self._initializer)
-
 VocabInfo = warm_starting_util.VocabInfo  # pylint: disable=invalid-name
 tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
 
index 48ae8cd..ff19a0a 100644 (file)
@@ -404,6 +404,42 @@ def build_raw_supervised_input_receiver_fn(features,
   return supervised_input_receiver_fn
 
 
+def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args):
+  """Get a function that returns a SupervisedInputReceiver matching an input_fn.
+
+  Note that this function calls the input_fn in a local graph in order to
+  extract features and labels. Placeholders are then created from those
+  features and labels in the default graph.
+
+  Args:
+    input_fn: An Estimator input_fn, which is a function that returns one of:
+
+      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+          tuple (features, labels) with same constraints as below.
+      * A tuple (features, labels): Where `features` is a `Tensor` or a
+        dictionary of string feature name to `Tensor` and `labels` is a
+        `Tensor` or a dictionary of string label name to `Tensor`. Both
+        `features` and `labels` are consumed by `model_fn`. They should
+        satisfy the expectation of `model_fn` from inputs.
+
+    **input_fn_args: set of kwargs to be passed to the input_fn. Note that
+      these will not be checked or validated here, and any errors raised by
+      the input_fn will be thrown to the top.
+
+  Returns:
+    A function taking no arguments that, when called, returns a
+    SupervisedInputReceiver. This function can be passed in as part of the
+    input_receiver_map when exporting SavedModels from Estimator with multiple
+    modes.
+  """
+  # Wrap the input_fn call in a graph to prevent sullying the default namespace
+  with ops.Graph().as_default():
+    result = input_fn(**input_fn_args)
+    features, labels, _ = util.parse_input_fn_result(result)
+  # Placeholders are created back in the default graph.
+  return build_raw_supervised_input_receiver_fn(features, labels)
+
+
 ### Below utilities are specific to SavedModel exports.
 
 
index 0af587f..a707471 100644 (file)
@@ -459,6 +459,41 @@ class ExportTest(test_util.TensorFlowTestCase):
     with self.assertRaises(ValueError):
       export.build_raw_supervised_input_receiver_fn(features, labels)
 
+  def test_build_supervised_input_receiver_fn_from_input_fn(self):
+    def dummy_input_fn():
+      return ({"x": constant_op.constant([[1], [1]]),
+               "y": constant_op.constant(["hello", "goodbye"])},
+              constant_op.constant([[1], [1]]))
+
+    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
+        dummy_input_fn)
+
+    with ops.Graph().as_default():
+      input_receiver = input_receiver_fn()
+      self.assertEqual(set(["x", "y"]),
+                       set(input_receiver.features.keys()))
+      self.assertIsInstance(input_receiver.labels, ops.Tensor)
+      self.assertEqual(set(["x", "y", "label"]),
+                       set(input_receiver.receiver_tensors.keys()))
+
+  def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
+    def dummy_input_fn(feature_key="x"):
+      return ({feature_key: constant_op.constant([[1], [1]]),
+               "y": constant_op.constant(["hello", "goodbye"])},
+              {"my_label": constant_op.constant([[1], [1]])})
+
+    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
+        dummy_input_fn, feature_key="z")
+
+    with ops.Graph().as_default():
+      input_receiver = input_receiver_fn()
+      self.assertEqual(set(["z", "y"]),
+                       set(input_receiver.features.keys()))
+      self.assertEqual(set(["my_label"]),
+                       set(input_receiver.labels.keys()))
+      self.assertEqual(set(["z", "y", "my_label"]),
+                       set(input_receiver.receiver_tensors.keys()))
+
   def test_build_all_signature_defs_without_receiver_alternatives(self):
     receiver_tensor = array_ops.placeholder(dtypes.string)
     output_1 = constant_op.constant([1.])
index e4e1d37..924ca30 100644 (file)
@@ -24,6 +24,7 @@ import time
 
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training
 from tensorflow.python.util import compat
 from tensorflow.python.util import function_utils
 
@@ -72,3 +73,59 @@ def get_timestamped_dir(dir_base):
         result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
   raise RuntimeError('Failed to obtain a unique export directory name after '
                      '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
+
+
+def parse_input_fn_result(result):
+  """Gets features, labels, and hooks from the result of an Estimator input_fn.
+
+  Args:
+    result: output of an input_fn to an estimator, which should be one of:
+
+      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+          tuple (features, labels) with same constraints as below.
+      * A tuple (features, labels): Where `features` is a `Tensor` or a
+        dictionary of string feature name to `Tensor` and `labels` is a
+        `Tensor` or a dictionary of string label name to `Tensor`. Both
+        `features` and `labels` are consumed by `model_fn`. They should
+        satisfy the expectation of `model_fn` from inputs.
+
+  Returns:
+    Tuple of features, labels, and input_hooks, where features are as described
+    above, labels are as described above or None, and input_hooks are a list
+    of SessionRunHooks to be included when running.
+
+  Raises:
+    ValueError: if the result is a list or tuple of length != 2.
+  """
+  input_hooks = []
+  try:
+    # We can't just check whether this is a tf.data.Dataset instance here,
+    # as this is plausibly a PerDeviceDataset. Try treating as a dataset first.
+    iterator = result.make_initializable_iterator()
+  except AttributeError:
+    # Not a dataset or dataset-like-object. Move along.
+    pass
+  else:
+    input_hooks.append(_DatasetInitializerHook(iterator))
+    result = iterator.get_next()
+
+  if isinstance(result, (list, tuple)):
+    if len(result) != 2:
+      raise ValueError(
+          'input_fn should return (features, labels) as a len 2 tuple.')
+    return result[0], result[1], input_hooks
+  return result, None, input_hooks
+
+
+class _DatasetInitializerHook(training.SessionRunHook):
+  """Creates a SessionRunHook that initializes the passed iterator."""
+
+  def __init__(self, iterator):
+    self._iterator = iterator
+
+  def begin(self):
+    self._initializer = self._iterator.initializer
+
+  def after_create_session(self, session, coord):
+    del coord
+    session.run(self._initializer)
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py
new file mode 100644 (file)
index 0000000..d7e0610
--- /dev/null
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for util.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import util
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+from tensorflow.python.training import training
+
+
+class UtilTest(test.TestCase):
+  """Tests for miscellaneous Estimator utils."""
+
+  def test_parse_input_fn_result_tuple(self):
+    def _input_fn():
+      features = constant_op.constant(np.arange(100))
+      labels = constant_op.constant(np.arange(100, 200))
+      return features, labels
+
+    features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+    with self.test_session() as sess:
+      vals = sess.run([features, labels])
+
+    self.assertAllEqual(vals[0], np.arange(100))
+    self.assertAllEqual(vals[1], np.arange(100, 200))
+    self.assertEqual(hooks, [])
+
+  def test_parse_input_fn_result_dataset(self):
+    def _input_fn():
+      features = np.expand_dims(np.arange(100), 0)
+      labels = np.expand_dims(np.arange(100, 200), 0)
+      return dataset_ops.Dataset.from_tensor_slices((features, labels))
+
+    features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+    with training.MonitoredSession(hooks=hooks) as sess:
+      vals = sess.run([features, labels])
+
+    self.assertAllEqual(vals[0], np.arange(100))
+    self.assertAllEqual(vals[1], np.arange(100, 200))
+    self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
+
+  def test_parse_input_fn_result_features_only(self):
+    def _input_fn():
+      return constant_op.constant(np.arange(100))
+
+    features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+    with self.test_session() as sess:
+      vals = sess.run([features])
+
+    self.assertAllEqual(vals[0], np.arange(100))
+    self.assertEqual(labels, None)
+    self.assertEqual(hooks, [])
+
+  def test_parse_input_fn_result_features_only_dataset(self):
+    def _input_fn():
+      features = np.expand_dims(np.arange(100), 0)
+      return dataset_ops.Dataset.from_tensor_slices(features)
+
+    features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+    with training.MonitoredSession(hooks=hooks) as sess:
+      vals = sess.run([features])
+
+    self.assertAllEqual(vals[0], np.arange(100))
+    self.assertEqual(labels, None)
+    self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
+
+  def test_parse_input_fn_result_invalid(self):
+    def _input_fn():
+      features = np.expand_dims(np.arange(100), 0)
+      labels = np.expand_dims(np.arange(100, 200), 0)
+      return dataset_ops.Dataset.from_tensor_slices((features, labels, labels))
+
+    with self.assertRaisesRegexp(ValueError, 'input_fn should return'):
+      util.parse_input_fn_result(_input_fn())
+
+
+if __name__ == '__main__':
+  test.main()