weights=weights,
processed_labels=processed_labels)
- def create_estimator_spec(
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
- """Returns an `EstimatorSpec`.
+ """Returns an `model_fn._TPUEstimatorSpec`.
Args:
features: Input `dict` of `Tensor` or `SparseTensor` objects.
`loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
- `EstimatorSpec`.
+ `model_fn._TPUEstimatorSpec`.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode, or if both are set.
classifier_output = head_lib._classification_output( # pylint:disable=protected-access
scores=probabilities, n_classes=self._n_classes,
label_vocabulary=self._label_vocabulary)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=self._eval_metric_ops(
- labels=processed_labels,
- probabilities=probabilities,
- weights=weights,
- unreduced_loss=unreduced_loss,
- regularization_loss=regularization_loss))
+ eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access
+ self._eval_metric_ops, {
+ 'labels': processed_labels,
+ 'probabilities': probabilities,
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss,
+ }))
# Train.
if optimizer is not None:
summary.scalar(
head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
STOP = -2
-class TPUEstimatorSpec(
- collections.namedtuple('TPUEstimatorSpec', [
- 'mode',
- 'predictions',
- 'loss',
- 'train_op',
- 'eval_metrics',
- 'export_outputs',
- 'scaffold_fn',
- 'host_call'
- ])):
+class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
self._call_model_fn(features, labels))
loss, train_op = estimator_spec.loss, estimator_spec.train_op
- if isinstance(estimator_spec, TPUEstimatorSpec):
+ if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
else:
captured_scaffold_fn.capture(None)
# outfeed.
with ops.control_dependencies([train_op]):
host_call_outfeed_ops = []
- if (isinstance(estimator_spec, TPUEstimatorSpec) and
- estimator_spec.host_call is not None):
+ if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access
+ and estimator_spec.host_call is not None):
host_call.record({'host_call': estimator_spec.host_call})
host_call_outfeed_ops = host_call.create_enqueue_op()
with ops.control_dependencies(host_call_outfeed_ops):
features, labels = inputs.features_and_labels()
tpu_estimator_spec = self._call_model_fn(features, labels)
- if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
+ if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
raise RuntimeError(
'estimator_spec used by TPU evaluation must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
tpu_estimator_spec = self._call_model_fn(
features, labels, is_export_mode=False)
- if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
+ if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
raise RuntimeError(
'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
estimator_spec = self._model_fn(features=features, **kwargs)
if (self._ctx.is_running_on_cpu(is_export_mode) and
- isinstance(estimator_spec, TPUEstimatorSpec)):
+ isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
return estimator_spec.as_estimator_spec()
def _verify_estimator_spec(self, estimator_spec):
"""Validates the estimator_spec."""
- if isinstance(estimator_spec, TPUEstimatorSpec):
+ if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
return estimator_spec
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
Raises:
ValueError: If features has the wrong type.
"""
+ tpu_estimator_spec = _tpu_dnn_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config)
+ return tpu_estimator_spec.as_estimator_spec()
+
+
+def _tpu_dnn_model_fn(features,
+ labels,
+ mode,
+ head,
+ hidden_units,
+ feature_columns,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None):
+ """Deep Neural Net model_fn for TPUEstimator.
+
+ Args:
+ features: dict of `Tensor`.
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+ dtype `int32` or `int64` in the range `[0, n_classes)`.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ head: A `head_lib._Head` instance.
+ hidden_units: Iterable of integer number of hidden units per layer.
+ feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use the Adagrad
+ optimizer with a default learning rate of 0.05.
+ activation_fn: Activation function applied to each layer.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Partitioner for input layer. Defaults
+ to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Returns:
+ A `model_fn.TPUEstimatorSpec` instance.
+
+ Raises:
+ ValueError: If features has the wrong type.
+ """
if not isinstance(features, dict):
raise ValueError('features should be a dictionary of `Tensor`s. '
'Given type: {}'.format(type(features)))
input_layer_partitioner=input_layer_partitioner)
logits = logit_fn(features=features, mode=mode)
- return head.create_estimator_spec(
+ return head._create_tpu_estimator_spec( # pylint: disable=protected-access
features=features,
mode=mode,
labels=labels,
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
"""
- if n_classes == 2:
- head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
- weight_column=weight_column,
- label_vocabulary=label_vocabulary,
- loss_reduction=loss_reduction)
- else:
- head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary,
- loss_reduction=loss_reduction)
-
+ head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
+ n_classes, weight_column, label_vocabulary, loss_reduction)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
hidden_weights_names + hidden_biases_names +
[LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0'])
- def _create_estimator_spec(
+ def _create_tpu_estimator_spec(
features, mode, logits, labels, train_op_fn=None, optimizer=None):
del features, labels # Not used.
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
train_op = train_op_fn(loss)
elif optimizer is not None:
train_op = optimizer.minimize(loss, global_step=None)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op)
elif mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss))
+ return model_fn._TPUEstimatorSpec(
+ mode=mode, loss=array_ops.identity(loss))
elif mode == model_fn.ModeKeys.PREDICT:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec(
mode=mode, predictions={'logits': array_ops.identity(logits)})
else:
testcase.fail('Invalid mode: {}'.format(mode))
+ def _create_estimator_spec(
+ features, mode, logits, labels, train_op_fn=None, optimizer=None):
+ tpu_spec = _create_tpu_estimator_spec(
+ features, mode, logits, labels, train_op_fn, optimizer)
+ return tpu_spec.as_estimator_spec()
+
head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
head.logits_dimension = logits_dimension
- head.create_estimator_spec = test.mock.MagicMock(wraps=_create_estimator_spec)
+ head._create_tpu_estimator_spec = test.mock.MagicMock(
+ wraps=_create_tpu_estimator_spec)
+ head.create_estimator_spec = test.mock.MagicMock(
+ wraps=_create_estimator_spec)
return head
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
return '%s/%s' % (val, head_name) if head_name else val
+def _create_eval_metrics_tuple(fn, kwargs):
+ """Creates TPU eval metrics tuple.
+
+ Helper function to make eval_metric tuple (eval_metric_fn, fn_kwargs) used
+ by `TPUEstimator`. TPUEstimator requires that `eval_metric_fn` take
+ exclusively Tensor arguments. This helper can help create such a function from
+ a more generic function that can take both Tensor and non-Tensor arguments.
+
+ Args:
+ fn: A eval_metric_fn that takes both Tensor and non-Tensor arguments.
+ This function must return a dict of form
+ {'metric name': (metric_tensor, eval_op)}
+ kwargs: Dict of arguments for `fn`.
+
+ Returns:
+ `eval_metric` tuple that can be passed to a `model_fn._TPUEstimatorSpec`.
+ """
+ tensor_kwargs = {}
+ nontensor_kwargs = {}
+ for k, v in six.iteritems(kwargs):
+ if tensor_util.is_tensor(v):
+ tensor_kwargs[k] = v
+ else:
+ nontensor_kwargs[k] = v
+ def _fn(**tensors):
+ return fn(**dict(nontensor_kwargs, **tensors))
+ return (_fn, tensor_kwargs)
+
+
class _Head(object):
"""Interface for the head/top of a model.
# TODO(b/65403806): By default, collect regularization_losses from
# GraphKeys.REGULARIZATION_LOSSES collection.
- @abc.abstractmethod
def create_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
Returns:
`EstimatorSpec`.
"""
- raise NotImplementedError('Calling an abstract method.')
+ try:
+ tpu_estimator_spec = (
+ self._create_tpu_estimator_spec(
+ features, mode, logits, labels, optimizer, train_op_fn,
+ regularization_losses))
+ return tpu_estimator_spec.as_estimator_spec()
+ except NotImplementedError:
+ # Not all subclasses of _Head will have implemented
+ # _create_tpu_estimator_spec. If it is implemented, we can use it to
+ # create our `EstimatorSpec` here.
+ raise NotImplementedError(
+ 'Subclasses of _Head must implement `create_estimator_spec()` or '
+ '_create_tpu_estimator_spec().')
+
+ def _create_tpu_estimator_spec(
+ self, features, mode, logits, labels=None, optimizer=None,
+ train_op_fn=None, regularization_losses=None):
+ """Returns `model_fn._TPUEstimatorSpec` that a model_fn can return.
+
+ Args:
+ features: Input `dict` of `Tensor` or `SparseTensor` objects.
+ mode: Estimator's `ModeKeys`.
+ logits: logits `Tensor` to be used by the head.
+ labels: Labels `Tensor`, or `dict` of same.
+ optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+ Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+ updates variables and increments `global_step`.
+ train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
+ to optimize the model with the loss in TRAIN mode. Used if `optimizer`
+ is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in
+ TRAIN mode. None is allowed in other modes. If you want to optimize loss
+ yourself you can pass `lambda _: tf.no_op()` and then use
+ EstimatorSpec.loss to compute and apply gradients.
+ regularization_losses: A list of additional scalar losses to be added to
+ the training loss, such as regularization losses.
+
+ Returns:
+ A `model_fn._TPUEstimatorSpec' instance.
+ """
+ raise NotImplementedError(
+ 'TPUEstimatorSpec not available for this model head.')
def _check_dense_labels_match_logits_and_reshape(
weights=weights,
processed_labels=label_ids)
- def create_estimator_spec(
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
- """Returns an `EstimatorSpec`.
+ """Returns a `model_fn._TPUEstimatorSpec`.
Args:
features: Input `dict` of `Tensor` or `SparseTensor` objects.
`loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
- `EstimatorSpec`.
+ A `model_fn._TPUEstimatorSpec` instance.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode, or if both are set.
classifier_output = _classification_output(
scores=probabilities, n_classes=self._n_classes,
label_vocabulary=self._label_vocabulary)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
regularized_training_loss = training_loss
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=self._eval_metric_ops(
- labels=label_ids,
- class_ids=class_ids,
- weights=weights,
- unreduced_loss=unreduced_loss,
- regularization_loss=regularization_loss))
+ eval_metrics=_create_eval_metrics_tuple(self._eval_metric_ops, {
+ 'labels': label_ids,
+ 'class_ids': class_ids,
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss
+ }))
# Train.
if optimizer is not None:
summary.scalar(
_summary_key(self._name, keys.LOSS_REGULARIZATION),
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
weights=weights,
processed_labels=labels)
- def create_estimator_spec(
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
"""Returns an `EstimatorSpec`.
classifier_output = _classification_output(
scores=probabilities, n_classes=2,
label_vocabulary=self._label_vocabulary)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=self._eval_metric_ops(
- labels=processed_labels,
- logits=logits,
- logistic=logistic,
- class_ids=class_ids,
- weights=weights,
- unreduced_loss=unreduced_loss,
- regularization_loss=regularization_loss))
+ eval_metrics=_create_eval_metrics_tuple(
+ self._eval_metric_ops,
+ {
+ 'labels': processed_labels,
+ 'logits': logits,
+ 'logistic': logistic,
+ 'class_ids': class_ids,
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss
+ }
+ ))
# Train.
if optimizer is not None:
summary.scalar(
_summary_key(self._name, keys.LOSS_REGULARIZATION),
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
weights=weights,
processed_labels=labels)
- def create_estimator_spec(
+ def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss):
+ """Returns the Eval metric ops."""
+ keys = metric_keys.MetricKeys
+ # Estimator already adds a metric for loss.
+ eval_metric_ops = {
+ _summary_key(self._name, keys.LOSS_MEAN):
+ metrics_lib.mean(
+ values=unreduced_loss,
+ weights=weights)
+ }
+ if regularization_loss is not None:
+ regularization_loss_key = _summary_key(
+ self._name, keys.LOSS_REGULARIZATION)
+ eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
+ values=regularization_loss,
+ name=keys.LOSS_REGULARIZATION)
+ return eval_metric_ops
+
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
"""Returns an `EstimatorSpec`.
`loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
- `EstimatorSpec`.
+ A `model_fn._TPUEstimatorSpec` instance.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode, or if both are set.
if mode == model_fn.ModeKeys.PREDICT:
regression_output = export_output.RegressionOutput(
value=predicted_value)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- keys = metric_keys.MetricKeys
- # Estimator already adds a metric for loss.
- eval_metric_ops = {
- _summary_key(self._name, keys.LOSS_MEAN):
- metrics_lib.mean(
- values=unreduced_loss,
- weights=weights)
- }
- if regularization_loss is not None:
- regularization_loss_key = _summary_key(
- self._name, keys.LOSS_REGULARIZATION)
- eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
- values=regularization_loss,
- name=keys.LOSS_REGULARIZATION)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=eval_metric_ops)
+ eval_metrics=_create_eval_metrics_tuple(
+ self._eval_metric_ops,
+ {
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss,
+ }
+ ))
# Train.
if optimizer is not None:
summary.scalar(
_summary_key(self._name, keys.LOSS_REGULARIZATION),
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
raise ValueError('Weight column should be castable to float. '
'Given dtype: {}'.format(weights.dtype))
return math_ops.to_float(weights, name='weights')
+
+
+def _binary_logistic_or_multi_class_head(
+ n_classes, weight_column, label_vocabulary, loss_reduction):
+ """Creates either binary or multi-class head.
+
+ Args:
+ n_classes: Number of label classes.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are
+ already encoded as integer or float within [0, 1] for `n_classes=2` and
+ encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+ Also there will be errors if vocabulary is not provided and labels are
+ string.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ `head._Head` instance.
+ """
+ if n_classes == 2:
+ head = _binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+ else:
+ head = _multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes, weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+ return head
return 1 / (1 + np.exp(-logits))
+class CreateEstimatorSpecTest(test.TestCase):
+
+ class _HeadWithTPUSupport(head_lib._Head):
+ """Head that overrides _create_tpu_estimator_spec."""
+
+ def name(self):
+ return 'HeadWithTPUSupport'
+
+ def logits_dimension(self):
+ return None
+
+ def create_loss(self, features, mode, logits, labels):
+ return None
+
+ def _create_tpu_estimator_spec(self, features, mode, logits, labels=None,
+ optimizer=None, train_op_fn=None,
+ regularization_losses=None):
+ return model_fn._TPUEstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ loss=constant_op.constant(0.0, dtype=dtypes.float32))
+
+ class _HeadWithOutTPUSupport(head_lib._Head):
+ """Head that overrides create_estimator_spec."""
+
+ def name(self):
+ return 'HeadWithOutTPUSupport'
+
+ def logits_dimension(self):
+ return None
+
+ def create_loss(self, features, mode, logits, labels):
+ return None
+
+ def create_estimator_spec(self, features, mode, logits, labels=None,
+ optimizer=None, train_op_fn=None,
+ regularization_losses=None):
+ return model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ loss=constant_op.constant(0.0, dtype=dtypes.float32))
+
+ class _InvalidHead(head_lib._Head):
+ """Head that overrides neither estimator_spec functions."""
+
+ def name(self):
+ return 'InvalidHead'
+
+ def logits_dimension(self):
+ return None
+
+ def create_loss(self, features, mode, logits, labels):
+ return None
+
+ def test_head_override_tpu_estimator_spec(self):
+ """Test for `_Head` that overrides _create_tpu_estimator_spec."""
+ head = self._HeadWithTPUSupport()
+
+ tpu_spec = head._create_tpu_estimator_spec(
+ features=None, mode=None, logits=None)
+ self.assertTrue(isinstance(tpu_spec, model_fn._TPUEstimatorSpec))
+ est_spec = head.create_estimator_spec(
+ features=None, mode=None, logits=None)
+ self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))
+
+ def test_head_override_estimator_spec(self):
+ """Test for `_Head` that overrides create_estimator_spec."""
+ head = self._HeadWithOutTPUSupport()
+
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'TPUEstimatorSpec not available for this model head.'):
+ _ = head._create_tpu_estimator_spec(
+ features=None, mode=None, logits=None)
+ est_spec = head.create_estimator_spec(
+ features=None, mode=None, logits=None)
+ self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))
+
+ def test_invalid_head_class(self):
+ head = self._InvalidHead()
+
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'TPUEstimatorSpec not available for this model head.'):
+ _ = head._create_tpu_estimator_spec(
+ features=None, mode=None, logits=None)
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ r'Subclasses of _Head must implement `create_estimator_spec\(\)` or '
+ r'_create_tpu_estimator_spec\(\).'):
+ _ = head.create_estimator_spec(
+ features=None, mode=None, logits=None)
+
+
class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def setUp(self):
return EstimatorSpec(*new_fields)
+class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
+ 'mode',
+ 'predictions',
+ 'loss',
+ 'train_op',
+ 'eval_metrics',
+ 'export_outputs',
+ 'scaffold_fn',
+ 'host_call'])):
+ """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
+
+ This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See
+ tensorflow/contrib/tpu/python/tpu/tpu_estimator.py for more detailed
+ documentation.
+ """
+
+ def __new__(cls,
+ mode,
+ predictions=None,
+ loss=None,
+ train_op=None,
+ eval_metrics=None,
+ export_outputs=None,
+ scaffold_fn=None,
+ host_call=None):
+ """Creates a `_TPUEstimatorSpec` instance."""
+ return super(_TPUEstimatorSpec, cls).__new__(cls,
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metrics=eval_metrics,
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn,
+ host_call=host_call)
+
+ def as_estimator_spec(self):
+ """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
+ if not self.eval_metrics:
+ eval_metric_ops = None
+ else:
+ metric_fn, tensors = self.eval_metrics
+ eval_metric_ops = metric_fn(**tensors)
+ return EstimatorSpec(mode=self.mode,
+ predictions=self.predictions,
+ loss=self.loss,
+ train_op=self.train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=self.export_outputs)
+
+
def _check_is_tensor_or_operation(x, name):
if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))