Internal Change.
authorMichael Case <mikecase@google.com>
Mon, 7 May 2018 22:24:02 +0000 (15:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 00:25:40 +0000 (17:25 -0700)
PiperOrigin-RevId: 195731675

tensorflow/contrib/estimator/python/estimator/head.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/python/estimator/canned/dnn.py
tensorflow/python/estimator/canned/dnn_testing_utils.py
tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/canned/head_test.py
tensorflow/python/estimator/model_fn.py

index 5d19bf4..109fdd3 100644 (file)
@@ -560,10 +560,10 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         weights=weights,
         processed_labels=processed_labels)
 
-  def create_estimator_spec(
+  def _create_tpu_estimator_spec(
       self, features, mode, logits, labels=None, optimizer=None,
       train_op_fn=None, regularization_losses=None):
-    """Returns an `EstimatorSpec`.
+    """Returns an `model_fn._TPUEstimatorSpec`.
 
     Args:
       features: Input `dict` of `Tensor` or `SparseTensor` objects.
@@ -586,7 +586,7 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
         avoid scaling errors.
     Returns:
-      `EstimatorSpec`.
+      `model_fn._TPUEstimatorSpec`.
     Raises:
       ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
         mode, or if both are set.
@@ -606,7 +606,7 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
             scores=probabilities, n_classes=self._n_classes,
             label_vocabulary=self._label_vocabulary)
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access
             mode=model_fn.ModeKeys.PREDICT,
             predictions=predictions,
             export_outputs={
@@ -629,16 +629,18 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
 
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
             loss=regularized_training_loss,
-            eval_metric_ops=self._eval_metric_ops(
-                labels=processed_labels,
-                probabilities=probabilities,
-                weights=weights,
-                unreduced_loss=unreduced_loss,
-                regularization_loss=regularization_loss))
+            eval_metrics=head_lib._create_eval_metrics_tuple(  # pylint:disable=protected-access
+                self._eval_metric_ops, {
+                    'labels': processed_labels,
+                    'probabilities': probabilities,
+                    'weights': weights,
+                    'unreduced_loss': unreduced_loss,
+                    'regularization_loss': regularization_loss,
+                }))
 
       # Train.
       if optimizer is not None:
@@ -672,7 +674,7 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         summary.scalar(
             head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION),  # pylint:disable=protected-access
             regularization_loss)
-    return model_fn.EstimatorSpec(
+    return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
index a69bfa9..a624ece 100644 (file)
@@ -175,17 +175,7 @@ class _SIGNAL(object):
   STOP = -2
 
 
-class TPUEstimatorSpec(
-    collections.namedtuple('TPUEstimatorSpec', [
-        'mode',
-        'predictions',
-        'loss',
-        'train_op',
-        'eval_metrics',
-        'export_outputs',
-        'scaffold_fn',
-        'host_call'
-    ])):
+class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
   """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
 
   See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
@@ -1156,7 +1146,7 @@ class _ModelFnWrapper(object):
           self._call_model_fn(features, labels))
       loss, train_op = estimator_spec.loss, estimator_spec.train_op
 
-      if isinstance(estimator_spec, TPUEstimatorSpec):
+      if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
       else:
         captured_scaffold_fn.capture(None)
@@ -1165,8 +1155,8 @@ class _ModelFnWrapper(object):
       # outfeed.
       with ops.control_dependencies([train_op]):
         host_call_outfeed_ops = []
-        if (isinstance(estimator_spec, TPUEstimatorSpec) and
-            estimator_spec.host_call is not None):
+        if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)  # pylint: disable=protected-access
+            and estimator_spec.host_call is not None):
           host_call.record({'host_call': estimator_spec.host_call})
           host_call_outfeed_ops = host_call.create_enqueue_op()
         with ops.control_dependencies(host_call_outfeed_ops):
@@ -1209,7 +1199,7 @@ class _ModelFnWrapper(object):
       features, labels = inputs.features_and_labels()
 
       tpu_estimator_spec = self._call_model_fn(features, labels)
-      if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
+      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
         raise RuntimeError(
             'estimator_spec used by TPU evaluation must have type'
             '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
@@ -1254,7 +1244,7 @@ class _ModelFnWrapper(object):
 
       tpu_estimator_spec = self._call_model_fn(
           features, labels, is_export_mode=False)
-      if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
+      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
         raise RuntimeError(
             'estimator_spec used by TPU prediction must have type'
             '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
@@ -1316,7 +1306,7 @@ class _ModelFnWrapper(object):
 
     estimator_spec = self._model_fn(features=features, **kwargs)
     if (self._ctx.is_running_on_cpu(is_export_mode) and
-        isinstance(estimator_spec, TPUEstimatorSpec)):
+        isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)):  # pylint: disable=protected-access
       # The estimator_spec will be passed to `Estimator` directly, which expects
       # type `EstimatorSpec`.
       return estimator_spec.as_estimator_spec()
@@ -1325,7 +1315,7 @@ class _ModelFnWrapper(object):
 
   def _verify_estimator_spec(self, estimator_spec):
     """Validates the estimator_spec."""
-    if isinstance(estimator_spec, TPUEstimatorSpec):
+    if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
       return estimator_spec
 
     err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
index 973a6ec..e7fbf8e 100644 (file)
@@ -154,6 +154,59 @@ def _dnn_model_fn(features,
   Raises:
     ValueError: If features has the wrong type.
   """
+  tpu_estimator_spec = _tpu_dnn_model_fn(
+      features=features,
+      labels=labels,
+      mode=mode,
+      head=head,
+      hidden_units=hidden_units,
+      feature_columns=feature_columns,
+      optimizer=optimizer,
+      activation_fn=activation_fn,
+      dropout=dropout,
+      input_layer_partitioner=input_layer_partitioner,
+      config=config)
+  return tpu_estimator_spec.as_estimator_spec()
+
+
+def _tpu_dnn_model_fn(features,
+                      labels,
+                      mode,
+                      head,
+                      hidden_units,
+                      feature_columns,
+                      optimizer='Adagrad',
+                      activation_fn=nn.relu,
+                      dropout=None,
+                      input_layer_partitioner=None,
+                      config=None):
+  """Deep Neural Net model_fn for TPUEstimator.
+
+  Args:
+    features: dict of `Tensor`.
+    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+      dtype `int32` or `int64` in the range `[0, n_classes)`.
+    mode: Defines whether this is training, evaluation or prediction.
+      See `ModeKeys`.
+    head: A `head_lib._Head` instance.
+    hidden_units: Iterable of integer number of hidden units per layer.
+    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
+    optimizer: String, `tf.Optimizer` object, or callable that creates the
+      optimizer to use for training. If not specified, will use the Adagrad
+      optimizer with a default learning rate of 0.05.
+    activation_fn: Activation function applied to each layer.
+    dropout: When not `None`, the probability we will drop out a given
+      coordinate.
+    input_layer_partitioner: Partitioner for input layer. Defaults
+      to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+    config: `RunConfig` object to configure the runtime settings.
+
+  Returns:
+    A `model_fn.TPUEstimatorSpec` instance.
+
+  Raises:
+    ValueError: If features has the wrong type.
+  """
   if not isinstance(features, dict):
     raise ValueError('features should be a dictionary of `Tensor`s. '
                      'Given type: {}'.format(type(features)))
@@ -182,7 +235,7 @@ def _dnn_model_fn(features,
         input_layer_partitioner=input_layer_partitioner)
     logits = logit_fn(features=features, mode=mode)
 
-    return head.create_estimator_spec(
+    return head._create_tpu_estimator_spec(  # pylint: disable=protected-access
         features=features,
         mode=mode,
         labels=labels,
@@ -320,17 +373,8 @@ class DNNClassifier(estimator.Estimator):
       loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
         to reduce training loss over batch. Defaults to `SUM`.
     """
-    if n_classes == 2:
-      head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(  # pylint: disable=protected-access
-          weight_column=weight_column,
-          label_vocabulary=label_vocabulary,
-          loss_reduction=loss_reduction)
-    else:
-      head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
-          n_classes, weight_column=weight_column,
-          label_vocabulary=label_vocabulary,
-          loss_reduction=loss_reduction)
-
+    head = head_lib._binary_logistic_or_multi_class_head(  # pylint: disable=protected-access
+        n_classes, weight_column, label_vocabulary, loss_reduction)
     def _model_fn(features, labels, mode, config):
       """Call the defined shared _dnn_model_fn."""
       return _dnn_model_fn(
index 62b13c3..06a6487 100644 (file)
@@ -134,7 +134,7 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits):
       hidden_weights_names + hidden_biases_names +
       [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0'])
 
-  def _create_estimator_spec(
+  def _create_tpu_estimator_spec(
       features, mode, logits, labels, train_op_fn=None, optimizer=None):
     del features, labels  # Not used.
     trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
@@ -149,19 +149,29 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits):
           train_op = train_op_fn(loss)
         elif optimizer is not None:
           train_op = optimizer.minimize(loss, global_step=None)
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(
             mode=mode, loss=loss, train_op=train_op)
       elif mode == model_fn.ModeKeys.EVAL:
-        return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss))
+        return model_fn._TPUEstimatorSpec(
+            mode=mode, loss=array_ops.identity(loss))
       elif mode == model_fn.ModeKeys.PREDICT:
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(
             mode=mode, predictions={'logits': array_ops.identity(logits)})
       else:
         testcase.fail('Invalid mode: {}'.format(mode))
 
+  def _create_estimator_spec(
+      features, mode, logits, labels, train_op_fn=None, optimizer=None):
+    tpu_spec = _create_tpu_estimator_spec(
+        features, mode, logits, labels, train_op_fn, optimizer)
+    return tpu_spec.as_estimator_spec()
+
   head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
   head.logits_dimension = logits_dimension
-  head.create_estimator_spec = test.mock.MagicMock(wraps=_create_estimator_spec)
+  head._create_tpu_estimator_spec = test.mock.MagicMock(
+      wraps=_create_tpu_estimator_spec)
+  head.create_estimator_spec = test.mock.MagicMock(
+      wraps=_create_estimator_spec)
 
   return head
 
index 48f448d..2326373 100644 (file)
@@ -32,6 +32,7 @@ from tensorflow.python.feature_column import feature_column as feature_column_li
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
@@ -69,6 +70,35 @@ def _summary_key(head_name, val):
   return '%s/%s' % (val, head_name) if head_name else val
 
 
+def _create_eval_metrics_tuple(fn, kwargs):
+  """Creates TPU eval metrics tuple.
+
+  Helper function to make eval_metric tuple (eval_metric_fn, fn_kwargs) used
+  by `TPUEstimator`. TPUEstimator requires that `eval_metric_fn` take
+  exclusively Tensor arguments. This helper can help create such a function from
+  a more generic function that can take both Tensor and non-Tensor arguments.
+
+  Args:
+    fn: A eval_metric_fn that takes both Tensor and non-Tensor arguments.
+        This function must return a dict of form
+        {'metric name': (metric_tensor, eval_op)}
+    kwargs: Dict of arguments for `fn`.
+
+  Returns:
+    `eval_metric` tuple that can be passed to a `model_fn._TPUEstimatorSpec`.
+  """
+  tensor_kwargs = {}
+  nontensor_kwargs = {}
+  for k, v in six.iteritems(kwargs):
+    if tensor_util.is_tensor(v):
+      tensor_kwargs[k] = v
+    else:
+      nontensor_kwargs[k] = v
+  def _fn(**tensors):
+    return fn(**dict(nontensor_kwargs, **tensors))
+  return (_fn, tensor_kwargs)
+
+
 class _Head(object):
   """Interface for the head/top of a model.
 
@@ -174,7 +204,6 @@ class _Head(object):
 
   # TODO(b/65403806): By default, collect regularization_losses from
   # GraphKeys.REGULARIZATION_LOSSES collection.
-  @abc.abstractmethod
   def create_estimator_spec(
       self, features, mode, logits, labels=None, optimizer=None,
       train_op_fn=None, regularization_losses=None):
@@ -203,7 +232,47 @@ class _Head(object):
     Returns:
       `EstimatorSpec`.
     """
-    raise NotImplementedError('Calling an abstract method.')
+    try:
+      tpu_estimator_spec = (
+          self._create_tpu_estimator_spec(
+              features, mode, logits, labels, optimizer, train_op_fn,
+              regularization_losses))
+      return tpu_estimator_spec.as_estimator_spec()
+    except NotImplementedError:
+      # Not all subclasses of _Head will have implemented
+      # _create_tpu_estimator_spec. If it is implemented, we can use it to
+      # create our `EstimatorSpec` here.
+      raise NotImplementedError(
+          'Subclasses of _Head must implement `create_estimator_spec()` or '
+          '_create_tpu_estimator_spec().')
+
+  def _create_tpu_estimator_spec(
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None, regularization_losses=None):
+    """Returns `model_fn._TPUEstimatorSpec` that a model_fn can return.
+
+    Args:
+      features: Input `dict` of `Tensor` or `SparseTensor` objects.
+      mode: Estimator's `ModeKeys`.
+      logits: logits `Tensor` to be used by the head.
+      labels: Labels `Tensor`, or `dict` of same.
+      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+        updates variables and increments `global_step`.
+      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
+        to optimize the model with the loss in TRAIN mode. Used if `optimizer`
+        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in
+        TRAIN mode. None is allowed in other modes. If you want to optimize loss
+        yourself you can pass `lambda _: tf.no_op()` and then use
+        EstimatorSpec.loss to compute and apply gradients.
+      regularization_losses: A list of additional scalar losses to be added to
+        the training loss, such as regularization losses.
+
+    Returns:
+      A `model_fn._TPUEstimatorSpec' instance.
+    """
+    raise NotImplementedError(
+        'TPUEstimatorSpec not available for this model head.')
 
 
 def _check_dense_labels_match_logits_and_reshape(
@@ -702,10 +771,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         weights=weights,
         processed_labels=label_ids)
 
-  def create_estimator_spec(
+  def _create_tpu_estimator_spec(
       self, features, mode, logits, labels=None, optimizer=None,
       train_op_fn=None, regularization_losses=None):
-    """Returns an `EstimatorSpec`.
+    """Returns a `model_fn._TPUEstimatorSpec`.
 
     Args:
       features: Input `dict` of `Tensor` or `SparseTensor` objects.
@@ -727,7 +796,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
         avoid scaling errors.
     Returns:
-      `EstimatorSpec`.
+      A `model_fn._TPUEstimatorSpec` instance.
     Raises:
       ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
         mode, or if both are set.
@@ -761,7 +830,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         classifier_output = _classification_output(
             scores=probabilities, n_classes=self._n_classes,
             label_vocabulary=self._label_vocabulary)
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
             mode=model_fn.ModeKeys.PREDICT,
             predictions=predictions,
             export_outputs={
@@ -781,16 +850,17 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         regularized_training_loss = training_loss
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
             loss=regularized_training_loss,
-            eval_metric_ops=self._eval_metric_ops(
-                labels=label_ids,
-                class_ids=class_ids,
-                weights=weights,
-                unreduced_loss=unreduced_loss,
-                regularization_loss=regularization_loss))
+            eval_metrics=_create_eval_metrics_tuple(self._eval_metric_ops, {
+                'labels': label_ids,
+                'class_ids': class_ids,
+                'weights': weights,
+                'unreduced_loss': unreduced_loss,
+                'regularization_loss': regularization_loss
+            }))
 
       # Train.
       if optimizer is not None:
@@ -824,7 +894,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         summary.scalar(
             _summary_key(self._name, keys.LOSS_REGULARIZATION),
             regularization_loss)
-    return model_fn.EstimatorSpec(
+    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
@@ -1060,7 +1130,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
         weights=weights,
         processed_labels=labels)
 
-  def create_estimator_spec(
+  def _create_tpu_estimator_spec(
       self, features, mode, logits, labels=None, optimizer=None,
       train_op_fn=None, regularization_losses=None):
     """Returns an `EstimatorSpec`.
@@ -1122,7 +1192,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
         classifier_output = _classification_output(
             scores=probabilities, n_classes=2,
             label_vocabulary=self._label_vocabulary)
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
             mode=model_fn.ModeKeys.PREDICT,
             predictions=predictions,
             export_outputs={
@@ -1146,18 +1216,22 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
 
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
             loss=regularized_training_loss,
-            eval_metric_ops=self._eval_metric_ops(
-                labels=processed_labels,
-                logits=logits,
-                logistic=logistic,
-                class_ids=class_ids,
-                weights=weights,
-                unreduced_loss=unreduced_loss,
-                regularization_loss=regularization_loss))
+            eval_metrics=_create_eval_metrics_tuple(
+                self._eval_metric_ops,
+                {
+                    'labels': processed_labels,
+                    'logits': logits,
+                    'logistic': logistic,
+                    'class_ids': class_ids,
+                    'weights': weights,
+                    'unreduced_loss': unreduced_loss,
+                    'regularization_loss': regularization_loss
+                }
+            ))
 
       # Train.
       if optimizer is not None:
@@ -1190,7 +1264,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
         summary.scalar(
             _summary_key(self._name, keys.LOSS_REGULARIZATION),
             regularization_loss)
-    return model_fn.EstimatorSpec(
+    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
@@ -1322,7 +1396,25 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         weights=weights,
         processed_labels=labels)
 
-  def create_estimator_spec(
+  def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss):
+    """Returns the Eval metric ops."""
+    keys = metric_keys.MetricKeys
+    # Estimator already adds a metric for loss.
+    eval_metric_ops = {
+        _summary_key(self._name, keys.LOSS_MEAN):
+            metrics_lib.mean(
+                values=unreduced_loss,
+                weights=weights)
+    }
+    if regularization_loss is not None:
+      regularization_loss_key = _summary_key(
+          self._name, keys.LOSS_REGULARIZATION)
+      eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
+          values=regularization_loss,
+          name=keys.LOSS_REGULARIZATION)
+    return eval_metric_ops
+
+  def _create_tpu_estimator_spec(
       self, features, mode, logits, labels=None, optimizer=None,
       train_op_fn=None, regularization_losses=None):
     """Returns an `EstimatorSpec`.
@@ -1348,7 +1440,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
         avoid scaling errors.
     Returns:
-      `EstimatorSpec`.
+      A `model_fn._TPUEstimatorSpec` instance.
     Raises:
       ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
         mode, or if both are set.
@@ -1369,7 +1461,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
       if mode == model_fn.ModeKeys.PREDICT:
         regression_output = export_output.RegressionOutput(
             value=predicted_value)
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
             mode=model_fn.ModeKeys.PREDICT,
             predictions=predictions,
             export_outputs={
@@ -1390,25 +1482,18 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
 
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
-        keys = metric_keys.MetricKeys
-        # Estimator already adds a metric for loss.
-        eval_metric_ops = {
-            _summary_key(self._name, keys.LOSS_MEAN):
-                metrics_lib.mean(
-                    values=unreduced_loss,
-                    weights=weights)
-        }
-        if regularization_loss is not None:
-          regularization_loss_key = _summary_key(
-              self._name, keys.LOSS_REGULARIZATION)
-          eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
-              values=regularization_loss,
-              name=keys.LOSS_REGULARIZATION)
-        return model_fn.EstimatorSpec(
+        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
             loss=regularized_training_loss,
-            eval_metric_ops=eval_metric_ops)
+            eval_metrics=_create_eval_metrics_tuple(
+                self._eval_metric_ops,
+                {
+                    'weights': weights,
+                    'unreduced_loss': unreduced_loss,
+                    'regularization_loss': regularization_loss,
+                }
+            ))
 
       # Train.
       if optimizer is not None:
@@ -1441,7 +1526,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         summary.scalar(
             _summary_key(self._name, keys.LOSS_REGULARIZATION),
             regularization_loss)
-    return model_fn.EstimatorSpec(
+    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
@@ -1478,3 +1563,42 @@ def _weights(features, weight_column):
       raise ValueError('Weight column should be castable to float. '
                        'Given dtype: {}'.format(weights.dtype))
     return math_ops.to_float(weights, name='weights')
+
+
+def _binary_logistic_or_multi_class_head(
+    n_classes, weight_column, label_vocabulary, loss_reduction):
+  """Creates either binary or multi-class head.
+
+  Args:
+    n_classes: Number of label classes.
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+      weights. It is used to down weight or boost examples during training. It
+      will be multiplied by the loss of the example. If it is a string, it is
+      used as a key to fetch weight tensor from the `features`. If it is a
+      `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+      then weight_column.normalizer_fn is applied on it to get weight tensor.
+    label_vocabulary: A list of strings represents possible label values. If
+      given, labels must be string type and have any value in
+      `label_vocabulary`. If it is not given, that means labels are
+      already encoded as integer or float within [0, 1] for `n_classes=2` and
+      encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+      Also there will be errors if vocabulary is not provided and labels are
+      string.
+    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+      to reduce training loss over batch. Defaults to `SUM`.
+
+  Returns:
+    `head._Head` instance.
+  """
+  if n_classes == 2:
+    head = _binary_logistic_head_with_sigmoid_cross_entropy_loss(
+        weight_column=weight_column,
+        label_vocabulary=label_vocabulary,
+        loss_reduction=loss_reduction)
+  else:
+    head = _multi_class_head_with_softmax_cross_entropy_loss(
+        n_classes, weight_column=weight_column,
+        label_vocabulary=label_vocabulary,
+        loss_reduction=loss_reduction)
+  return head
index 32a6339..ecca3e8 100644 (file)
@@ -86,6 +86,98 @@ def _sigmoid(logits):
   return 1 / (1 + np.exp(-logits))
 
 
+class CreateEstimatorSpecTest(test.TestCase):
+
+  class _HeadWithTPUSupport(head_lib._Head):
+    """Head that overrides _create_tpu_estimator_spec."""
+
+    def name(self):
+      return 'HeadWithTPUSupport'
+
+    def logits_dimension(self):
+      return None
+
+    def create_loss(self, features, mode, logits, labels):
+      return None
+
+    def _create_tpu_estimator_spec(self, features, mode, logits, labels=None,
+                                   optimizer=None, train_op_fn=None,
+                                   regularization_losses=None):
+      return model_fn._TPUEstimatorSpec(
+          mode=model_fn.ModeKeys.EVAL,
+          loss=constant_op.constant(0.0, dtype=dtypes.float32))
+
+  class _HeadWithOutTPUSupport(head_lib._Head):
+    """Head that overrides create_estimator_spec."""
+
+    def name(self):
+      return 'HeadWithOutTPUSupport'
+
+    def logits_dimension(self):
+      return None
+
+    def create_loss(self, features, mode, logits, labels):
+      return None
+
+    def create_estimator_spec(self, features, mode, logits, labels=None,
+                              optimizer=None, train_op_fn=None,
+                              regularization_losses=None):
+      return model_fn.EstimatorSpec(
+          mode=model_fn.ModeKeys.EVAL,
+          loss=constant_op.constant(0.0, dtype=dtypes.float32))
+
+  class _InvalidHead(head_lib._Head):
+    """Head that overrides neither estimator_spec functions."""
+
+    def name(self):
+      return 'InvalidHead'
+
+    def logits_dimension(self):
+      return None
+
+    def create_loss(self, features, mode, logits, labels):
+      return None
+
+  def test_head_override_tpu_estimator_spec(self):
+    """Test for `_Head` that overrides _create_tpu_estimator_spec."""
+    head = self._HeadWithTPUSupport()
+
+    tpu_spec = head._create_tpu_estimator_spec(
+        features=None, mode=None, logits=None)
+    self.assertTrue(isinstance(tpu_spec, model_fn._TPUEstimatorSpec))
+    est_spec = head.create_estimator_spec(
+        features=None, mode=None, logits=None)
+    self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))
+
+  def test_head_override_estimator_spec(self):
+    """Test for `_Head` that overrides create_estimator_spec."""
+    head = self._HeadWithOutTPUSupport()
+
+    with self.assertRaisesRegexp(
+        NotImplementedError,
+        'TPUEstimatorSpec not available for this model head.'):
+      _ = head._create_tpu_estimator_spec(
+          features=None, mode=None, logits=None)
+    est_spec = head.create_estimator_spec(
+        features=None, mode=None, logits=None)
+    self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))
+
+  def test_invalid_head_class(self):
+    head = self._InvalidHead()
+
+    with self.assertRaisesRegexp(
+        NotImplementedError,
+        'TPUEstimatorSpec not available for this model head.'):
+      _ = head._create_tpu_estimator_spec(
+          features=None, mode=None, logits=None)
+    with self.assertRaisesRegexp(
+        NotImplementedError,
+        r'Subclasses of _Head must implement `create_estimator_spec\(\)` or '
+        r'_create_tpu_estimator_spec\(\).'):
+      _ = head.create_estimator_spec(
+          features=None, mode=None, logits=None)
+
+
 class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
 
   def setUp(self):
index 4ab2578..3edf9fe 100644 (file)
@@ -334,6 +334,57 @@ class EstimatorSpec(
     return EstimatorSpec(*new_fields)
 
 
+class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
+    'mode',
+    'predictions',
+    'loss',
+    'train_op',
+    'eval_metrics',
+    'export_outputs',
+    'scaffold_fn',
+    'host_call'])):
+  """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
+
+  This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See
+  tensorflow/contrib/tpu/python/tpu/tpu_estimator.py for more detailed
+  documentation.
+  """
+
+  def __new__(cls,
+              mode,
+              predictions=None,
+              loss=None,
+              train_op=None,
+              eval_metrics=None,
+              export_outputs=None,
+              scaffold_fn=None,
+              host_call=None):
+    """Creates a `_TPUEstimatorSpec` instance."""
+    return super(_TPUEstimatorSpec, cls).__new__(cls,
+                                                 mode=mode,
+                                                 predictions=predictions,
+                                                 loss=loss,
+                                                 train_op=train_op,
+                                                 eval_metrics=eval_metrics,
+                                                 export_outputs=export_outputs,
+                                                 scaffold_fn=scaffold_fn,
+                                                 host_call=host_call)
+
+  def as_estimator_spec(self):
+    """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
+    if not self.eval_metrics:
+      eval_metric_ops = None
+    else:
+      metric_fn, tensors = self.eval_metrics
+      eval_metric_ops = metric_fn(**tensors)
+    return EstimatorSpec(mode=self.mode,
+                         predictions=self.predictions,
+                         loss=self.loss,
+                         train_op=self.train_op,
+                         eval_metric_ops=eval_metric_ops,
+                         export_outputs=self.export_outputs)
+
+
 def _check_is_tensor_or_operation(x, name):
   if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
     raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))