From: A. Unique TensorFlower Date: Tue, 29 May 2018 15:36:14 +0000 (-0700) Subject: Generalize assert_true_mean_equal and assert_true_mean_equal_two_sample to assert_tru... X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~21 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cff06379c2e1ac01de3b3c0ca32c3a3037d5b833;p=platform%2Fupstream%2Ftensorflow.git Generalize assert_true_mean_equal and assert_true_mean_equal_two_sample to assert_true_mean_in_interval. PiperOrigin-RevId: 198400265 --- diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py index ce6cf70..4a5a6b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -129,16 +129,41 @@ class StatisticalTestingTest(test.TestCase): # Test that the test assertion confirms that the mean of the # standard uniform distribution is not 0.4. - with self.assertRaisesOpError("Mean confidence interval too high"): + with self.assertRaisesOpError("true mean greater than expected"): sess.run(st.assert_true_mean_equal_by_dkwm( samples, 0., 1., 0.4, false_fail_rate=1e-6)) # Test that the test assertion confirms that the mean of the # standard uniform distribution is not 0.6. - with self.assertRaisesOpError("Mean confidence interval too low"): + with self.assertRaisesOpError("true mean smaller than expected"): sess.run(st.assert_true_mean_equal_by_dkwm( samples, 0., 1., 0.6, false_fail_rate=1e-6)) + def test_dkwm_mean_in_interval_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is between 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.2 and 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.6 and 0.8. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6)) + def test_dkwm_mean_two_sample_assertion(self): rng = np.random.RandomState(seed=0) num_samples = 4000 @@ -172,7 +197,7 @@ class StatisticalTestingTest(test.TestCase): # Test that the test assertion confirms that the mean of the # standard uniform distribution is different from the mean of beta(2, 1). beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples1 has a smaller mean"): + with self.assertRaisesOpError("true mean smaller than expected"): sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( samples1, 0., 1., beta_high_samples, 0., 1., @@ -190,7 +215,7 @@ class StatisticalTestingTest(test.TestCase): # Test that the test assertion confirms that the mean of the # standard uniform distribution is different from the mean of beta(1, 2). beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples2 has a smaller mean"): + with self.assertRaisesOpError("true mean greater than expected"): sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( samples1, 0., 1., beta_low_samples, 0., 1., diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index 9c69435..3ea9a33 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -140,6 +140,7 @@ __all__ = [ "assert_true_mean_equal_by_dkwm", "min_discrepancy_of_true_means_detectable_by_dkwm", "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_in_interval_by_dkwm", "assert_true_mean_equal_by_dkwm_two_sample", "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", "min_num_samples_for_dkwm_mean_two_sample_test", @@ -454,20 +455,8 @@ def assert_true_mean_equal_by_dkwm( with ops.name_scope( name, "assert_true_mean_equal_by_dkwm", [samples, low, high, expected, false_fail_rate]): - samples = ops.convert_to_tensor(samples, name="samples") - low = ops.convert_to_tensor(low, name="low") - high = ops.convert_to_tensor(high, name="high") - expected = ops.convert_to_tensor(expected, name="expected") - false_fail_rate = ops.convert_to_tensor( - false_fail_rate, name="false_fail_rate") - samples = _check_shape_dominates(samples, [low, high, expected]) - min_mean, max_mean = true_mean_confidence_interval_by_dkwm( - samples, low, high, error_rate=false_fail_rate) - less_op = check_ops.assert_less( - min_mean, expected, message="Mean confidence interval too high") - with ops.control_dependencies([less_op]): - return check_ops.assert_greater( - max_mean, expected, message="Mean confidence interval too low") + return assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected, expected, false_fail_rate) def min_discrepancy_of_true_means_detectable_by_dkwm( @@ -505,12 +494,15 @@ def min_discrepancy_of_true_means_detectable_by_dkwm( some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discr[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discr[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The detectable discrepancy scales as @@ -578,12 +570,15 @@ def min_num_samples_for_dkwm_mean_test( some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discrepancy[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discrepancy[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The required number of samples scales as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, @@ -610,6 +605,76 @@ def min_num_samples_for_dkwm_mean_test( return math_ops.maximum(n1, n2) +def assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected_low, expected_high, + false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is in the given interval. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the mean of the distribution from which the given samples are + drawn is _outside_ the given interval with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want + to check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected_low: Floating-point `Tensor` of lower bounds on the + expected true means. + expected_high: Floating-point `Tensor` of upper bounds on the + expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean + interval does not overlap with the corresponding confidence + interval. + """ + with ops.name_scope( + name, "assert_true_mean_in_interval_by_dkwm", + [samples, low, high, expected_low, expected_high, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected_low = ops.convert_to_tensor(expected_low, name="expected_low") + expected_high = ops.convert_to_tensor(expected_high, name="expected_high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates( + samples, [low, high, expected_low, expected_high]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, false_fail_rate) + # Assert that the interval [min_mean, max_mean] intersects the + # interval [expected_low, expected_high]. This is true if + # max_mean >= expected_low and min_mean <= expected_high. + # By DeMorgan's law, that's also equivalent to + # not (max_mean < expected_low or min_mean > expected_high), + # which is a way of saying the two intervals are not disjoint. + check_confidence_interval_can_intersect = check_ops.assert_greater_equal( + max_mean, expected_low, message="Confidence interval does not " + "intersect: true mean smaller than expected") + with ops.control_dependencies([check_confidence_interval_can_intersect]): + return check_ops.assert_less_equal( + min_mean, expected_high, message="Confidence interval does not " + "intersect: true mean greater than expected") + + def assert_true_mean_equal_by_dkwm_two_sample( samples1, low1, high1, samples2, low2, high2, false_fail_rate=1e-6, name=None): @@ -676,20 +741,10 @@ def assert_true_mean_equal_by_dkwm_two_sample( # and sample counts should be valid; however, because the intervals # scale as O(-log(false_fail_rate)), there doesn't seem to be much # room to win. - min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( - samples1, low1, high1, false_fail_rate / 2.) min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( samples2, low2, high2, false_fail_rate / 2.) - # I want to assert - # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), - # but I think I only have and-combination of asserts, so use DeMorgan. - check_confidence_intervals_can_intersect = check_ops.assert_greater_equal( - max_mean_1, min_mean_2, message="Confidence intervals do not " - "intersect: samples1 has a smaller mean than samples2") - with ops.control_dependencies([check_confidence_intervals_can_intersect]): - return check_ops.assert_less_equal( - min_mean_1, max_mean_2, message="Confidence intervals do not " - "intersect: samples2 has a smaller mean than samples1") + return assert_true_mean_in_interval_by_dkwm( + samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.) def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(