TPUEstimator.export_savedmodel() saves a SavedModel with both TPU and CPU graphs.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sun, 27 May 2018 17:49:12 +0000 (10:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 27 May 2018 17:51:39 +0000 (10:51 -0700)
PiperOrigin-RevId: 198229550

tensorflow/contrib/tpu/python/tpu/tpu.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/python/estimator/estimator.py

index f531ae5..7d165fd 100644 (file)
@@ -330,6 +330,7 @@ def outside_compilation(computation, args=None):
   Returns:
     The Tensors returned by computation.
   """
+  args = [] if args is None else args
   graph = ops.get_default_graph()
 
   # If we are in a TPUReplicateContext, signal that we are now
index f0c7564..f273756 100644 (file)
@@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export_output as export_output_lib
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -61,6 +62,7 @@ from tensorflow.python.ops import summary_ops_v2 as contrib_summary
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.summary import summary
 from tensorflow.python.training import basic_session_run_hooks
 from tensorflow.python.training import evaluation
@@ -71,6 +73,7 @@ from tensorflow.python.util import function_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 
+
 _INITIAL_LOSS = 1e7
 _ZERO_LOSS = 0.
 _TPU_ESTIMATOR = 'tpu_estimator'
@@ -81,6 +84,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
 _ONE_GIGABYTE = 1024 * 1024 * 1024
 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
 _TPU_TRAIN_OP = '_tpu_train_op'
+_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
 
 _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
 
@@ -1773,8 +1777,45 @@ class TPUEstimator(estimator_lib.Estimator):
   Exporting
   =========
 
-  Exporting `SavedModel` support on TPU is not yet implemented. So,
-  `export_savedmodel` is executed on CPU, even if `use_tpu` is true.
+  `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
+  and another with `tag_constants.SERVING` and `tag_constants.TPU`.
+  At serving time, these tags are used to select metagraph to load.
+
+  Before running the graph on TPU, TPU system needs to be initialized. If
+  TensorFlow Serving model-server is used, this is done automatically. If
+  not, please call `session.run(tpu.initialize_system())`.
+
+  `tpu.outside_compilation` can be used to wrap TPU incompatible ops in
+  `model_fn`.
+
+  Example:
+  ----------------
+
+  ```
+  def model_fn(features, labels, mode, config, params):
+    ...
+    logits = ...
+    export_outputs = {
+      'logits': export_output_lib.PredictOutput(
+        {'logits': logits})
+    }
+
+    def host_call(logits):
+      class_ids = math_ops.argmax(logits)
+      classes = string_ops.as_string(class_ids)
+      export_outputs['classes'] =
+        export_output_lib.ClassificationOutput(classes=classes)
+
+    tpu.outside_compilation(host_call, [logits])
+
+    ...
+  ```
+
+  Current limitations:
+  --------------------
+
+  1. Outside compilation does not work yet (b/79991729).
+
   """
 
   def __init__(self,
@@ -1903,6 +1944,103 @@ class TPUEstimator(estimator_lib.Estimator):
 
     self._is_input_fn_invoked = None
 
+  def _add_meta_graph_for_mode(self,
+                               builder,
+                               input_receiver_fn_map,
+                               checkpoint_path,
+                               strip_default_attrs,
+                               save_variables=True,
+                               mode=model_fn_lib.ModeKeys.PREDICT,
+                               export_tags=None):
+    if mode != model_fn_lib.ModeKeys.PREDICT:
+      raise NotImplementedError(
+          'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
+          'got {}.'.format(mode))
+
+    super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
+                                                       input_receiver_fn_map,
+                                                       checkpoint_path,
+                                                       strip_default_attrs,
+                                                       save_variables,
+                                                       mode=mode)
+
+    input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
+                             input_receiver_fn_map[mode]}
+    export_tags = [tag_constants.SERVING, tag_constants.TPU]
+    mode = _REWRITE_FOR_INFERENCE_MODE
+    super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
+                                                       input_receiver_fn_map,
+                                                       checkpoint_path,
+                                                       strip_default_attrs,
+                                                       save_variables=False,
+                                                       mode=mode,
+                                                       export_tags=export_tags)
+
+  def _call_model_fn(self, features, labels, mode, config):
+    if mode == _REWRITE_FOR_INFERENCE_MODE:
+      return self._call_model_fn_for_inference(features, labels, mode, config)
+    else:
+      return super(TPUEstimator, self)._call_model_fn(
+          features, labels, mode, config)
+
+  def _call_model_fn_for_inference(self, features, labels, mode, config):
+    """Wraps `_call_model_fn` for `export_savedmodel`."""
+    if mode != _REWRITE_FOR_INFERENCE_MODE:
+      raise ValueError('mode must be {}; '
+                       'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))
+
+    capture = _CapturedObject()
+
+    def computation():
+      """Compute tpu tensors used in export_outputs.
+
+      Passed to rewrite_for_inference so that model_fn will be called under
+      the rewriting contexts. Only tpu tensors are returned, but export_outputs
+      and scaffold are captured.
+
+      Returns:
+         A list of Tensors used in export_outputs and not marked for
+         outside_compilation.
+      """
+      # We should only call model fn once and it should be inside `computation`
+      # so that building the graph will happen under `rewrite_for_inference`.
+      mode = model_fn_lib.ModeKeys.PREDICT
+      estimator_spec = self._call_model_fn(features, labels, mode, config)
+
+      # We pick the TPU tensors out from `export_output` and later return them
+      # from `computation` for rewriting.
+      tensors_dict = collections.OrderedDict(
+          (k, _export_output_to_tensors(v))
+          for k, v in six.iteritems(estimator_spec.export_outputs)
+      )
+      tensors = nest.flatten(tensors_dict)
+      tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)]
+
+      # We cannot return anything other than `tpu_tensors` here so we capture
+      # the rest for later use.
+      capture.capture((estimator_spec, tensors_dict, tensors))
+      return tpu_tensors
+
+    tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation)
+    estimator_spec, tensors_dict, tensors = capture.get()
+
+    # Reconstruct `tensors`, but with `tpu_tensors` replaced with
+    # `tpu_tensors_on_cpu`.
+    new_tensors = [
+        tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t
+        for t in tensors
+    ]
+    # Reconstruct `tensors_dict`.
+    new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
+    # Reconstruct `export_outputs`.
+    export_outputs = estimator_spec.export_outputs
+    new_export_outputs = collections.OrderedDict(
+        (k, _clone_export_output_with_tensors(export_outputs[k], v))
+        for k, v in six.iteritems(new_tensors_dict)
+    )
+
+    return estimator_spec._replace(export_outputs=new_export_outputs)
+
   def _create_global_step(self, graph):
     """Creates a global step suitable for TPUs.
 
@@ -2278,6 +2416,76 @@ class TPUEstimator(estimator_lib.Estimator):
     return _model_fn
 
 
+def _is_tpu_tensor(tensor):
+  if not isinstance(tensor, ops.Tensor):
+    return False
+  try:
+    tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR)  # pylint: disable=protected-access
+  except ValueError:
+    return True
+  else:
+    return False
+
+
+def _export_output_to_tensors(export_output):
+  """Get a list of `Tensors` used in `export_output`.
+
+  Args:
+    export_output: an `ExportOutput` object such as `ClassificationOutput`,
+            `RegressionOutput`, or `PredictOutput`.
+  Returns:
+    a list of tensors used in export_output.
+
+  Raises:
+    ValueError: if `export_output` is not one of `ClassificationOutput`,
+        `RegressionOutput`, or `PredictOutput`.
+  """
+  if isinstance(export_output, export_output_lib.ClassificationOutput):
+    return [export_output.scores, export_output.classes]
+  elif isinstance(export_output, export_output_lib.RegressionOutput):
+    return [export_output.value]
+  elif isinstance(export_output, export_output_lib.PredictOutput):
+    return export_output.outputs.values()
+  else:
+    raise ValueError(
+        '`export_output` must be have type `ClassificationOutput`, '
+        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
+
+
+def _clone_export_output_with_tensors(export_output, tensors):
+  """Clones `export_output` but with new `tensors`.
+
+  Args:
+    export_output: an `ExportOutput` object such as `ClassificationOutput`,
+            `RegressionOutput`, or `PredictOutput`.
+    tensors: a list of `Tensors` used to construct a new `export_output`.
+
+  Returns:
+    A dict similar to `export_output` but with `tensors`.
+
+  Raises:
+    ValueError: if `export_output` is not one of `ClassificationOutput`,
+        `RegressionOutput`, or `PredictOutput`.
+  """
+  if isinstance(export_output, export_output_lib.ClassificationOutput):
+    if len(tensors) != 2:
+      raise ValueError('tensors must be of length 2; '
+                       'got {}.'.format(len(tensors)))
+    return export_output_lib.ClassificationOutput(*tensors)
+  elif isinstance(export_output, export_output_lib.RegressionOutput):
+    if len(tensors) != 1:
+      raise ValueError('tensors must be of length 1; '
+                       'got {}'.format(len(tensors)))
+    return export_output_lib.RegressionOutput(*tensors)
+  elif isinstance(export_output, export_output_lib.PredictOutput):
+    return export_output_lib.PredictOutput(
+        dict(zip(export_output.outputs.keys(), tensors)))
+  else:
+    raise ValueError(
+        '`export_output` must be have type `ClassificationOutput`, '
+        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
+
+
 def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   """Executes `model_fn_wrapper` multiple times on all TPU shards."""
   iterations_per_loop_var = _create_or_get_iterations_per_loop()
@@ -2844,4 +3052,3 @@ def _add_item_to_params(params, key, value):
   else:
     # Now params is Python dict.
     params[key] = value
-
index 347a760..331ee74 100644 (file)
@@ -212,8 +212,8 @@ class Estimator(object):
     else:
       self._session_config = self._config.session_config
 
-    self._device_fn = self._config.device_fn or \
-                      _get_replica_device_setter(self._config)
+    self._device_fn = (
+        self._config.device_fn or _get_replica_device_setter(self._config))
 
     if model_fn is None:
       raise ValueError('model_fn must be provided to Estimator.')
@@ -564,7 +564,8 @@ class Estimator(object):
     allowed_overrides = set([
         '_call_input_fn', '_create_global_step',
         '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
-        '_tf_api_names', '_validate_features_in_predict_input'
+        '_tf_api_names', '_validate_features_in_predict_input',
+        '_call_model_fn', '_add_meta_graph_for_mode'
     ])
     estimator_members = set([m for m in Estimator.__dict__.keys()
                              if not m.startswith('__')])
@@ -828,10 +829,14 @@ class Estimator(object):
       gfile.Rename(temp_export_dir, export_dir)
       return export_dir
 
-  def _add_meta_graph_for_mode(
-      self, builder, input_receiver_fn_map, checkpoint_path,
-      strip_default_attrs, save_variables=True,
-      mode=model_fn_lib.ModeKeys.PREDICT):
+  def _add_meta_graph_for_mode(self,
+                               builder,
+                               input_receiver_fn_map,
+                               checkpoint_path,
+                               strip_default_attrs,
+                               save_variables=True,
+                               mode=model_fn_lib.ModeKeys.PREDICT,
+                               export_tags=None):
     # pylint: disable=line-too-long
     """Loads variables and adds them along with a MetaGraphDef for saving.
 
@@ -850,9 +855,14 @@ class Estimator(object):
         True for the first call to this function, and the SavedModelBuilder will
         raise an error if that is not the case.
       mode: tf.estimator.ModeKeys value indicating which mode will be exported.
+      export_tags: The set of tags with which to save `MetaGraphDef`. If None,
+        a default set will be selected to matched the passed mode.
     """
     # pylint: enable=line-too-long
+    if export_tags is None:
+      export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
     input_receiver_fn = input_receiver_fn_map[mode]
+
     with ops.Graph().as_default() as g:
       self._create_and_assert_global_step(g)
       random_seed.set_random_seed(self._config.tf_random_seed)
@@ -877,8 +887,6 @@ class Estimator(object):
 
       with tf_session.Session(config=self._session_config) as session:
 
-        export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
-
         local_init_op = (
             estimator_spec.scaffold.local_init_op or
             monitored_session.Scaffold.default_local_init_op())