From: A. Unique TensorFlower Date: Sun, 27 May 2018 17:49:12 +0000 (-0700) Subject: TPUEstimator.export_savedmodel() saves a SavedModel with both TPU and CPU graphs. X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~30 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f2177855323f11e4f9620638e238691c57000373;p=platform%2Fupstream%2Ftensorflow.git TPUEstimator.export_savedmodel() saves a SavedModel with both TPU and CPU graphs. PiperOrigin-RevId: 198229550 --- diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index f531ae5..7d165fd 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -330,6 +330,7 @@ def outside_compilation(computation, args=None): Returns: The Tensors returned by computation. """ + args = [] if args is None else args graph = ops.get_default_graph() # If we are in a TPUReplicateContext, signal that we are now diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index f0c7564..f273756 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -61,6 +62,7 @@ from tensorflow.python.ops import summary_ops_v2 as contrib_summary from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation @@ -71,6 +73,7 @@ from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect + _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _TPU_ESTIMATOR = 'tpu_estimator' @@ -81,6 +84,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' +_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -1773,8 +1777,45 @@ class TPUEstimator(estimator_lib.Estimator): Exporting ========= - Exporting `SavedModel` support on TPU is not yet implemented. So, - `export_savedmodel` is executed on CPU, even if `use_tpu` is true. + `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, + and another with `tag_constants.SERVING` and `tag_constants.TPU`. + At serving time, these tags are used to select metagraph to load. + + Before running the graph on TPU, TPU system needs to be initialized. If + TensorFlow Serving model-server is used, this is done automatically. If + not, please call `session.run(tpu.initialize_system())`. + + `tpu.outside_compilation` can be used to wrap TPU incompatible ops in + `model_fn`. + + Example: + ---------------- + + ``` + def model_fn(features, labels, mode, config, params): + ... + logits = ... + export_outputs = { + 'logits': export_output_lib.PredictOutput( + {'logits': logits}) + } + + def host_call(logits): + class_ids = math_ops.argmax(logits) + classes = string_ops.as_string(class_ids) + export_outputs['classes'] = + export_output_lib.ClassificationOutput(classes=classes) + + tpu.outside_compilation(host_call, [logits]) + + ... + ``` + + Current limitations: + -------------------- + + 1. Outside compilation does not work yet (b/79991729). + """ def __init__(self, @@ -1903,6 +1944,103 @@ class TPUEstimator(estimator_lib.Estimator): self._is_input_fn_invoked = None + def _add_meta_graph_for_mode(self, + builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=True, + mode=model_fn_lib.ModeKeys.PREDICT, + export_tags=None): + if mode != model_fn_lib.ModeKeys.PREDICT: + raise NotImplementedError( + 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' + 'got {}.'.format(mode)) + + super(TPUEstimator, self)._add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables, + mode=mode) + + input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: + input_receiver_fn_map[mode]} + export_tags = [tag_constants.SERVING, tag_constants.TPU] + mode = _REWRITE_FOR_INFERENCE_MODE + super(TPUEstimator, self)._add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=False, + mode=mode, + export_tags=export_tags) + + def _call_model_fn(self, features, labels, mode, config): + if mode == _REWRITE_FOR_INFERENCE_MODE: + return self._call_model_fn_for_inference(features, labels, mode, config) + else: + return super(TPUEstimator, self)._call_model_fn( + features, labels, mode, config) + + def _call_model_fn_for_inference(self, features, labels, mode, config): + """Wraps `_call_model_fn` for `export_savedmodel`.""" + if mode != _REWRITE_FOR_INFERENCE_MODE: + raise ValueError('mode must be {}; ' + 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + + capture = _CapturedObject() + + def computation(): + """Compute tpu tensors used in export_outputs. + + Passed to rewrite_for_inference so that model_fn will be called under + the rewriting contexts. Only tpu tensors are returned, but export_outputs + and scaffold are captured. + + Returns: + A list of Tensors used in export_outputs and not marked for + outside_compilation. + """ + # We should only call model fn once and it should be inside `computation` + # so that building the graph will happen under `rewrite_for_inference`. + mode = model_fn_lib.ModeKeys.PREDICT + estimator_spec = self._call_model_fn(features, labels, mode, config) + + # We pick the TPU tensors out from `export_output` and later return them + # from `computation` for rewriting. + tensors_dict = collections.OrderedDict( + (k, _export_output_to_tensors(v)) + for k, v in six.iteritems(estimator_spec.export_outputs) + ) + tensors = nest.flatten(tensors_dict) + tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + + # We cannot return anything other than `tpu_tensors` here so we capture + # the rest for later use. + capture.capture((estimator_spec, tensors_dict, tensors)) + return tpu_tensors + + tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) + estimator_spec, tensors_dict, tensors = capture.get() + + # Reconstruct `tensors`, but with `tpu_tensors` replaced with + # `tpu_tensors_on_cpu`. + new_tensors = [ + tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t + for t in tensors + ] + # Reconstruct `tensors_dict`. + new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) + # Reconstruct `export_outputs`. + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_tensors_dict) + ) + + return estimator_spec._replace(export_outputs=new_export_outputs) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -2278,6 +2416,76 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn +def _is_tpu_tensor(tensor): + if not isinstance(tensor, ops.Tensor): + return False + try: + tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access + except ValueError: + return True + else: + return False + + +def _export_output_to_tensors(export_output): + """Get a list of `Tensors` used in `export_output`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + Returns: + a list of tensors used in export_output. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + return [export_output.scores, export_output.classes] + elif isinstance(export_output, export_output_lib.RegressionOutput): + return [export_output.value] + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output.outputs.values() + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + +def _clone_export_output_with_tensors(export_output, tensors): + """Clones `export_output` but with new `tensors`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + tensors: a list of `Tensors` used to construct a new `export_output`. + + Returns: + A dict similar to `export_output` but with `tensors`. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + if len(tensors) != 2: + raise ValueError('tensors must be of length 2; ' + 'got {}.'.format(len(tensors))) + return export_output_lib.ClassificationOutput(*tensors) + elif isinstance(export_output, export_output_lib.RegressionOutput): + if len(tensors) != 1: + raise ValueError('tensors must be of length 1; ' + 'got {}'.format(len(tensors))) + return export_output_lib.RegressionOutput(*tensors) + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output_lib.PredictOutput( + dict(zip(export_output.outputs.keys(), tensors))) + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() @@ -2844,4 +3052,3 @@ def _add_item_to_params(params, key, value): else: # Now params is Python dict. params[key] = value - diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 347a760..331ee74 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -212,8 +212,8 @@ class Estimator(object): else: self._session_config = self._config.session_config - self._device_fn = self._config.device_fn or \ - _get_replica_device_setter(self._config) + self._device_fn = ( + self._config.device_fn or _get_replica_device_setter(self._config)) if model_fn is None: raise ValueError('model_fn must be provided to Estimator.') @@ -564,7 +564,8 @@ class Estimator(object): allowed_overrides = set([ '_call_input_fn', '_create_global_step', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names', '_validate_features_in_predict_input' + '_tf_api_names', '_validate_features_in_predict_input', + '_call_model_fn', '_add_meta_graph_for_mode' ]) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) @@ -828,10 +829,14 @@ class Estimator(object): gfile.Rename(temp_export_dir, export_dir) return export_dir - def _add_meta_graph_for_mode( - self, builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables=True, - mode=model_fn_lib.ModeKeys.PREDICT): + def _add_meta_graph_for_mode(self, + builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=True, + mode=model_fn_lib.ModeKeys.PREDICT, + export_tags=None): # pylint: disable=line-too-long """Loads variables and adds them along with a MetaGraphDef for saving. @@ -850,9 +855,14 @@ class Estimator(object): True for the first call to this function, and the SavedModelBuilder will raise an error if that is not the case. mode: tf.estimator.ModeKeys value indicating which mode will be exported. + export_tags: The set of tags with which to save `MetaGraphDef`. If None, + a default set will be selected to matched the passed mode. """ # pylint: enable=line-too-long + if export_tags is None: + export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] input_receiver_fn = input_receiver_fn_map[mode] + with ops.Graph().as_default() as g: self._create_and_assert_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) @@ -877,8 +887,6 @@ class Estimator(object): with tf_session.Session(config=self._session_config) as session: - export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] - local_init_op = ( estimator_spec.scaffold.local_init_op or monitored_session.Scaffold.default_local_init_op())