Speed up statistical_testing_test by consolidating sess.run calls.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 04:22:54 +0000 (21:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 04:25:24 +0000 (21:25 -0700)
PiperOrigin-RevId: 190721153

tensorflow/contrib/distributions/BUILD
tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py

index 1c381cc..682448b 100644 (file)
@@ -486,6 +486,7 @@ cuda_py_test(
         "//third_party/py/numpy",
         "//tensorflow/python:client_testlib",
     ],
+    shard_count = 4,
     tags = [
         "manual",
         "noasan",
index 3548ac1..c0e7bdd 100644 (file)
@@ -22,39 +22,75 @@ import numpy as np
 
 from tensorflow.contrib.distributions.python.ops import statistical_testing as st
 from tensorflow.python.framework import errors
-from tensorflow.python.ops import check_ops
 from tensorflow.python.platform import test
 
 
 class StatisticalTestingTest(test.TestCase):
 
   def test_dkwm_design_mean_one_sample_soundness(self):
-    numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
+    thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
     rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.]
-    with self.test_session() as sess:
-      for ff in rates:
-        for fp in rates:
-          sufficient_n = st.min_num_samples_for_dkwm_mean_test(
-              numbers, 0., 1., false_fail_rate=ff, false_pass_rate=fp)
-          detectable_d = st.min_discrepancy_of_true_means_detectable_by_dkwm(
-              sufficient_n, 0., 1., false_fail_rate=ff, false_pass_rate=fp)
-          sess.run(check_ops.assert_less_equal(detectable_d, numbers))
+    false_fail_rates, false_pass_rates = np.meshgrid(rates, rates)
+    false_fail_rates = false_fail_rates.flatten().astype(np.float32)
+    false_pass_rates = false_pass_rates.flatten().astype(np.float32)
+
+    detectable_discrepancies = []
+    for false_pass_rate, false_fail_rate in zip(
+        false_pass_rates, false_fail_rates):
+      sufficient_n = st.min_num_samples_for_dkwm_mean_test(
+          thresholds, low=0., high=1., false_fail_rate=false_fail_rate,
+          false_pass_rate=false_pass_rate)
+      detectable_discrepancies.append(
+          st.min_discrepancy_of_true_means_detectable_by_dkwm(
+              sufficient_n, low=0., high=1., false_fail_rate=false_fail_rate,
+              false_pass_rate=false_pass_rate))
+
+    detectable_discrepancies_ = self.evaluate(detectable_discrepancies)
+    for discrepancies, false_pass_rate, false_fail_rate in zip(
+        detectable_discrepancies_, false_pass_rates, false_fail_rates):
+      below_threshold = discrepancies <= thresholds
+      self.assertAllEqual(
+          np.ones_like(below_threshold, np.bool), below_threshold,
+          msg='false_pass_rate({}), false_fail_rate({})'.format(
+              false_pass_rate, false_fail_rate))
 
   def test_dkwm_design_mean_two_sample_soundness(self):
-    numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
+    thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
     rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.]
-    with self.test_session() as sess:
-      for ff in rates:
-        for fp in rates:
-          (sufficient_n1,
-           sufficient_n2) = st.min_num_samples_for_dkwm_mean_two_sample_test(
-               numbers, 0., 1., 0., 1.,
-               false_fail_rate=ff, false_pass_rate=fp)
-          d_fn = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample
-          detectable_d = d_fn(
-              sufficient_n1, 0., 1., sufficient_n2, 0., 1.,
-              false_fail_rate=ff, false_pass_rate=fp)
-          sess.run(check_ops.assert_less_equal(detectable_d, numbers))
+    false_fail_rates, false_pass_rates = np.meshgrid(rates, rates)
+    false_fail_rates = false_fail_rates.flatten().astype(np.float32)
+    false_pass_rates = false_pass_rates.flatten().astype(np.float32)
+
+    detectable_discrepancies = []
+    for false_pass_rate, false_fail_rate in zip(
+        false_pass_rates, false_fail_rates):
+      [
+          sufficient_n1,
+          sufficient_n2
+      ] = st.min_num_samples_for_dkwm_mean_two_sample_test(
+          thresholds, low1=0., high1=1., low2=0., high2=1.,
+          false_fail_rate=false_fail_rate,
+          false_pass_rate=false_pass_rate)
+
+      detectable_discrepancies.append(
+          st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
+              n1=sufficient_n1,
+              low1=0.,
+              high1=1.,
+              n2=sufficient_n2,
+              low2=0.,
+              high2=1.,
+              false_fail_rate=false_fail_rate,
+              false_pass_rate=false_pass_rate))
+
+    detectable_discrepancies_ = self.evaluate(detectable_discrepancies)
+    for discrepancies, false_pass_rate, false_fail_rate in zip(
+        detectable_discrepancies_, false_pass_rates, false_fail_rates):
+      below_threshold = discrepancies <= thresholds
+      self.assertAllEqual(
+          np.ones_like(below_threshold, np.bool), below_threshold,
+          msg='false_pass_rate({}), false_fail_rate({})'.format(
+              false_pass_rate, false_fail_rate))
 
   def test_true_mean_confidence_interval_by_dkwm_one_sample(self):
     rng = np.random.RandomState(seed=0)