Generalize assert_true_mean_equal and assert_true_mean_equal_two_sample to assert_tru...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 29 May 2018 15:36:14 +0000 (08:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 15:38:52 +0000 (08:38 -0700)
PiperOrigin-RevId: 198400265

tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
tensorflow/contrib/distributions/python/ops/statistical_testing.py

index ce6cf70..4a5a6b5 100644 (file)
@@ -129,16 +129,41 @@ class StatisticalTestingTest(test.TestCase):
 
       # Test that the test assertion confirms that the mean of the
       # standard uniform distribution is not 0.4.
-      with self.assertRaisesOpError("Mean confidence interval too high"):
+      with self.assertRaisesOpError("true mean greater than expected"):
         sess.run(st.assert_true_mean_equal_by_dkwm(
             samples, 0., 1., 0.4, false_fail_rate=1e-6))
 
       # Test that the test assertion confirms that the mean of the
       # standard uniform distribution is not 0.6.
-      with self.assertRaisesOpError("Mean confidence interval too low"):
+      with self.assertRaisesOpError("true mean smaller than expected"):
         sess.run(st.assert_true_mean_equal_by_dkwm(
             samples, 0., 1., 0.6, false_fail_rate=1e-6))
 
+  def test_dkwm_mean_in_interval_one_sample_assertion(self):
+    rng = np.random.RandomState(seed=0)
+    num_samples = 5000
+
+    # Test that the test assertion agrees that the mean of the standard
+    # uniform distribution is between 0.4 and 0.6.
+    samples = rng.uniform(size=num_samples).astype(np.float32)
+    self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+        samples, 0., 1.,
+        expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6))
+
+    # Test that the test assertion confirms that the mean of the
+    # standard uniform distribution is not between 0.2 and 0.4.
+    with self.assertRaisesOpError("true mean greater than expected"):
+      self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+          samples, 0., 1.,
+          expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6))
+
+    # Test that the test assertion confirms that the mean of the
+    # standard uniform distribution is not between 0.6 and 0.8.
+    with self.assertRaisesOpError("true mean smaller than expected"):
+      self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+          samples, 0., 1.,
+          expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6))
+
   def test_dkwm_mean_two_sample_assertion(self):
     rng = np.random.RandomState(seed=0)
     num_samples = 4000
@@ -172,7 +197,7 @@ class StatisticalTestingTest(test.TestCase):
       # Test that the test assertion confirms that the mean of the
       # standard uniform distribution is different from the mean of beta(2, 1).
       beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32)
-      with self.assertRaisesOpError("samples1 has a smaller mean"):
+      with self.assertRaisesOpError("true mean smaller than expected"):
         sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
             samples1, 0., 1.,
             beta_high_samples, 0., 1.,
@@ -190,7 +215,7 @@ class StatisticalTestingTest(test.TestCase):
       # Test that the test assertion confirms that the mean of the
       # standard uniform distribution is different from the mean of beta(1, 2).
       beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32)
-      with self.assertRaisesOpError("samples2 has a smaller mean"):
+      with self.assertRaisesOpError("true mean greater than expected"):
         sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
             samples1, 0., 1.,
             beta_low_samples, 0., 1.,
index 9c69435..3ea9a33 100644 (file)
@@ -140,6 +140,7 @@ __all__ = [
     "assert_true_mean_equal_by_dkwm",
     "min_discrepancy_of_true_means_detectable_by_dkwm",
     "min_num_samples_for_dkwm_mean_test",
+    "assert_true_mean_in_interval_by_dkwm",
     "assert_true_mean_equal_by_dkwm_two_sample",
     "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample",
     "min_num_samples_for_dkwm_mean_two_sample_test",
@@ -454,20 +455,8 @@ def assert_true_mean_equal_by_dkwm(
   with ops.name_scope(
       name, "assert_true_mean_equal_by_dkwm",
       [samples, low, high, expected, false_fail_rate]):
-    samples = ops.convert_to_tensor(samples, name="samples")
-    low = ops.convert_to_tensor(low, name="low")
-    high = ops.convert_to_tensor(high, name="high")
-    expected = ops.convert_to_tensor(expected, name="expected")
-    false_fail_rate = ops.convert_to_tensor(
-        false_fail_rate, name="false_fail_rate")
-    samples = _check_shape_dominates(samples, [low, high, expected])
-    min_mean, max_mean = true_mean_confidence_interval_by_dkwm(
-        samples, low, high, error_rate=false_fail_rate)
-    less_op = check_ops.assert_less(
-        min_mean, expected, message="Mean confidence interval too high")
-    with ops.control_dependencies([less_op]):
-      return check_ops.assert_greater(
-          max_mean, expected, message="Mean confidence interval too low")
+    return assert_true_mean_in_interval_by_dkwm(
+        samples, low, high, expected, expected, false_fail_rate)
 
 
 def min_discrepancy_of_true_means_detectable_by_dkwm(
@@ -505,12 +494,15 @@ def min_discrepancy_of_true_means_detectable_by_dkwm(
   some scalar distribution supported on `[low[i], high[i]]` is enough
   to detect a difference in means of size `discr[i]` or more.
   Specifically, we guarantee that (a) if the true mean is the expected
-  mean, `assert_true_mean_equal_by_dkwm` will fail with probability at
-  most `false_fail_rate / K` (which amounts to `false_fail_rate` if
-  applied to the whole batch at once), and (b) if the true mean
-  differs from the expected mean by at least `discr[i]`,
-  `assert_true_mean_equal_by_dkwm` will pass with probability at most
-  `false_pass_rate`.
+  mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm`
+  (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with
+  probability at most `false_fail_rate / K` (which amounts to
+  `false_fail_rate` if applied to the whole batch at once), and (b) if
+  the true mean differs from the expected mean (resp. falls outside
+  the expected interval) by at least `discr[i]`,
+  `assert_true_mean_equal_by_dkwm`
+  (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with
+  probability at most `false_pass_rate`.
 
   The detectable discrepancy scales as
 
@@ -578,12 +570,15 @@ def min_num_samples_for_dkwm_mean_test(
   some scalar distribution supported on `[low[i], high[i]]` is enough
   to detect a difference in means of size `discrepancy[i]` or more.
   Specifically, we guarantee that (a) if the true mean is the expected
-  mean, `assert_true_mean_equal_by_dkwm` will fail with probability at
-  most `false_fail_rate / K` (which amounts to `false_fail_rate` if
-  applied to the whole batch at once), and (b) if the true mean
-  differs from the expected mean by at least `discrepancy[i]`,
-  `assert_true_mean_equal_by_dkwm` will pass with probability at most
-  `false_pass_rate`.
+  mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm`
+  (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with
+  probability at most `false_fail_rate / K` (which amounts to
+  `false_fail_rate` if applied to the whole batch at once), and (b) if
+  the true mean differs from the expected mean (resp. falls outside
+  the expected interval) by at least `discrepancy[i]`,
+  `assert_true_mean_equal_by_dkwm`
+  (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with
+  probability at most `false_pass_rate`.
 
   The required number of samples scales
   as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`,
@@ -610,6 +605,76 @@ def min_num_samples_for_dkwm_mean_test(
     return math_ops.maximum(n1, n2)
 
 
+def assert_true_mean_in_interval_by_dkwm(
+    samples, low, high, expected_low, expected_high,
+    false_fail_rate=1e-6, name=None):
+  """Asserts the mean of the given distribution is in the given interval.
+
+  More precisely, fails if there is enough evidence (using the
+  [Dvoretzky-Kiefer-Wolfowitz-Massart inequality]
+  (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval))
+  that the mean of the distribution from which the given samples are
+  drawn is _outside_ the given interval with statistical significance
+  `false_fail_rate` or stronger, otherwise passes.  If you also want
+  to check that you are gathering enough evidence that a pass is not
+  spurious, see `min_num_samples_for_dkwm_mean_test` and
+  `min_discrepancy_of_true_means_detectable_by_dkwm`.
+
+  Note that `false_fail_rate` is a total false failure rate for all
+  the assertions in the batch.  As such, if the batch is nontrivial,
+  the assertion will insist on stronger evidence to fail any one member.
+
+  Args:
+    samples: Floating-point `Tensor` of samples from the distribution(s)
+      of interest.  Entries are assumed IID across the 0th dimension.
+      The other dimensions must broadcast with `low` and `high`.
+      The support is bounded: `low <= samples <= high`.
+    low: Floating-point `Tensor` of lower bounds on the distributions'
+      supports.
+    high: Floating-point `Tensor` of upper bounds on the distributions'
+      supports.
+    expected_low: Floating-point `Tensor` of lower bounds on the
+      expected true means.
+    expected_high: Floating-point `Tensor` of upper bounds on the
+      expected true means.
+    false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+      rate of mistakes.
+    name: A name for this operation (optional).
+
+  Returns:
+    check: Op that raises `InvalidArgumentError` if any expected mean
+      interval does not overlap with the corresponding confidence
+      interval.
+  """
+  with ops.name_scope(
+      name, "assert_true_mean_in_interval_by_dkwm",
+      [samples, low, high, expected_low, expected_high, false_fail_rate]):
+    samples = ops.convert_to_tensor(samples, name="samples")
+    low = ops.convert_to_tensor(low, name="low")
+    high = ops.convert_to_tensor(high, name="high")
+    expected_low = ops.convert_to_tensor(expected_low, name="expected_low")
+    expected_high = ops.convert_to_tensor(expected_high, name="expected_high")
+    false_fail_rate = ops.convert_to_tensor(
+        false_fail_rate, name="false_fail_rate")
+    samples = _check_shape_dominates(
+        samples, [low, high, expected_low, expected_high])
+    min_mean, max_mean = true_mean_confidence_interval_by_dkwm(
+        samples, low, high, false_fail_rate)
+    # Assert that the interval [min_mean, max_mean] intersects the
+    # interval [expected_low, expected_high].  This is true if
+    #   max_mean >= expected_low and min_mean <= expected_high.
+    # By DeMorgan's law, that's also equivalent to
+    #   not (max_mean < expected_low or min_mean > expected_high),
+    # which is a way of saying the two intervals are not disjoint.
+    check_confidence_interval_can_intersect = check_ops.assert_greater_equal(
+        max_mean, expected_low, message="Confidence interval does not "
+        "intersect: true mean smaller than expected")
+    with ops.control_dependencies([check_confidence_interval_can_intersect]):
+      return check_ops.assert_less_equal(
+          min_mean, expected_high, message="Confidence interval does not "
+          "intersect: true mean greater than expected")
+
+
 def assert_true_mean_equal_by_dkwm_two_sample(
     samples1, low1, high1, samples2, low2, high2,
     false_fail_rate=1e-6, name=None):
@@ -676,20 +741,10 @@ def assert_true_mean_equal_by_dkwm_two_sample(
       # and sample counts should be valid; however, because the intervals
       # scale as O(-log(false_fail_rate)), there doesn't seem to be much
       # room to win.
-      min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm(
-          samples1, low1, high1, false_fail_rate / 2.)
       min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm(
           samples2, low2, high2, false_fail_rate / 2.)
-      # I want to assert
-      #   not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2),
-      # but I think I only have and-combination of asserts, so use DeMorgan.
-      check_confidence_intervals_can_intersect = check_ops.assert_greater_equal(
-          max_mean_1, min_mean_2, message="Confidence intervals do not "
-          "intersect: samples1 has a smaller mean than samples2")
-      with ops.control_dependencies([check_confidence_intervals_can_intersect]):
-        return check_ops.assert_less_equal(
-            min_mean_1, max_mean_2, message="Confidence intervals do not "
-            "intersect: samples2 has a smaller mean than samples1")
+      return assert_true_mean_in_interval_by_dkwm(
+          samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.)
 
 
 def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(