From f500bcb889b3598f386f59eb69a79af6b704bf50 Mon Sep 17 00:00:00 2001 From: joel-shor Date: Fri, 20 Apr 2018 01:41:28 +0300 Subject: [PATCH] [tf.data] Allow `sample_from_datasets` to accept a tf.Dataset object for `weights`. Tested: bazel test :interleave_dataset_op_test --- .../kernel_tests/interleave_dataset_op_test.py | 59 ++++++++++++---------- .../contrib/data/python/ops/interleave_ops.py | 25 +++++---- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index ff6d0c3..43aa4b1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -928,8 +928,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): sess.run(next_element) def _normalize(self, vec): - batched = (len(vec.shape) == 2) - return vec / vec.sum(axis=1, keepdims=True) if batched else vec / vec.sum() + return vec / vec.sum() def _chi2(self, expected, actual): actual = np.asarray(actual) @@ -938,35 +937,43 @@ class DirectedInterleaveDatasetTest(test.TestCase): chi2 = np.sum(diff * diff / expected, axis=0) return chi2 + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + def testSampleFromDatasets(self): - random_seed.set_random_seed(1618) + random_seed.set_random_seed(1619) num_samples = 10000 - rand_probs = self._normalize(np.random.random_sample((10,))) - rand_probs2 = self._normalize(np.random.random_sample((15,))) + rand_probs = self._normalize(np.random.random_sample((15,))) - for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]: + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - # Create a dataset that samples each integer in `[0, probs.shape[0])` - # with probability given by `probs[i]`. - dataset = interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(probs.shape[0]) - ], probs) - dataset = dataset.take(num_samples) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - freqs = np.zeros_like(probs) - for _ in range(num_samples): - freqs[sess.run(next_element)] += 1 - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Use chi-squared test to assert that the observed distribution - # matches the expected distribution. Based on the implementation - # in "tensorflow/python/kernel_tests/multinomial_op_test.py". + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) def testErrors(self): diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 106a1ef..5ae1fa9 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -200,10 +200,10 @@ def sample_from_datasets(datasets, weights=None, seed=None): Args: datasets: A list of @{tf.data.Dataset} objects with compatible structure. - weights: (Optional.) A list of `len(datasets)` floating-point values, - where `weights[i]` represents the probability with which an element - should be sampled from `datasets[i]`. Defaults to a uniform distribution - across `datasets`. + weights: (Optional.) A list of `len(datasets)` floating-point values or a + @{tf.data.Dataset} object, where `weights[i]` represents the probability + with which an element should be sampled from `datasets[i]`. Defaults to a + uniform distribution across `datasets`. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See @{tf.set_random_seed} for behavior. @@ -219,24 +219,23 @@ def sample_from_datasets(datasets, weights=None, seed=None): """ num_datasets = len(datasets) if weights is None: - weights = array_ops.ones( - [num_datasets], dtype=dtypes.float32, name="weights") - else: + weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat() + elif not isinstance(weights, dataset_ops.Dataset): weights = ops.convert_to_tensor(weights, name="weights") if weights.dtype not in (dtypes.float32, dtypes.float64): raise TypeError("`weights` must be convertible to a tensor of " "`tf.float32` or `tf.float64` elements.") if not weights.shape.is_compatible_with([num_datasets]): raise ValueError("`weights` must be a vector of length `len(datasets)`.") + weights = dataset_ops.Dataset.from_tensors(weights).repeat() # The `stateless_multinomial()` op expects log-probabilities, as opposed to # weights. - logits = math_ops.log(weights, name="logits") - - def select_dataset(seed): + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + def select_dataset(logits, seed): return array_ops.squeeze( - stateless.stateless_multinomial([logits], 1, seed=seed), axis=[0, 1]) - - selector_input = random_ops.RandomDataset(seed).batch(2).map(select_dataset) + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + selector_input = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) return DirectedInterleaveDataset(selector_input, datasets) -- 2.7.4