From 6e96bf989014ef3079d668c93f3ebebff30e3e37 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 24 May 2018 11:37:12 -0700 Subject: [PATCH] [tf.data] Add `tf.contrib.data.choose_from_datasets()`. This is a deterministic counterpart to `tf.contrib.data.sample_from_datasets()`. PiperOrigin-RevId: 197926386 --- tensorflow/contrib/data/__init__.py | 1 + .../directed_interleave_dataset_test.py | 27 +++++++++++++ .../contrib/data/python/ops/interleave_ops.py | 45 ++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index a25aa85..1af1ed0 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -30,6 +30,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length +@@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index d071eb1..34b6a08 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -96,6 +96,21 @@ class DirectedInterleaveDatasetTest(test.TestCase): freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testErrors(self): with self.assertRaisesRegexp(ValueError, r"vector of length `len\(datasets\)`"): @@ -116,6 +131,18 @@ class DirectedInterleaveDatasetTest(test.TestCase): dataset_ops.Dataset.from_tensors(0.0) ]) + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + class SampleFromDatasetsSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 812a50e..be66fba 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -240,3 +241,47 @@ def sample_from_datasets(datasets, weights=None, seed=None): (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) return DirectedInterleaveDataset(selector_input, datasets) + + +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return DirectedInterleaveDataset(choice_dataset, datasets) -- 2.7.4