Fix PR-AUC calculation, namely the incorrect use of linear interpolation for Precisio...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 07:33:14 +0000 (00:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 07:38:05 +0000 (00:38 -0700)
Also, modify the name of the "trapezoidal" summation method to reflect the fact that the proper interpolation method in this case isn't quite the trapezoidal one.

PiperOrigin-RevId: 191555707

tensorflow/python/kernel_tests/metrics_test.py
tensorflow/python/ops/metrics_impl.py

index ad802f7..5565348 100644 (file)
@@ -1124,40 +1124,91 @@ class AUCTest(test.TestCase):
 
       self.assertAlmostEqual(0.7, auc.eval(), 5)
 
-  def testAUCPRSpecialCase(self):
+  # Regarding the AUC-PR tests: note that the preferred method when
+  # calculating AUC-PR is summation_method='careful_interpolation'.
+  def testCorrectAUCPRSpecialCase(self):
     with self.test_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
-      auc, update_op = metrics.auc(labels, predictions, curve='PR')
+      auc, update_op = metrics.auc(labels, predictions, curve='PR',
+                                   summation_method='careful_interpolation')
+
+      sess.run(variables.local_variables_initializer())
+      # expected ~= 0.79726744594
+      expected = 1 - math.log(1.5) / 2
+      self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+      self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+  def testCorrectAnotherAUCPRSpecialCase(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant(
+          [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
+          shape=(1, 7),
+          dtype=dtypes_lib.float32)
+      labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
+      auc, update_op = metrics.auc(labels, predictions, curve='PR',
+                                   summation_method='careful_interpolation')
+
+      sess.run(variables.local_variables_initializer())
+      # expected ~= 0.61350593198
+      expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3
+      self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+      self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+  def testThirdCorrectAUCPRSpecialCase(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant(
+          [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
+          shape=(1, 7),
+          dtype=dtypes_lib.float32)
+      labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
+      auc, update_op = metrics.auc(labels, predictions, curve='PR',
+                                   summation_method='careful_interpolation')
+
+      sess.run(variables.local_variables_initializer())
+      # expected ~= 0.90410597584
+      expected = 1 - math.log(4./3) / 3
+      self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+      self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+  def testIncorrectAUCPRSpecialCase(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant(
+          [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
+      labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
+      auc, update_op = metrics.auc(labels, predictions, curve='PR',
+                                   summation_method='trapezoidal')
 
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
 
       self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
 
-  def testAnotherAUCPRSpecialCase(self):
+  def testAnotherIncorrectAUCPRSpecialCase(self):
     with self.test_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
           shape=(1, 7),
           dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
-      auc, update_op = metrics.auc(labels, predictions, curve='PR')
+      auc, update_op = metrics.auc(labels, predictions, curve='PR',
+                                   summation_method='trapezoidal')
 
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
 
       self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
 
-  def testThirdAUCPRSpecialCase(self):
+  def testThirdIncorrectAUCPRSpecialCase(self):
     with self.test_session() as sess:
       predictions = constant_op.constant(
           [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
           shape=(1, 7),
           dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
-      auc, update_op = metrics.auc(labels, predictions, curve='PR')
+      auc, update_op = metrics.auc(labels, predictions, curve='PR',
+                                   summation_method='trapezoidal')
 
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
index 9ec4954..47eea6e 100644 (file)
@@ -33,6 +33,7 @@ from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.tf_export import tf_export
 
@@ -626,10 +627,16 @@ def auc(labels,
     curve: Specifies the name of the curve to be computed, 'ROC' [default] or
       'PR' for the Precision-Recall-curve.
     name: An optional variable_scope name.
-    summation_method: Specifies the Riemann summation method used, 'trapezoidal'
-      [default] that applies the trapezoidal rule, 'minoring' that applies
-      left summation for increasing intervals and right summation for decreasing
-      intervals or 'majoring' that applies the opposite.
+    summation_method: Specifies the Riemann summation method used
+      (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
+      applies the trapezoidal rule; 'careful_interpolation', a variant of it
+      differing only by a more correct interpolation scheme for PR-AUC -
+      interpolating (true/false) positives but not the ratio that is precision;
+      'minoring' that applies left summation for increasing intervals and right
+      summation for decreasing intervals; 'majoring' that does the opposite.
+      Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
+      (to be deprecated soon) as it applies the same method for ROC, and a
+      better one (see Davis & Goadrich 2006 for details) for the PR curve.
 
   Returns:
     auc: A scalar `Tensor` representing the current area-under-curve.
@@ -664,8 +671,62 @@ def auc(labels,
     # Add epsilons to avoid dividing by 0.
     epsilon = 1.0e-6
 
+    def interpolate_pr_auc(tp, fp, fn):
+      """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
+
+      Note here we derive & use a closed formula not present in the paper
+      - as follows:
+      Modeling all of TP (true positive weight),
+      FP (false positive weight) and their sum P = TP + FP (positive weight)
+      as varying linearly within each interval [A, B] between successive
+      thresholds, we get
+        Precision = (TP_A + slope * (P - P_A)) / P
+      with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
+      The area within the interval is thus (slope / total_pos_weight) times
+        int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
+        int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
+      where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
+        int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
+      Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
+         slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
+      where dTP == TP_B - TP_A.
+      Note that when P_A == 0 the above calculation simplifies into
+        int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
+      which is really equivalent to imputing constant precision throughout the
+      first bucket having >0 true positives.
+
+      Args:
+        tp: true positive counts
+        fp: false positive counts
+        fn: false negative counts
+      Returns:
+        pr_auc: an approximation of the area under the P-R curve.
+      """
+      dtp = tp[:num_thresholds - 1] - tp[1:]
+      p = tp + fp
+      prec_slope = _safe_div(dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope')
+      intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
+      safe_p_ratio = array_ops.where(
+          math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
+          _safe_div(p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'),
+          array_ops.ones_like(p[1:]))
+      return math_ops.reduce_sum(
+          _safe_div(
+              prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
+              tp[1:] + fn[1:],
+              name='pr_auc_increment'),
+          name='interpolate_pr_auc')
+
     def compute_auc(tp, fn, tn, fp, name):
       """Computes the roc-auc or pr-auc based on confusion counts."""
+      if curve == 'PR':
+        if summation_method == 'trapezoidal':
+          logging.warning(
+              'Trapezoidal rule is known to produce incorrect PR-AUCs; '
+              'please switch to "careful_interpolation" instead.')
+        elif summation_method == 'careful_interpolation':
+          # This one is a bit tricky and is handled separately.
+          return interpolate_pr_auc(tp, fp, fn)
       rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
       if curve == 'ROC':
         fp_rate = math_ops.div(fp, fp + tn + epsilon)
@@ -675,7 +736,9 @@ def auc(labels,
         prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
         x = rec
         y = prec
-      if summation_method == 'trapezoidal':
+      if summation_method in ('trapezoidal', 'careful_interpolation'):
+        # Note that the case ('PR', 'careful_interpolation') has been handled
+        # above.
         return math_ops.reduce_sum(
             math_ops.multiply(x[:num_thresholds - 1] - x[1:],
                               (y[:num_thresholds - 1] + y[1:]) / 2.),
@@ -923,8 +986,8 @@ def mean_per_class_accuracy(labels,
         weights = array_ops.reshape(weights, [-1])
       weights = math_ops.to_float(weights)
 
-      is_correct = is_correct * weights
-      ones = ones * weights
+      is_correct *= weights
+      ones *= weights
 
     update_total_op = state_ops.scatter_add(total, labels, ones)
     update_count_op = state_ops.scatter_add(count, labels, is_correct)