from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import evaluation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
+
_INITIAL_LOSS = 1e7
_ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'tpu_estimator'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
+_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
Exporting
=========
- Exporting `SavedModel` support on TPU is not yet implemented. So,
- `export_savedmodel` is executed on CPU, even if `use_tpu` is true.
+ `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
+ and another with `tag_constants.SERVING` and `tag_constants.TPU`.
+ At serving time, these tags are used to select metagraph to load.
+
+ Before running the graph on TPU, TPU system needs to be initialized. If
+ TensorFlow Serving model-server is used, this is done automatically. If
+ not, please call `session.run(tpu.initialize_system())`.
+
+ `tpu.outside_compilation` can be used to wrap TPU incompatible ops in
+ `model_fn`.
+
+ Example:
+ ----------------
+
+ ```
+ def model_fn(features, labels, mode, config, params):
+ ...
+ logits = ...
+ export_outputs = {
+ 'logits': export_output_lib.PredictOutput(
+ {'logits': logits})
+ }
+
+ def host_call(logits):
+ class_ids = math_ops.argmax(logits)
+ classes = string_ops.as_string(class_ids)
+ export_outputs['classes'] =
+ export_output_lib.ClassificationOutput(classes=classes)
+
+ tpu.outside_compilation(host_call, [logits])
+
+ ...
+ ```
+
+ Current limitations:
+ --------------------
+
+ 1. Outside compilation does not work yet (b/79991729).
+
"""
def __init__(self,
self._is_input_fn_invoked = None
+ def _add_meta_graph_for_mode(self,
+ builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables=True,
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ export_tags=None):
+ if mode != model_fn_lib.ModeKeys.PREDICT:
+ raise NotImplementedError(
+ 'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
+ 'got {}.'.format(mode))
+
+ super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables,
+ mode=mode)
+
+ input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
+ input_receiver_fn_map[mode]}
+ export_tags = [tag_constants.SERVING, tag_constants.TPU]
+ mode = _REWRITE_FOR_INFERENCE_MODE
+ super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables=False,
+ mode=mode,
+ export_tags=export_tags)
+
+ def _call_model_fn(self, features, labels, mode, config):
+ if mode == _REWRITE_FOR_INFERENCE_MODE:
+ return self._call_model_fn_for_inference(features, labels, mode, config)
+ else:
+ return super(TPUEstimator, self)._call_model_fn(
+ features, labels, mode, config)
+
+ def _call_model_fn_for_inference(self, features, labels, mode, config):
+ """Wraps `_call_model_fn` for `export_savedmodel`."""
+ if mode != _REWRITE_FOR_INFERENCE_MODE:
+ raise ValueError('mode must be {}; '
+ 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))
+
+ capture = _CapturedObject()
+
+ def computation():
+ """Compute tpu tensors used in export_outputs.
+
+ Passed to rewrite_for_inference so that model_fn will be called under
+ the rewriting contexts. Only tpu tensors are returned, but export_outputs
+ and scaffold are captured.
+
+ Returns:
+ A list of Tensors used in export_outputs and not marked for
+ outside_compilation.
+ """
+ # We should only call model fn once and it should be inside `computation`
+ # so that building the graph will happen under `rewrite_for_inference`.
+ mode = model_fn_lib.ModeKeys.PREDICT
+ estimator_spec = self._call_model_fn(features, labels, mode, config)
+
+ # We pick the TPU tensors out from `export_output` and later return them
+ # from `computation` for rewriting.
+ tensors_dict = collections.OrderedDict(
+ (k, _export_output_to_tensors(v))
+ for k, v in six.iteritems(estimator_spec.export_outputs)
+ )
+ tensors = nest.flatten(tensors_dict)
+ tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)]
+
+ # We cannot return anything other than `tpu_tensors` here so we capture
+ # the rest for later use.
+ capture.capture((estimator_spec, tensors_dict, tensors))
+ return tpu_tensors
+
+ tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation)
+ estimator_spec, tensors_dict, tensors = capture.get()
+
+ # Reconstruct `tensors`, but with `tpu_tensors` replaced with
+ # `tpu_tensors_on_cpu`.
+ new_tensors = [
+ tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t
+ for t in tensors
+ ]
+ # Reconstruct `tensors_dict`.
+ new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
+ # Reconstruct `export_outputs`.
+ export_outputs = estimator_spec.export_outputs
+ new_export_outputs = collections.OrderedDict(
+ (k, _clone_export_output_with_tensors(export_outputs[k], v))
+ for k, v in six.iteritems(new_tensors_dict)
+ )
+
+ return estimator_spec._replace(export_outputs=new_export_outputs)
+
def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
return _model_fn
+def _is_tpu_tensor(tensor):
+ if not isinstance(tensor, ops.Tensor):
+ return False
+ try:
+ tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access
+ except ValueError:
+ return True
+ else:
+ return False
+
+
+def _export_output_to_tensors(export_output):
+ """Get a list of `Tensors` used in `export_output`.
+
+ Args:
+ export_output: an `ExportOutput` object such as `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ Returns:
+ a list of tensors used in export_output.
+
+ Raises:
+ ValueError: if `export_output` is not one of `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ """
+ if isinstance(export_output, export_output_lib.ClassificationOutput):
+ return [export_output.scores, export_output.classes]
+ elif isinstance(export_output, export_output_lib.RegressionOutput):
+ return [export_output.value]
+ elif isinstance(export_output, export_output_lib.PredictOutput):
+ return export_output.outputs.values()
+ else:
+ raise ValueError(
+ '`export_output` must be have type `ClassificationOutput`, '
+ '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
+
+
+def _clone_export_output_with_tensors(export_output, tensors):
+ """Clones `export_output` but with new `tensors`.
+
+ Args:
+ export_output: an `ExportOutput` object such as `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ tensors: a list of `Tensors` used to construct a new `export_output`.
+
+ Returns:
+ A dict similar to `export_output` but with `tensors`.
+
+ Raises:
+ ValueError: if `export_output` is not one of `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ """
+ if isinstance(export_output, export_output_lib.ClassificationOutput):
+ if len(tensors) != 2:
+ raise ValueError('tensors must be of length 2; '
+ 'got {}.'.format(len(tensors)))
+ return export_output_lib.ClassificationOutput(*tensors)
+ elif isinstance(export_output, export_output_lib.RegressionOutput):
+ if len(tensors) != 1:
+ raise ValueError('tensors must be of length 1; '
+ 'got {}'.format(len(tensors)))
+ return export_output_lib.RegressionOutput(*tensors)
+ elif isinstance(export_output, export_output_lib.PredictOutput):
+ return export_output_lib.PredictOutput(
+ dict(zip(export_output.outputs.keys(), tensors)))
+ else:
+ raise ValueError(
+ '`export_output` must be have type `ClassificationOutput`, '
+ '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
+
+
def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
else:
# Now params is Python dict.
params[key] = value
-
else:
self._session_config = self._config.session_config
- self._device_fn = self._config.device_fn or \
- _get_replica_device_setter(self._config)
+ self._device_fn = (
+ self._config.device_fn or _get_replica_device_setter(self._config))
if model_fn is None:
raise ValueError('model_fn must be provided to Estimator.')
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_validate_features_in_predict_input'
+ '_tf_api_names', '_validate_features_in_predict_input',
+ '_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
gfile.Rename(temp_export_dir, export_dir)
return export_dir
- def _add_meta_graph_for_mode(
- self, builder, input_receiver_fn_map, checkpoint_path,
- strip_default_attrs, save_variables=True,
- mode=model_fn_lib.ModeKeys.PREDICT):
+ def _add_meta_graph_for_mode(self,
+ builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables=True,
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ export_tags=None):
# pylint: disable=line-too-long
"""Loads variables and adds them along with a MetaGraphDef for saving.
True for the first call to this function, and the SavedModelBuilder will
raise an error if that is not the case.
mode: tf.estimator.ModeKeys value indicating which mode will be exported.
+ export_tags: The set of tags with which to save `MetaGraphDef`. If None,
+ a default set will be selected to matched the passed mode.
"""
# pylint: enable=line-too-long
+ if export_tags is None:
+ export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
input_receiver_fn = input_receiver_fn_map[mode]
+
with ops.Graph().as_default() as g:
self._create_and_assert_global_step(g)
random_seed.set_random_seed(self._config.tf_random_seed)
with tf_session.Session(config=self._session_config) as session:
- export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
-
local_init_op = (
estimator_spec.scaffold.local_init_op or
monitored_session.Scaffold.default_local_init_op())