Add precision and recall metrics to _BinaryLogisticHeadWithSigmoidCrossEntropyLoss.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sun, 18 Mar 2018 22:18:06 +0000 (15:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 18 Mar 2018 22:22:08 +0000 (15:22 -0700)
This change makes most of the binary classifiers in the canned estimators provide precision and recall metrics during evaluation. This matches the behavior of the canned estimators defined in the deprecated tf.contrib.learn.estimator.

PiperOrigin-RevId: 189522420

tensorflow/python/estimator/canned/baseline_test.py
tensorflow/python/estimator/canned/dnn_testing_utils.py
tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/canned/head_test.py
tensorflow/python/estimator/canned/linear_testing_utils.py
tensorflow/python/estimator/canned/metric_keys.py

index 96639e8..7833df2 100644 (file)
@@ -1071,6 +1071,8 @@ class BaselineClassifierEvaluationTest(test.TestCase):
           ops.GraphKeys.GLOBAL_STEP: 100,
           metric_keys.MetricKeys.LOSS_MEAN: 1.3133,
           metric_keys.MetricKeys.ACCURACY: 0.,
+          metric_keys.MetricKeys.PRECISION: 0.,
+          metric_keys.MetricKeys.RECALL: 0.,
           metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,
           metric_keys.MetricKeys.LABEL_MEAN: 1.,
           metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
@@ -1132,6 +1134,8 @@ class BaselineClassifierEvaluationTest(test.TestCase):
           ops.GraphKeys.GLOBAL_STEP: 100,
           metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
           metric_keys.MetricKeys.ACCURACY: 0.5,
+          metric_keys.MetricKeys.PRECISION: 0.,
+          metric_keys.MetricKeys.RECALL: 0.,
           metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,
           metric_keys.MetricKeys.LABEL_MEAN: 0.5,
           metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
@@ -1207,6 +1211,8 @@ class BaselineClassifierEvaluationTest(test.TestCase):
           ops.GraphKeys.GLOBAL_STEP: 100,
           metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
           metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),
+          metric_keys.MetricKeys.PRECISION: 0.,
+          metric_keys.MetricKeys.RECALL: 0.,
           metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,
           metric_keys.MetricKeys.LABEL_MEAN: label_mean,
           metric_keys.MetricKeys.ACCURACY_BASELINE: (
@@ -1542,4 +1548,3 @@ class BaselineLogitFnTest(test.TestCase):
 
 if __name__ == '__main__':
   test.main()
-
index 9a7d088..85b058c 100644 (file)
@@ -1035,6 +1035,8 @@ class BaseDNNClassifierEvaluateTest(object):
         metric_keys.MetricKeys.LOSS: expected_loss,
         metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2.,
         metric_keys.MetricKeys.ACCURACY: 0.5,
+        metric_keys.MetricKeys.PRECISION: 0.0,
+        metric_keys.MetricKeys.RECALL: 0.0,
         metric_keys.MetricKeys.PREDICTION_MEAN: 0.11105597,
         metric_keys.MetricKeys.LABEL_MEAN: 0.5,
         metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
@@ -1042,6 +1044,7 @@ class BaseDNNClassifierEvaluateTest(object):
         # that is what the algorithm returns.
         metric_keys.MetricKeys.AUC: 0.5,
         metric_keys.MetricKeys.AUC_PR: 0.75,
+
         ops.GraphKeys.GLOBAL_STEP: global_step
     }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1))
 
index 8d742a2..f68204a 100644 (file)
@@ -940,6 +940,18 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
                   predictions=class_ids,
                   weights=weights,
                   name=keys.ACCURACY),
+          _summary_key(self._name, keys.PRECISION):
+              metrics_lib.precision(
+                  labels=labels,
+                  predictions=class_ids,
+                  weights=weights,
+                  name=keys.PRECISION),
+          _summary_key(self._name, keys.RECALL):
+              metrics_lib.recall(
+                  labels=labels,
+                  predictions=class_ids,
+                  weights=weights,
+                  name=keys.RECALL),
           _summary_key(self._name, keys.PREDICTION_MEAN):
               _predictions_mean(
                   predictions=logistic,
index b40758f..b5d35c9 100644 (file)
@@ -1559,6 +1559,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
         # loss_mean = loss/2 = 41./2 = 20.5
         keys.LOSS_MEAN: 20.5,
         keys.ACCURACY: 1./2,
+        keys.PRECISION: 1.,
+        keys.RECALL: 1./2,
         keys.PREDICTION_MEAN: 1./2,
         keys.LABEL_MEAN: 2./2,
         keys.ACCURACY_BASELINE: 2./2,
@@ -1602,11 +1604,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
     expected_metric_keys = [
         '{}/some_binary_head'.format(metric_keys.MetricKeys.LOSS_MEAN),
         '{}/some_binary_head'.format(metric_keys.MetricKeys.ACCURACY),
+        '{}/some_binary_head'.format(metric_keys.MetricKeys.PRECISION),
+        '{}/some_binary_head'.format(metric_keys.MetricKeys.RECALL),
         '{}/some_binary_head'.format(metric_keys.MetricKeys.PREDICTION_MEAN),
         '{}/some_binary_head'.format(metric_keys.MetricKeys.LABEL_MEAN),
         '{}/some_binary_head'.format(metric_keys.MetricKeys.ACCURACY_BASELINE),
         '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC),
-        '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC_PR)
+        '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC_PR),
     ]
     self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())
 
@@ -1637,6 +1641,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
         keys.LOSS_MEAN: expected_unregularized_loss,
         keys.LOSS_REGULARIZATION: expected_regularization_loss,
         keys.ACCURACY: 1./2,
+        keys.PRECISION: 1.,
+        keys.RECALL: 1./2,
         keys.PREDICTION_MEAN: 1./2,
         keys.LABEL_MEAN: 2./2,
         keys.ACCURACY_BASELINE: 2./2,
@@ -1742,6 +1748,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
     expected_metrics = {
         keys.LOSS_MEAN: 1.62652338 / 2.,
         keys.ACCURACY: 1./2,
+        keys.PRECISION: 1.,
+        keys.RECALL: .5,
         keys.PREDICTION_MEAN: 1./2,
         keys.LABEL_MEAN: 2./2,
         keys.ACCURACY_BASELINE: 2./2,
@@ -2187,6 +2195,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
         keys.LOSS_MEAN: 26.9615384615,
         # accuracy = (1*1 + .1*0 + 1.5*0)/(1 + .1 + 1.5) = 1/2.6 = .38461538461
         keys.ACCURACY: .38461538461,
+        keys.PRECISION: 1./2.5,
+        keys.RECALL: 1./1.1,
         # prediction_mean = (1*1 + .1*0 + 1.5*1)/(1 + .1 + 1.5) = 2.5/2.6
         #                 = .96153846153
         keys.PREDICTION_MEAN: .96153846153,
@@ -2486,6 +2496,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
     expected_metrics = {
         keys.LOSS_MEAN: expected_loss / np.sum(weights),
         keys.ACCURACY: (1.*0. + 1.5*1. + 2.*1. + 2.5*0.) / np.sum(weights),
+        keys.PRECISION: 2.0/3.0,
+        keys.RECALL: 2.0/4.5,
         keys.PREDICTION_MEAN: (1.*1 + 1.5*0 + 2.*1 + 2.5*0) / np.sum(weights),
         keys.LABEL_MEAN: (1.*0 + 1.5*0 + 2.*1 + 2.5*1) / np.sum(weights),
         keys.ACCURACY_BASELINE: (1.*0 + 1.5*0 + 2.*1 + 2.5*1) / np.sum(weights),
index 8e506a7..da3ce86 100644 (file)
@@ -1337,6 +1337,8 @@ class BaseLinearClassifierEvaluationTest(object):
           ops.GraphKeys.GLOBAL_STEP: 100,
           metric_keys.MetricKeys.LOSS_MEAN: 41.,
           metric_keys.MetricKeys.ACCURACY: 0.,
+          metric_keys.MetricKeys.PRECISION: 0.,
+          metric_keys.MetricKeys.RECALL: 0.,
           metric_keys.MetricKeys.PREDICTION_MEAN: 0.,
           metric_keys.MetricKeys.LABEL_MEAN: 1.,
           metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
@@ -1406,6 +1408,8 @@ class BaseLinearClassifierEvaluationTest(object):
           ops.GraphKeys.GLOBAL_STEP: 100,
           metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
           metric_keys.MetricKeys.ACCURACY: 0.,
+          metric_keys.MetricKeys.PRECISION: 0.,
+          metric_keys.MetricKeys.RECALL: 0.,
           metric_keys.MetricKeys.PREDICTION_MEAN: 0.5,
           metric_keys.MetricKeys.LABEL_MEAN: 0.5,
           metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
@@ -1487,6 +1491,8 @@ class BaseLinearClassifierEvaluationTest(object):
           ops.GraphKeys.GLOBAL_STEP: 100,
           metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
           metric_keys.MetricKeys.ACCURACY: 0.,
+          metric_keys.MetricKeys.PRECISION: 0.,
+          metric_keys.MetricKeys.RECALL: 0.,
           metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,
           metric_keys.MetricKeys.LABEL_MEAN: label_mean,
           metric_keys.MetricKeys.ACCURACY_BASELINE: (
index 44eb680..f374d31 100644 (file)
@@ -28,6 +28,8 @@ class MetricKeys(object):
   LOSS_REGULARIZATION = 'regularization_loss'
 
   ACCURACY = 'accuracy'
+  PRECISION = 'precision'
+  RECALL = 'recall'
   # This is the best the model could do by always predicting one class.
   # Should be < ACCURACY in a trained model.
   ACCURACY_BASELINE = 'accuracy_baseline'