[tf.data] Allow `sample_from_datasets` to accept a tf.Dataset object for `weights`.
authorjoel-shor <joelshor@google.com>
Thu, 19 Apr 2018 22:41:28 +0000 (01:41 +0300)
committerjoel-shor <joelshor@google.com>
Thu, 19 Apr 2018 22:41:28 +0000 (01:41 +0300)
Tested:
bazel test :interleave_dataset_op_test

tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
tensorflow/contrib/data/python/ops/interleave_ops.py

index ff6d0c3..43aa4b1 100644 (file)
@@ -928,8 +928,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
         sess.run(next_element)
 
   def _normalize(self, vec):
-    batched = (len(vec.shape) == 2)
-    return vec / vec.sum(axis=1, keepdims=True) if batched else vec / vec.sum()
+    return vec / vec.sum()
 
   def _chi2(self, expected, actual):
     actual = np.asarray(actual)
@@ -938,35 +937,43 @@ class DirectedInterleaveDatasetTest(test.TestCase):
     chi2 = np.sum(diff * diff / expected, axis=0)
     return chi2
 
+  def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
+    # Create a dataset that samples each integer in `[0, num_datasets)`
+    # with probability given by `weights[i]`.
+    dataset = interleave_ops.sample_from_datasets([
+        dataset_ops.Dataset.from_tensors(i).repeat(None)
+        for i in range(num_datasets)
+    ], weights)
+    dataset = dataset.take(num_samples)
+    iterator = dataset.make_one_shot_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      freqs = np.zeros([num_datasets])
+      for _ in range(num_samples):
+        freqs[sess.run(next_element)] += 1
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
+    return freqs
+
   def testSampleFromDatasets(self):
-    random_seed.set_random_seed(1618)
+    random_seed.set_random_seed(1619)
     num_samples = 10000
-    rand_probs = self._normalize(np.random.random_sample((10,)))
-    rand_probs2 = self._normalize(np.random.random_sample((15,)))
+    rand_probs = self._normalize(np.random.random_sample((15,)))
 
-    for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]:
+    # Use chi-squared test to assert that the observed distribution matches the
+    # expected distribution. Based on the implementation in
+    # "tensorflow/python/kernel_tests/multinomial_op_test.py".
+    for probs in [[.85, .05, .1], rand_probs]:
       probs = np.asarray(probs)
+      classes = len(probs)
+      freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
+      self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
 
-      # Create a dataset that samples each integer in `[0, probs.shape[0])`
-      # with probability given by `probs[i]`.
-      dataset = interleave_ops.sample_from_datasets([
-          dataset_ops.Dataset.from_tensors(i).repeat(None)
-          for i in range(probs.shape[0])
-      ], probs)
-      dataset = dataset.take(num_samples)
-      iterator = dataset.make_one_shot_iterator()
-      next_element = iterator.get_next()
-
-      with self.test_session() as sess:
-        freqs = np.zeros_like(probs)
-        for _ in range(num_samples):
-          freqs[sess.run(next_element)] += 1
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(next_element)
-
-      # Use chi-squared test to assert that the observed distribution
-      # matches the expected distribution. Based on the implementation
-      # in "tensorflow/python/kernel_tests/multinomial_op_test.py".
+      # Also check that `weights` as a dataset samples correctly.
+      probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
+      freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
       self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
 
   def testErrors(self):
index 106a1ef..5ae1fa9 100644 (file)
@@ -200,10 +200,10 @@ def sample_from_datasets(datasets, weights=None, seed=None):
 
   Args:
     datasets: A list of @{tf.data.Dataset} objects with compatible structure.
-    weights: (Optional.) A list of `len(datasets)` floating-point values,
-      where `weights[i]` represents the probability with which an element
-      should be sampled from `datasets[i]`. Defaults to a uniform distribution
-      across `datasets`.
+    weights: (Optional.) A list of `len(datasets)` floating-point values or a
+      @{tf.data.Dataset} object, where `weights[i]` represents the probability
+      with which an element should be sampled from `datasets[i]`. Defaults to a
+      uniform distribution across `datasets`.
     seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
       random seed that will be used to create the distribution. See
       @{tf.set_random_seed} for behavior.
@@ -219,24 +219,23 @@ def sample_from_datasets(datasets, weights=None, seed=None):
   """
   num_datasets = len(datasets)
   if weights is None:
-    weights = array_ops.ones(
-        [num_datasets], dtype=dtypes.float32, name="weights")
-  else:
+    weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat()
+  elif not isinstance(weights, dataset_ops.Dataset):
     weights = ops.convert_to_tensor(weights, name="weights")
     if weights.dtype not in (dtypes.float32, dtypes.float64):
       raise TypeError("`weights` must be convertible to a tensor of "
                       "`tf.float32` or `tf.float64` elements.")
     if not weights.shape.is_compatible_with([num_datasets]):
       raise ValueError("`weights` must be a vector of length `len(datasets)`.")
+    weights = dataset_ops.Dataset.from_tensors(weights).repeat()
 
   # The `stateless_multinomial()` op expects log-probabilities, as opposed to
   # weights.
-  logits = math_ops.log(weights, name="logits")
-
-  def select_dataset(seed):
+  logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+  def select_dataset(logits, seed):
     return array_ops.squeeze(
-        stateless.stateless_multinomial([logits], 1, seed=seed), axis=[0, 1])
-
-  selector_input = random_ops.RandomDataset(seed).batch(2).map(select_dataset)
+        stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+  selector_input = dataset_ops.Dataset.zip(
+      (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
 
   return DirectedInterleaveDataset(selector_input, datasets)