# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
- keys.AUC_PR: 0.5972,
+ keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
- keys.AUC_PR: 0.5972,
+ keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
- keys.AUC_PR: 0.5972,
+ keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
- keys.AUC_PR: 0.5972,
+ keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
- keys.AUC_PR: 0.5972,
+ keys.AUC_PR: 0.7639,
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.2000,
- keys.AUC_PR: 0.5833,
+ keys.AUC_PR: 0.7833,
}
# Assert spec contains expected tensors.
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.4977,
- keys.AUC_PR: 0.4037,
+ keys.AUC_PR: 0.6645,
}
self._test_eval(
head=head,
# this assert tests that the algorithm remains consistent.
keys.AUC + '/head1': 0.1667,
keys.AUC + '/head2': 0.3333,
- keys.AUC_PR + '/head1': 0.49999964,
- keys.AUC_PR + '/head2': 0.33333313,
+ keys.AUC_PR + '/head1': 0.6667,
+ keys.AUC_PR + '/head2': 0.5000,
}
# Assert spec contains expected tensors.
"auc_precision_recall": 0.166667,
"auc_precision_recall/class0": 0,
"auc_precision_recall/class1": 0.,
- "auc_precision_recall/class2": 0.49999,
+ "auc_precision_recall/class2": 1.,
"labels/actual_label_mean/class0": self._labels[0][0],
"labels/actual_label_mean/class1": self._labels[0][1],
"labels/actual_label_mean/class2": self._labels[0][2],
"accuracy/baseline_label_mean": label_mean,
"accuracy/threshold_0.500000_mean": 1. / 2,
"auc": 1. / 2,
- "auc_precision_recall": 0.25,
+ "auc_precision_recall": 0.749999,
"labels/actual_label_mean": label_mean,
"labels/prediction_mean": .731059, # softmax
"loss": expected_loss,
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.54166603, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
- self.assertAlmostEqual(0.54166603, auc.eval(), delta=1e-3)
+ self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
with self.test_session() as sess:
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
- self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
+ self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
with self.test_session() as sess:
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
- self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
+ self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.49999976, sess.run(update_op), 6)
+ self.assertAlmostEqual(1, sess.run(update_op), 6)
- self.assertAlmostEqual(0.49999976, auc.eval(), 6)
+ self.assertAlmostEqual(1, auc.eval(), 6)
def testWithMultipleUpdates(self):
num_samples = 1000
# [[0, 25, 0],
# [0, 0, 25],
# [25, 0, 0]]
- # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
- # labels, predictions)
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
expect = -0.333333333333
with self.test_session() as sess:
weights_t: weights[batch_start:batch_end]
})
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
- # labels_np, predictions_np,
- # sample_weight=weights_np)
+ # labels_np, predictions_np, sample_weight=weights_np)
expect = 0.289965397924
self.assertAlmostEqual(expect, kappa.eval(), 5)
metric_keys.MetricKeys.LABEL_MEAN: 1.,
metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
metric_keys.MetricKeys.AUC: 0.,
- metric_keys.MetricKeys.AUC_PR: 0.5,
+ metric_keys.MetricKeys.AUC_PR: 1.,
}
else:
# Multi classes: loss = 1 * -log ( softmax(logits)[label] )
metric_keys.MetricKeys.LABEL_MEAN: 0.5,
metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
metric_keys.MetricKeys.AUC: 0.5,
- metric_keys.MetricKeys.AUC_PR: 0.25,
+ metric_keys.MetricKeys.AUC_PR: 0.75,
}
else:
# Expand logits since batch_size=2
metric_keys.MetricKeys.ACCURACY_BASELINE: (
max(label_mean, 1-label_mean)),
metric_keys.MetricKeys.AUC: 0.5,
- metric_keys.MetricKeys.AUC_PR: 0.16666645,
+ metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.),
}
else:
# Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )
# There is no good way to calculate AUC for only two data points. But
# that is what the algorithm returns.
metric_keys.MetricKeys.AUC: 0.5,
- metric_keys.MetricKeys.AUC_PR: 0.25,
+ metric_keys.MetricKeys.AUC_PR: 0.75,
ops.GraphKeys.GLOBAL_STEP: global_step
}, dnn_classifier.evaluate(input_fn=_input_fn, steps=1))
keys.LABEL_MEAN: 2./2,
keys.ACCURACY_BASELINE: 2./2,
keys.AUC: 0.,
- keys.AUC_PR: 0.74999905,
+ keys.AUC_PR: 1.,
}
# Assert spec contains expected tensors.
keys.LABEL_MEAN: 2./2,
keys.ACCURACY_BASELINE: 2./2,
keys.AUC: 0.,
- keys.AUC_PR: 0.75,
+ keys.AUC_PR: 1.,
}
# Assert predictions, loss, and metrics.
keys.LABEL_MEAN: 2./2,
keys.ACCURACY_BASELINE: 2./2,
keys.AUC: 0.,
- keys.AUC_PR: 0.74999905,
+ keys.AUC_PR: 1.,
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1.,
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1.,
keys.RECALL_AT_THRESHOLD % thresholds[0]: 1.,
keys.LABEL_MEAN: expected_label_mean,
keys.ACCURACY_BASELINE: 1 - expected_label_mean,
keys.AUC: .45454565,
- keys.AUC_PR: .21923049,
+ keys.AUC_PR: .6737757325172424,
}
# Assert spec contains expected tensors.
# We cannot reliably calculate AUC with only 4 data points, but the
# values should not change because of backwards-compatibility.
keys.AUC: 0.5222,
- keys.AUC_PR: 0.5119,
+ keys.AUC_PR: 0.7341,
}
tol = 1e-2
metric_keys.MetricKeys.LABEL_MEAN: 1.,
metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
metric_keys.MetricKeys.AUC: 0.,
- metric_keys.MetricKeys.AUC_PR: 0.5,
+ metric_keys.MetricKeys.AUC_PR: 1.,
}
else:
# Multi classes: loss = 1 * -log ( soft_max(logits)[label] )
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
- self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3)
+ self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
with self.test_session() as sess:
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
- self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
+ self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
with self.test_session() as sess:
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
- self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
-
- def testFourthAUCPRSpecialCase(self):
- # Create the labels and data.
- labels = np.array([
- 0, 0, 0, 0, 0, 0, 0, 1, 0, 1])
- predictions = np.array([
- 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35])
-
- with self.test_session() as sess:
- auc, _ = metrics.auc(
- labels, predictions, curve='PR', num_thresholds=11)
-
- sess.run(variables.local_variables_initializer())
- # Since this is only approximate, we can't expect a 6 digits match.
- # Although with higher number of samples/thresholds we should see the
- # accuracy improving
- self.assertAlmostEqual(0.0, auc.eval(), delta=0.001)
+ self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
self.assertAlmostEqual(1, auc.eval(), 6)
- def testRecallOneAndPrecisionOne(self):
+ def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
with self.test_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
- self.assertAlmostEqual(0.5, sess.run(update_op), 6)
+ self.assertAlmostEqual(1, sess.run(update_op), 6)
- self.assertAlmostEqual(0.5, auc.eval(), 6)
+ self.assertAlmostEqual(1, auc.eval(), 6)
def np_auc(self, predictions, labels, weights):
"""Computes the AUC explicitly using Numpy.
x = fp_rate
y = rec
else: # curve == 'PR'.
- prec = math_ops.div(tp, tp + fp + epsilon)
+ prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
x = rec
y = prec
if summation_method == 'trapezoidal':
weights = array_ops.reshape(weights, [-1])
weights = math_ops.to_float(weights)
- is_correct *= weights
- ones *= weights
+ is_correct = is_correct * weights
+ ones = ones * weights
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)