self.assertAlmostEqual(0.7, auc.eval(), 5)
- def testAUCPRSpecialCase(self):
+ # Regarding the AUC-PR tests: note that the preferred method when
+ # calculating AUC-PR is summation_method='careful_interpolation'.
+ def testCorrectAUCPRSpecialCase(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
- auc, update_op = metrics.auc(labels, predictions, curve='PR')
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='careful_interpolation')
+
+ sess.run(variables.local_variables_initializer())
+ # expected ~= 0.79726744594
+ expected = 1 - math.log(1.5) / 2
+ self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+ def testCorrectAnotherAUCPRSpecialCase(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
+ shape=(1, 7),
+ dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='careful_interpolation')
+
+ sess.run(variables.local_variables_initializer())
+ # expected ~= 0.61350593198
+ expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3
+ self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+ def testThirdCorrectAUCPRSpecialCase(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
+ shape=(1, 7),
+ dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='careful_interpolation')
+
+ sess.run(variables.local_variables_initializer())
+ # expected ~= 0.90410597584
+ expected = 1 - math.log(4./3) / 3
+ self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+ def testIncorrectAUCPRSpecialCase(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='trapezoidal')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
- def testAnotherAUCPRSpecialCase(self):
+ def testAnotherIncorrectAUCPRSpecialCase(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
- auc, update_op = metrics.auc(labels, predictions, curve='PR')
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='trapezoidal')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
- def testThirdAUCPRSpecialCase(self):
+ def testThirdIncorrectAUCPRSpecialCase(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
- auc, update_op = metrics.auc(labels, predictions, curve='PR')
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='trapezoidal')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
name: An optional variable_scope name.
- summation_method: Specifies the Riemann summation method used, 'trapezoidal'
- [default] that applies the trapezoidal rule, 'minoring' that applies
- left summation for increasing intervals and right summation for decreasing
- intervals or 'majoring' that applies the opposite.
+ summation_method: Specifies the Riemann summation method used
+ (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
+ applies the trapezoidal rule; 'careful_interpolation', a variant of it
+ differing only by a more correct interpolation scheme for PR-AUC -
+ interpolating (true/false) positives but not the ratio that is precision;
+ 'minoring' that applies left summation for increasing intervals and right
+ summation for decreasing intervals; 'majoring' that does the opposite.
+ Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
+ (to be deprecated soon) as it applies the same method for ROC, and a
+ better one (see Davis & Goadrich 2006 for details) for the PR curve.
Returns:
auc: A scalar `Tensor` representing the current area-under-curve.
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
+ def interpolate_pr_auc(tp, fp, fn):
+ """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
+
+ Note here we derive & use a closed formula not present in the paper
+ - as follows:
+ Modeling all of TP (true positive weight),
+ FP (false positive weight) and their sum P = TP + FP (positive weight)
+ as varying linearly within each interval [A, B] between successive
+ thresholds, we get
+ Precision = (TP_A + slope * (P - P_A)) / P
+ with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
+ The area within the interval is thus (slope / total_pos_weight) times
+ int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
+ int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
+ where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
+ int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
+ Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
+ slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
+ where dTP == TP_B - TP_A.
+ Note that when P_A == 0 the above calculation simplifies into
+ int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
+ which is really equivalent to imputing constant precision throughout the
+ first bucket having >0 true positives.
+
+ Args:
+ tp: true positive counts
+ fp: false positive counts
+ fn: false negative counts
+ Returns:
+ pr_auc: an approximation of the area under the P-R curve.
+ """
+ dtp = tp[:num_thresholds - 1] - tp[1:]
+ p = tp + fp
+ prec_slope = _safe_div(dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope')
+ intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
+ safe_p_ratio = array_ops.where(
+ math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
+ _safe_div(p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'),
+ array_ops.ones_like(p[1:]))
+ return math_ops.reduce_sum(
+ _safe_div(
+ prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
+ tp[1:] + fn[1:],
+ name='pr_auc_increment'),
+ name='interpolate_pr_auc')
+
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
+ if curve == 'PR':
+ if summation_method == 'trapezoidal':
+ logging.warning(
+ 'Trapezoidal rule is known to produce incorrect PR-AUCs; '
+ 'please switch to "careful_interpolation" instead.')
+ elif summation_method == 'careful_interpolation':
+ # This one is a bit tricky and is handled separately.
+ return interpolate_pr_auc(tp, fp, fn)
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
fp_rate = math_ops.div(fp, fp + tn + epsilon)
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
x = rec
y = prec
- if summation_method == 'trapezoidal':
+ if summation_method in ('trapezoidal', 'careful_interpolation'):
+ # Note that the case ('PR', 'careful_interpolation') has been handled
+ # above.
return math_ops.reduce_sum(
math_ops.multiply(x[:num_thresholds - 1] - x[1:],
(y[:num_thresholds - 1] + y[1:]) / 2.),
weights = array_ops.reshape(weights, [-1])
weights = math_ops.to_float(weights)
- is_correct = is_correct * weights
- ones = ones * weights
+ is_correct *= weights
+ ones *= weights
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)