Adds regularization_losses in canned heads.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 17 Jan 2018 17:49:03 +0000 (09:49 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 17 Jan 2018 17:52:26 +0000 (09:52 -0800)
PiperOrigin-RevId: 182228149

tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/canned/head_test.py
tensorflow/python/estimator/canned/metric_keys.py

index 6706594665df04aff1d28a9c200c1724fea82283..204e1119f2191457359ecaf9fd012fcb5a2b0463 100644 (file)
@@ -173,7 +173,8 @@ class _Head(object):
 
   @abc.abstractmethod
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None):
+      self, features, mode, logits, labels=None, train_op_fn=None,
+      regularization_losses=None):
     """Returns `EstimatorSpec` that a model_fn can return.
 
     Please note that,
@@ -185,10 +186,12 @@ class _Head(object):
       logits: logits `Tensor` to be used by the head.
       labels: Labels `Tensor`, or `dict` of same.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
-          to optimize the model with the loss. This is used in TRAIN mode and
-          must not be None. None is allowed in other modes. If you want to
-          optimize loss yourself you can pass `no_op_train_fn` and then use
-          EstimatorSpec.loss to compute and apply gradients.
+        to optimize the model with the loss. This is used in TRAIN mode and
+        must not be None. None is allowed in other modes. If you want to
+        optimize loss yourself you can pass `no_op_train_fn` and then use
+        EstimatorSpec.loss to compute and apply gradients.
+      regularization_losses: A list of additional scalar losses to be added to
+        the training loss, such as regularization losses.
 
     Returns:
       `EstimatorSpec`.
@@ -547,10 +550,12 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
   def logits_dimension(self):
     return self._n_classes
 
-  def _eval_metric_ops(self, labels, class_ids, weights, unreduced_loss):
+  def _eval_metric_ops(
+      self, labels, class_ids, weights, unreduced_loss, regularization_loss):
     """Returns the Eval metric ops."""
     with ops.name_scope(
-        None, 'metrics', (labels, class_ids, weights, unreduced_loss)):
+        None, 'metrics',
+        (labels, class_ids, weights, unreduced_loss, regularization_loss)):
       keys = metric_keys.MetricKeys
       metric_ops = {
           # Estimator already adds a metric for loss.
@@ -567,6 +572,11 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
                   weights=weights,
                   name=keys.ACCURACY),
       }
+      if regularization_loss is not None:
+        metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = (
+            metrics_lib.mean(
+                values=regularization_loss,
+                name=keys.LOSS_REGULARIZATION))
     return metric_ops
 
   def _label_ids(self, labels):
@@ -607,7 +617,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         processed_labels=label_ids)
 
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None):
+      self, features, mode, logits, labels=None, train_op_fn=None,
+      regularization_losses=None):
     """Returns an `EstimatorSpec`.
 
     Args:
@@ -620,6 +631,12 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         equals `TRAIN` or `EVAL`.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns
         `train_op`. Required in TRAIN mode.
+      regularization_losses: A list of additional scalar losses to be added to
+        the training loss, such as regularization losses. These losses are
+        usually expressed as a batch average, so for best results users need to
+        set `loss_reduction=MEAN_PER_ELEMENT` or
+        `loss_reduction=SUM_BY_NONZERO_WEIGHTS` when creating the head to
+        avoid scaling errors.
     Returns:
       `EstimatorSpec`.
     Raises:
@@ -665,17 +682,25 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
 
       training_loss, unreduced_loss, weights, label_ids = self.create_loss(
           features=features, mode=mode, logits=logits, labels=labels)
+      if regularization_losses:
+        regularization_loss = math_ops.add_n(regularization_losses)
+        regularized_training_loss = math_ops.add_n(
+            [training_loss, regularization_loss])
+      else:
+        regularization_loss = None
+        regularized_training_loss = training_loss
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
         return model_fn.EstimatorSpec(
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
-            loss=training_loss,
+            loss=regularized_training_loss,
             eval_metric_ops=self._eval_metric_ops(
                 labels=label_ids,
                 class_ids=class_ids,
                 weights=weights,
-                unreduced_loss=unreduced_loss))
+                unreduced_loss=unreduced_loss,
+                regularization_loss=regularization_loss))
 
       # Train.
       if train_op_fn is None:
@@ -689,18 +714,23 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
       else:
         mean_loss = None
     with ops.name_scope(''):
+      keys = metric_keys.MetricKeys
       summary.scalar(
-          _summary_key(self._name, metric_keys.MetricKeys.LOSS),
-          training_loss)
+          _summary_key(self._name, keys.LOSS),
+          regularized_training_loss)
       if mean_loss is not None:
         summary.scalar(
-            _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN),
+            _summary_key(self._name, keys.LOSS_MEAN),
             mean_loss)
+      if regularization_loss is not None:
+        summary.scalar(
+            _summary_key(self._name, keys.LOSS_REGULARIZATION),
+            regularization_loss)
     return model_fn.EstimatorSpec(
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
-        loss=training_loss,
-        train_op=train_op_fn(training_loss))
+        loss=regularized_training_loss,
+        train_op=train_op_fn(regularized_training_loss))
 
 
 def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
index 98f74a7cedec29453da0878af393c1a285c15fc4..28b8e635fb483252edf68a4140bb57c3b99fb96a 100644 (file)
@@ -484,6 +484,52 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
     ]
     self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())
 
+  def test_eval_with_regularization_losses(self):
+    n_classes = 3
+    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+        n_classes, loss_reduction=losses.Reduction.MEAN)
+    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)
+    labels = np.array(((1,), (1,)), dtype=np.int64)
+    features = {'x': np.array(((42,),), dtype=np.int32)}
+    regularization_losses = [1.5, 0.5]
+    expected_regularization_loss = 2.
+    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size
+    #                    = sum(10, 0) / 2 = 5.
+    expected_unregularized_loss = 5.
+    expected_regularized_loss = (
+        expected_unregularized_loss + expected_regularization_loss)
+    # Create estimator spec.
+    spec = head.create_estimator_spec(
+        features=features,
+        mode=model_fn.ModeKeys.EVAL,
+        logits=logits,
+        labels=labels,
+        regularization_losses=regularization_losses)
+
+    keys = metric_keys.MetricKeys
+    expected_metrics = {
+        keys.LOSS_MEAN: expected_unregularized_loss,
+        keys.LOSS_REGULARIZATION: expected_regularization_loss,
+        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.
+    }
+
+    # Assert predictions, loss, and metrics.
+    tol = 1e-2
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      self.assertIsNone(spec.scaffold.summary_op)
+      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+      loss, metrics = sess.run((spec.loss, update_ops))
+      self.assertAllClose(expected_regularized_loss, loss, rtol=tol, atol=tol)
+      # Check results of both update (in `metrics`) and value ops.
+      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
+      self.assertAllClose(
+          expected_metrics, {k: value_ops[k].eval()
+                             for k in value_ops},
+          rtol=tol,
+          atol=tol)
+
   def test_eval_with_label_vocabulary_create_loss(self):
     n_classes = 3
     head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
@@ -741,6 +787,51 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
               expected_loss / 2,
       }, summary_str, tol)
 
+  def test_train_with_regularization_losses(self):
+    n_classes = 3
+    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+        n_classes, loss_reduction=losses.Reduction.MEAN)
+
+    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)
+    labels = np.array(((1,), (1,)), dtype=np.int64)
+    features = {'x': np.array(((42,),), dtype=np.int32)}
+    expected_train_result = 'my_train_op'
+    def _train_op_fn(loss):
+      return string_ops.string_join(
+          [constant_op.constant(expected_train_result),
+           string_ops.as_string(loss, precision=2)])
+
+    regularization_losses = [1.5, 0.5]
+    expected_regularization_loss = 2.
+    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size
+    #                    = sum(10, 0) / 2 = 5.
+    # loss = unregularized_loss + regularization_loss = 7.
+    expected_loss = 7.
+    spec = head.create_estimator_spec(
+        features=features,
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        train_op_fn=_train_op_fn,
+        regularization_losses=regularization_losses)
+
+    # Assert predictions, loss, train_op, and summaries.
+    tol = 1e-2
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      self.assertIsNotNone(spec.scaffold.summary_op)
+      loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
+                                                  spec.scaffold.summary_op))
+      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+      self.assertEqual(
+          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),
+          train_result)
+      _assert_simple_summaries(self, {
+          metric_keys.MetricKeys.LOSS: expected_loss,
+          metric_keys.MetricKeys.LOSS_REGULARIZATION: (
+              expected_regularization_loss),
+      }, summary_str, tol)
+
   def test_train_one_dim_create_loss(self):
     """Tests create_loss with 1D labels and weights (shape [batch_size])."""
     head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
index 7dc4bfe5ffb5f762b56f4fc91b8a75ee4ba1796e..44eb680939203fea67e3391326a6f1013f022ad5 100644 (file)
@@ -25,6 +25,7 @@ class MetricKeys(object):
   """Metric key strings."""
   LOSS = model_fn.LOSS_METRIC_KEY
   LOSS_MEAN = model_fn.AVERAGE_LOSS_METRIC_KEY
+  LOSS_REGULARIZATION = 'regularization_loss'
 
   ACCURACY = 'accuracy'
   # This is the best the model could do by always predicting one class.