Update documentation of ServingInputReceiver when a non-dict is passed as argument.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 10 May 2018 17:49:20 +0000 (10:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 10 May 2018 17:52:05 +0000 (10:52 -0700)
PiperOrigin-RevId: 196138375

tensorflow/python/estimator/export/export.py

index 9aafb56..48ae8cd 100644 (file)
@@ -14,7 +14,6 @@
 # ==============================================================================
 """Configuration and utilities for receiving inputs at serving time."""
 
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -37,7 +36,6 @@ from tensorflow.python.saved_model import signature_def_utils
 from tensorflow.python.util import compat
 from tensorflow.python.util.tf_export import tf_export
 
-
 _SINGLE_FEATURE_DEFAULT_NAME = 'feature'
 _SINGLE_RECEIVER_DEFAULT_NAME = 'input'
 _SINGLE_LABEL_DEFAULT_NAME = 'label'
@@ -69,11 +67,11 @@ def _wrap_and_check_receiver_tensors(receiver_tensors):
 
 def _check_tensor(tensor, name, error_label='feature'):
   """Check that passed `tensor` is a Tensor or SparseTensor."""
-  if not (isinstance(tensor, ops.Tensor)
-          or isinstance(tensor, sparse_tensor.SparseTensor)):
+  if not (isinstance(tensor, ops.Tensor) or
+          isinstance(tensor, sparse_tensor.SparseTensor)):
     fmt_name = ' {}'.format(name) if name else ''
-    value_error = ValueError(
-        '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name))
+    value_error = ValueError('{}{} must be a Tensor or SparseTensor.'.format(
+        error_label, fmt_name))
     # NOTE(ericmc): This if-else block is a specific carve-out for
     # LabeledTensor, which has a `.tensor` attribute and which is
     # convertible to tf.Tensor via ops.convert_to_tensor.
@@ -92,19 +90,23 @@ def _check_tensor(tensor, name, error_label='feature'):
 
 def _check_tensor_key(name, error_label='feature'):
   if not isinstance(name, six.string_types):
-    raise ValueError(
-        '{} keys must be strings: {}.'.format(error_label, name))
+    raise ValueError('{} keys must be strings: {}.'.format(error_label, name))
 
 
 @tf_export('estimator.export.ServingInputReceiver')
-class ServingInputReceiver(collections.namedtuple(
-    'ServingInputReceiver',
-    ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
+class ServingInputReceiver(
+    collections.namedtuple(
+        'ServingInputReceiver',
+        ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
   """A return type for a serving_input_receiver_fn.
 
   The expected return values are:
     features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
-      `SparseTensor`, specifying the features to be passed to the model.
+      `SparseTensor`, specifying the features to be passed to the model. Note:
+      if `features` passed is not a dict, it will be wrapped in a dict with a
+      single entry, using 'feature' as the key.  Consequently, the model must
+      accept a feature dict of the form {'feature': tensor}.  You may use
+      `TensorServingInputReceiver` if you want the tensor to be passed as is.
     receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
       or `SparseTensor`, specifying input nodes where this receiver expects to
       be fed by default.  Typically, this is a single placeholder expecting
@@ -119,7 +121,9 @@ class ServingInputReceiver(collections.namedtuple(
       Defaults to None.
   """
 
-  def __new__(cls, features, receiver_tensors,
+  def __new__(cls,
+              features,
+              receiver_tensors,
               receiver_tensors_alternatives=None):
     if features is None:
       raise ValueError('features must be defined.')
@@ -139,8 +143,9 @@ class ServingInputReceiver(collections.namedtuple(
       for alternative_name, receiver_tensors_alt in (
           six.iteritems(receiver_tensors_alternatives)):
         if not isinstance(receiver_tensors_alt, dict):
-          receiver_tensors_alt = {_SINGLE_RECEIVER_DEFAULT_NAME:
-                                  receiver_tensors_alt}
+          receiver_tensors_alt = {
+              _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
+          }
           # Updating dict during iteration is OK in this case.
           receiver_tensors_alternatives[alternative_name] = (
               receiver_tensors_alt)
@@ -157,9 +162,10 @@ class ServingInputReceiver(collections.namedtuple(
 
 
 @tf_export('estimator.export.TensorServingInputReceiver')
-class TensorServingInputReceiver(collections.namedtuple(
-    'TensorServingInputReceiver',
-    ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
+class TensorServingInputReceiver(
+    collections.namedtuple(
+        'TensorServingInputReceiver',
+        ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
   """A return type for a serving_input_receiver_fn.
 
   This is for use with models that expect a single `Tensor` or `SparseTensor`
@@ -194,7 +200,9 @@ class TensorServingInputReceiver(collections.namedtuple(
       Defaults to None.
   """
 
-  def __new__(cls, features, receiver_tensors,
+  def __new__(cls,
+              features,
+              receiver_tensors,
               receiver_tensors_alternatives=None):
     if features is None:
       raise ValueError('features must be defined.')
@@ -212,9 +220,9 @@ class TensorServingInputReceiver(collections.namedtuple(
         receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)
 
 
-class SupervisedInputReceiver(collections.namedtuple(
-    'SupervisedInputReceiver',
-    ['features', 'labels', 'receiver_tensors'])):
+class SupervisedInputReceiver(
+    collections.namedtuple('SupervisedInputReceiver',
+                           ['features', 'labels', 'receiver_tensors'])):
   """A return type for a training_input_receiver_fn or eval_input_receiver_fn.
 
   This differs from a ServingInputReceiver in that (1) this receiver expects
@@ -272,11 +280,13 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
   Returns:
     A serving_input_receiver_fn suitable for use in serving.
   """
+
   def serving_input_receiver_fn():
     """An input_fn that expects a serialized tf.Example."""
-    serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
-                                                  shape=[default_batch_size],
-                                                  name='input_example_tensor')
+    serialized_tf_example = array_ops.placeholder(
+        dtype=dtypes.string,
+        shape=[default_batch_size],
+        name='input_example_tensor')
     receiver_tensors = {'examples': serialized_tf_example}
     features = parsing_ops.parse_example(serialized_tf_example, feature_spec)
     return ServingInputReceiver(features, receiver_tensors)
@@ -295,10 +305,12 @@ def _placeholder_from_tensor(t, default_batch_size=None):
   return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
 
 
-def _placeholders_from_receiver_tensors_dict(
-    input_vals, default_batch_size=None):
-  return {name: _placeholder_from_tensor(t, default_batch_size)
-          for name, t in input_vals.items()}
+def _placeholders_from_receiver_tensors_dict(input_vals,
+                                             default_batch_size=None):
+  return {
+      name: _placeholder_from_tensor(t, default_batch_size)
+      for name, t in input_vals.items()
+  }
 
 
 @tf_export('estimator.export.build_raw_serving_input_receiver_fn')
@@ -316,6 +328,7 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
   Returns:
     A serving_input_receiver_fn.
   """
+
   def serving_input_receiver_fn():
     """A serving_input_receiver_fn that expects features to be fed directly."""
     receiver_tensors = _placeholders_from_receiver_tensors_dict(
@@ -329,8 +342,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
   return serving_input_receiver_fn
 
 
-def build_raw_supervised_input_receiver_fn(
-    features, labels, default_batch_size=None):
+def build_raw_supervised_input_receiver_fn(features,
+                                           labels,
+                                           default_batch_size=None):
   """Build a supervised_input_receiver_fn for raw features and labels.
 
   This function wraps tensor placeholders in a supervised_receiver_fn
@@ -443,11 +457,12 @@ def build_all_signature_defs(receiver_tensors,
     for receiver_name, receiver_tensors_alt in (
         six.iteritems(receiver_tensors_alternatives)):
       if not isinstance(receiver_tensors_alt, dict):
-        receiver_tensors_alt = {_SINGLE_RECEIVER_DEFAULT_NAME:
-                                receiver_tensors_alt}
+        receiver_tensors_alt = {
+            _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
+        }
       for output_key, export_output in export_outputs.items():
-        signature_name = '{}:{}'.format(receiver_name or 'None',
-                                        output_key or 'None')
+        signature_name = '{}:{}'.format(receiver_name or 'None', output_key or
+                                        'None')
         try:
           signature = export_output.as_signature_def(receiver_tensors_alt)
           signature_def_map[signature_name] = signature
@@ -464,8 +479,11 @@ def build_all_signature_defs(receiver_tensors,
   # signatures produced for serving. We skip this check for training and eval
   # signatures, which are not intended for serving.
   if serving_only:
-    signature_def_map = {k: v for k, v in signature_def_map.items()
-                         if signature_def_utils.is_valid_signature(v)}
+    signature_def_map = {
+        k: v
+        for k, v in signature_def_map.items()
+        if signature_def_utils.is_valid_signature(v)
+    }
   return signature_def_map
 
 
@@ -506,8 +524,8 @@ def _log_signature_report(signature_def_map, excluded_signatures):
 
   if not signature_def_map:
     logging.warn('Export includes no signatures!')
-  elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
-        not in signature_def_map):
+  elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
+        signature_def_map):
     logging.warn('Export includes no default signature!')
 
 
@@ -547,6 +565,5 @@ def get_temp_export_dir(timestamped_export_dir):
   """
   (dirname, basename) = os.path.split(timestamped_export_dir)
   temp_export_dir = os.path.join(
-      compat.as_bytes(dirname),
-      compat.as_bytes('temp-{}'.format(basename)))
+      compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename)))
   return temp_export_dir