Tighten label check in BinaryLogisticHeadWithSigmoidCrossEntropyLoss
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Apr 2018 19:12:46 +0000 (12:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 16 Apr 2018 19:20:03 +0000 (12:20 -0700)
PiperOrigin-RevId: 193078844

tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/canned/head_test.py

index 189b81a..c365ea8 100644 (file)
@@ -1039,7 +1039,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
           vocabulary_list=tuple(self._label_vocabulary),
           name='class_id_lookup').lookup(labels)
     labels = math_ops.to_float(labels)
-    labels = _assert_range(labels, 2)
+    labels = _assert_range(labels, n_classes=2)
     if self._loss_fn:
       unweighted_loss = _call_loss_fn(
           loss_fn=self._loss_fn, labels=labels, logits=logits,
@@ -1447,12 +1447,12 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
 
 def _assert_range(labels, n_classes, message=None):
   with ops.name_scope(None, 'assert_range', (labels,)):
-    assert_less = check_ops.assert_less(
+    assert_less = check_ops.assert_less_equal(
         labels,
-        ops.convert_to_tensor(n_classes, dtype=labels.dtype),
-        message=message or 'Label IDs must < n_classes')
+        ops.convert_to_tensor(n_classes - 1, dtype=labels.dtype),
+        message=message or 'Labels must <= n_classes - 1')
     assert_greater = check_ops.assert_non_negative(
-        labels, message=message or 'Label IDs must >= 0')
+        labels, message=message or 'Labels must >= 0')
     with ops.control_dependencies((assert_less, assert_greater)):
       return array_ops.identity(labels)
 
index fe6ee07..7da3df0 100644 (file)
@@ -255,14 +255,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
     with self.test_session():
-      with self.assertRaisesOpError('Label IDs must < n_classes'):
+      with self.assertRaisesOpError('Labels must <= n_classes - 1'):
         training_loss.eval({
             labels_placeholder: labels_2x1_with_large_id,
             logits_placeholder: logits_2x3
         })
 
     with self.test_session():
-      with self.assertRaisesOpError('Label IDs must >= 0'):
+      with self.assertRaisesOpError('Labels must >= 0'):
         training_loss.eval({
             labels_placeholder: labels_2x1_with_negative_id,
             logits_placeholder: logits_2x3
@@ -2090,6 +2090,24 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
               expected_regularization_loss),
       }, summary_str)
 
+  def test_float_labels_invalid_values(self):
+    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
+
+    logits = np.array([[0.5], [-0.3]], dtype=np.float32)
+    labels = np.array([[1.2], [0.4]], dtype=np.float32)
+    features = {'x': np.array([[42]], dtype=np.float32)}
+    training_loss = head.create_loss(
+        features=features,
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels)[0]
+    with self.assertRaisesRegexp(
+        errors.InvalidArgumentError,
+        r'Labels must <= n_classes - 1'):
+      with self.test_session():
+        _initialize_variables(self, monitored_session.Scaffold())
+        training_loss.eval()
+
   def test_float_labels_train_create_loss(self):
     head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()