[tf.data] Split out the `tf.contrib.data.sample_from_datasets()` tests.
authorDerek Murray <mrry@google.com>
Thu, 24 May 2018 00:32:54 +0000 (17:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 00:35:41 +0000 (17:35 -0700)
These were previously broken and disabled in CI builds; this change also fixes
them up.

PiperOrigin-RevId: 197818554

tensorflow/contrib/cmake/tf_core_kernels.cmake
tensorflow/contrib/data/python/kernel_tests/BUILD
tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py [new file with mode: 0644]
tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py

index 90c5852..2d76bf5 100644 (file)
@@ -69,6 +69,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
       "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc"
+      "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc"
index 53da494..d269b5b 100644 (file)
@@ -208,6 +208,23 @@ py_test(
     ],
 )
 
+py_test(
+    name = "directed_interleave_dataset_test",
+    size = "medium",
+    srcs = ["directed_interleave_dataset_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":dataset_serialization_test",
+        "//tensorflow/contrib/data/python/ops:interleave_ops",
+        "//tensorflow/python:client",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:training",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//third_party/py/numpy",
+    ],
+)
+
 tf_py_test(
     name = "get_single_element_test",
     size = "small",
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
new file mode 100644 (file)
index 0000000..d071eb1
--- /dev/null
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import random_seed
+from tensorflow.python.platform import test
+
+
+class DirectedInterleaveDatasetTest(test.TestCase):
+
+  def testBasic(self):
+    selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
+    input_datasets = [
+        dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
+    ]
+    dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
+                                                       input_datasets)
+    iterator = dataset.make_initializable_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      sess.run(iterator.initializer)
+      for _ in range(100):
+        for i in range(10):
+          self.assertEqual(i, sess.run(next_element))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
+  def _normalize(self, vec):
+    return vec / vec.sum()
+
+  def _chi2(self, expected, actual):
+    actual = np.asarray(actual)
+    expected = np.asarray(expected)
+    diff = actual - expected
+    chi2 = np.sum(diff * diff / expected, axis=0)
+    return chi2
+
+  def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
+    # Create a dataset that samples each integer in `[0, num_datasets)`
+    # with probability given by `weights[i]`.
+    dataset = interleave_ops.sample_from_datasets([
+        dataset_ops.Dataset.from_tensors(i).repeat(None)
+        for i in range(num_datasets)
+    ], weights)
+    dataset = dataset.take(num_samples)
+    iterator = dataset.make_one_shot_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      freqs = np.zeros([num_datasets])
+      for _ in range(num_samples):
+        freqs[sess.run(next_element)] += 1
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
+    return freqs
+
+  def testSampleFromDatasets(self):
+    random_seed.set_random_seed(1619)
+    num_samples = 5000
+    rand_probs = self._normalize(np.random.random_sample((15,)))
+
+    # Use chi-squared test to assert that the observed distribution matches the
+    # expected distribution. Based on the implementation in
+    # "tensorflow/python/kernel_tests/multinomial_op_test.py".
+    for probs in [[.85, .05, .1], rand_probs]:
+      probs = np.asarray(probs)
+      classes = len(probs)
+      freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
+      self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+      # Also check that `weights` as a dataset samples correctly.
+      probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
+      freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
+      self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+  def testErrors(self):
+    with self.assertRaisesRegexp(ValueError,
+                                 r"vector of length `len\(datasets\)`"):
+      interleave_ops.sample_from_datasets(
+          [dataset_ops.Dataset.range(10),
+           dataset_ops.Dataset.range(20)],
+          weights=[0.25, 0.25, 0.25, 0.25])
+
+    with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
+      interleave_ops.sample_from_datasets(
+          [dataset_ops.Dataset.range(10),
+           dataset_ops.Dataset.range(20)],
+          weights=[1, 1])
+
+    with self.assertRaisesRegexp(TypeError, "must have the same type"):
+      interleave_ops.sample_from_datasets([
+          dataset_ops.Dataset.from_tensors(0),
+          dataset_ops.Dataset.from_tensors(0.0)
+      ])
+
+
+class SampleFromDatasetsSerializationTest(
+    dataset_serialization_test_base.DatasetSerializationTestBase):
+
+  def _build_dataset(self, probs, num_samples):
+    dataset = interleave_ops.sample_from_datasets(
+        [
+            dataset_ops.Dataset.from_tensors(i).repeat(None)
+            for i in range(len(probs))
+        ],
+        probs,
+        seed=1813)
+    return dataset.take(num_samples)
+
+  def testSerializationCore(self):
+    self.run_core_tests(
+        lambda: self._build_dataset([0.5, 0.5], 100),
+        lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+
+
+if __name__ == "__main__":
+  test.main()
index 43aa4b1..bee561e 100644 (file)
@@ -30,7 +30,6 @@ from tensorflow.contrib.data.python.ops import interleave_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
@@ -907,114 +906,5 @@ class ParallelInterleaveDatasetTest(test.TestCase):
         sess.run(self.next_element)
 
 
-class DirectedInterleaveDatasetTest(test.TestCase):
-
-  def testBasic(self):
-    selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
-    input_datasets = [
-        dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
-    ]
-    dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
-                                                       input_datasets)
-    iterator = dataset.make_initializable_iterator()
-    next_element = iterator.get_next()
-
-    with self.test_session() as sess:
-      sess.run(iterator.initializer)
-      for _ in range(100):
-        for i in range(10):
-          self.assertEqual(i, sess.run(next_element))
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
-
-  def _normalize(self, vec):
-    return vec / vec.sum()
-
-  def _chi2(self, expected, actual):
-    actual = np.asarray(actual)
-    expected = np.asarray(expected)
-    diff = actual - expected
-    chi2 = np.sum(diff * diff / expected, axis=0)
-    return chi2
-
-  def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
-    # Create a dataset that samples each integer in `[0, num_datasets)`
-    # with probability given by `weights[i]`.
-    dataset = interleave_ops.sample_from_datasets([
-        dataset_ops.Dataset.from_tensors(i).repeat(None)
-        for i in range(num_datasets)
-    ], weights)
-    dataset = dataset.take(num_samples)
-    iterator = dataset.make_one_shot_iterator()
-    next_element = iterator.get_next()
-
-    with self.test_session() as sess:
-      freqs = np.zeros([num_datasets])
-      for _ in range(num_samples):
-        freqs[sess.run(next_element)] += 1
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
-
-    return freqs
-
-  def testSampleFromDatasets(self):
-    random_seed.set_random_seed(1619)
-    num_samples = 10000
-    rand_probs = self._normalize(np.random.random_sample((15,)))
-
-    # Use chi-squared test to assert that the observed distribution matches the
-    # expected distribution. Based on the implementation in
-    # "tensorflow/python/kernel_tests/multinomial_op_test.py".
-    for probs in [[.85, .05, .1], rand_probs]:
-      probs = np.asarray(probs)
-      classes = len(probs)
-      freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
-      self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
-
-      # Also check that `weights` as a dataset samples correctly.
-      probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
-      freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
-      self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
-
-  def testErrors(self):
-    with self.assertRaisesRegexp(ValueError,
-                                 r"vector of length `len\(datasets\)`"):
-      interleave_ops.sample_from_datasets(
-          [dataset_ops.Dataset.range(10),
-           dataset_ops.Dataset.range(20)],
-          weights=[0.25, 0.25, 0.25, 0.25])
-
-    with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
-      interleave_ops.sample_from_datasets(
-          [dataset_ops.Dataset.range(10),
-           dataset_ops.Dataset.range(20)],
-          weights=[1, 1])
-
-    with self.assertRaisesRegexp(TypeError, "must have the same type"):
-      interleave_ops.sample_from_datasets([
-          dataset_ops.Dataset.from_tensors(0),
-          dataset_ops.Dataset.from_tensors(0.0)
-      ])
-
-
-class SampleFromDatasetsSerializationTest(
-    dataset_serialization_test_base.DatasetSerializationTestBase):
-
-  def _build_dataset(self, probs, num_samples):
-    dataset = interleave_ops.sample_from_datasets(
-        [
-            dataset_ops.Dataset.from_tensors(i).repeat(None)
-            for i in range(len(probs))
-        ],
-        probs,
-        seed=1813)
-    return dataset.take(num_samples)
-
-  def testSerializationCore(self):
-    self.run_core_tests(
-        lambda: self._build_dataset([0.5, 0.5], 100),
-        lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
-
-
 if __name__ == "__main__":
   test.main()