[tf.data] Just replace old resample with new.
authorjoel-shor <joelshor@google.com>
Thu, 26 Apr 2018 23:21:44 +0000 (02:21 +0300)
committerjoel-shor <joelshor@google.com>
Thu, 26 Apr 2018 23:21:44 +0000 (02:21 +0300)
Also, add an optimization / bug fix that shortcircuits combining the two datasets if one should always be sampled from.

Tested:

bazel test :resample_test

tensorflow/contrib/data/python/kernel_tests/resample_test.py
tensorflow/contrib/data/python/ops/resampling.py

index 7f007fe..fc84301 100644 (file)
@@ -34,14 +34,12 @@ from tensorflow.python.util import compat
 
 
 def _time_resampling(
-    test_obj, data_np, target_dist, init_dist, use_v2, num_to_sample):
+    test_obj, data_np, target_dist, init_dist, num_to_sample):
   dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
 
   # Reshape distribution via rejection sampling.
-  apply_fn = (resampling.rejection_resample_v2 if use_v2 else
-              resampling.rejection_resample)
   dataset = dataset.apply(
-      apply_fn(
+      resampling.rejection_resample(
           class_func=lambda x: x,
           target_dist=target_dist,
           initial_dist=init_dist,
@@ -61,20 +59,17 @@ def _time_resampling(
 class ResampleTest(test.TestCase, parameterized.TestCase):
 
   @parameterized.named_parameters(
-      ("InitialnDistributionKnown", True, False),
-      ("InitialDistributionUnknown", False, False),
-      ("InitialDistributionKnownV2", True, True),
-      ("InitialDistributionUnknownV2", False, True))
-  def testDistribution(self, initial_known, use_v2):
+      ("InitialnDistributionKnown", True),
+      ("InitialDistributionUnknown", False))
+  def testDistribution(self, initial_known):
     classes = np.random.randint(5, size=(20000,))  # Uniformly sampled
     target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
     initial_dist = [0.2] * 5 if initial_known else None
     dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
         200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
-    apply_fn = (resampling.rejection_resample_v2 if use_v2 else
-                resampling.rejection_resample)
+
     get_next = dataset.apply(
-        apply_fn(
+        resampling.rejection_resample(
             target_dist=target_dist,
             initial_dist=initial_dist,
             class_func=lambda c, _: c,
@@ -96,11 +91,39 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
     returned_dist = class_counts / total_returned
     self.assertAllClose(target_dist, returned_dist, atol=1e-2)
 
+  @parameterized.named_parameters(
+      ("OnlyInitial", True),
+      ("NotInitial", False))
+  def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
+    init_dist = [0.5, 0.5]
+    target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
+    num_classes = len(init_dist)
+    # We don't need many samples to test that this works.
+    num_samples = 100
+    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+    dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
+
+    # Reshape distribution.
+    dataset = dataset.apply(
+        resampling.rejection_resample(
+            class_func=lambda x: x,
+            target_dist=target_dist,
+            initial_dist=init_dist))
+
+    get_next = dataset.make_one_shot_iterator().get_next()
+
+    with self.test_session() as sess:
+      returned = []
+      with self.assertRaises(errors.OutOfRangeError):
+        while True:
+          returned.append(sess.run(get_next))
+
   def testRandomClasses(self):
     init_dist = [0.25, 0.25, 0.25, 0.25]
     target_dist = [0.0, 0.0, 0.0, 1.0]
     num_classes = len(init_dist)
-    # We don't need many samples to test a dirac-delta target distribution
+    # We don't need many samples to test a dirac-delta target distribution.
     num_samples = 100
     data_np = np.random.choice(num_classes, num_samples, p=init_dist)
 
@@ -134,26 +157,8 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
 
     self.assertAllClose(target_dist, bincount, atol=1e-2)
 
-  @parameterized.named_parameters(
-      ("SmallSkewManySamples", [0.1, 0.1, 0.1, 0.7], 1000),
-      ("BigSkewManySamples", [0.01, 0.01, 0.01, 0.97], 1000),
-      ("SmallSkewFewSamples", [0.1, 0.1, 0.1, 0.7], 100),
-      ("BigSkewFewSamples", [0.01, 0.01, 0.01, 0.97], 100))
-  def testNewResampleIsFaster(self, target_dist, num_to_sample):
-    init_dist = [0.25, 0.25, 0.25, 0.25]
-    num_classes = len(init_dist)
-    num_samples = 1000
-    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
-    fast_time = _time_resampling(self, data_np, target_dist, init_dist,
-                                 use_v2=True, num_to_sample=num_to_sample)
-    slow_time = _time_resampling(self, data_np, target_dist, init_dist,
-                                 use_v2=False, num_to_sample=num_to_sample)
-
-    self.assertLess(fast_time, slow_time)
-
 
-class MapDatasetBenchmark(test.Benchmark):
+class ResampleDatasetBenchmark(test.Benchmark):
 
   def benchmarkResamplePerformance(self):
     init_dist = [0.25, 0.25, 0.25, 0.25]
@@ -164,25 +169,11 @@ class MapDatasetBenchmark(test.Benchmark):
     data_np = np.random.choice(num_classes, num_samples, p=init_dist)
 
     resample_time = _time_resampling(
-        self, data_np, target_dist, init_dist, use_v2=False, num_to_sample=1000)
+        self, data_np, target_dist, init_dist, num_to_sample=1000)
 
     self.report_benchmark(
         iters=1000, wall_time=resample_time, name="benchmark_resample")
 
-  def benchmarkResampleAndBatchPerformance(self):
-    init_dist = [0.25, 0.25, 0.25, 0.25]
-    target_dist = [0.0, 0.0, 0.0, 1.0]
-    num_classes = len(init_dist)
-    # We don't need many samples to test a dirac-delta target distribution
-    num_samples = 1000
-    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
-    resample_time = _time_resampling(
-        self, data_np, target_dist, init_dist, use_v2=True, num_to_sample=1000)
-
-    self.report_benchmark(
-        iters=1000, wall_time=resample_time, name="benchmark_resample_v2")
-
 
 if __name__ == "__main__":
   test.main()
index 16d851b..66eaf9b 100644 (file)
@@ -58,62 +58,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
 
     # Get initial distribution.
     if initial_dist is not None:
-      initial_dist_t = ops.convert_to_tensor(
-          initial_dist, name="initial_dist")
-      acceptance_dist = _calculate_acceptance_probs(initial_dist_t,
-                                                    target_dist_t)
-      initial_dist_ds = dataset_ops.Dataset.from_tensors(
-          initial_dist_t).repeat()
-      acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
-          acceptance_dist).repeat()
-    else:
-      initial_dist_ds = _estimate_initial_dist_ds(
-          target_dist_t, class_values_ds)
-      acceptance_dist_ds = initial_dist_ds.map(
-          lambda initial: _calculate_acceptance_probs(initial, target_dist_t))
-    return _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
-                      class_values_ds, seed)
-
-  return _apply_fn
-
-
-def rejection_resample_v2(class_func, target_dist, initial_dist=None,
-                          seed=None):
-  """A transformation that resamples a dataset to achieve a target distribution.
-
-  This differs from v1 in that it will also sample from the original dataset
-  with some probability, so it makes strictly fewer data rejections. Due to an
-  implementation detail it must initialize a separate dataset initializer, so
-  the dataset becomes stateful after this transformation is applied
-  (`make_one_shot_iterator` won't work; users must use
-  `make_initializable_iterator`). This transformation is faster than the
-  original, except for overhead.
-
-  **NOTE** Resampling is performed via rejection sampling; some fraction
-  of the input values will be dropped.
-
-  Args:
-    class_func: A function mapping an element of the input dataset to a scalar
-      `tf.int32` tensor. Values should be in `[0, num_classes)`.
-    target_dist: A floating point type tensor, shaped `[num_classes]`.
-    initial_dist: (Optional.)  A floating point type tensor, shaped
-      `[num_classes]`.  If not provided, the true class distribution is
-      estimated live in a streaming fashion.
-    seed: (Optional.) Python integer seed for the resampler.
-
-  Returns:
-    A `Dataset` transformation function, which can be passed to
-    @{tf.data.Dataset.apply}.
-  """
-  def _apply_fn(dataset):
-    """Function from `Dataset` to `Dataset` that applies the transformation."""
-    target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
-    class_values_ds = dataset.map(class_func)
-
-    # Get initial distribution.
-    if initial_dist is not None:
-      initial_dist_t = ops.convert_to_tensor(
-          initial_dist, name="initial_dist")
+      initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
       acceptance_dist, prob_of_original = (
           _calculate_acceptance_probs_with_mixing(initial_dist_t,
                                                   target_dist_t))
@@ -133,19 +78,51 @@ def rejection_resample_v2(class_func, target_dist, initial_dist=None,
           lambda accept_prob, _: accept_prob)
       prob_of_original_ds = acceptance_and_original_prob_ds.map(
           lambda _, prob_original: prob_original)
+      prob_of_original = None
     filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
                              class_values_ds, seed)
     # Prefetch filtered dataset for speed.
     filtered_ds = filtered_ds.prefetch(3)
 
-    return interleave_ops.sample_from_datasets(
-        [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
-        weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
-        seed=seed)
+    prob_original_static = _get_prob_original_static(
+        initial_dist, target_dist_t) if initial_dist is not None else None
+    if prob_original_static == 1:
+      return dataset_ops.Dataset.zip((class_values_ds, dataset))
+    elif prob_original_static == 0:
+      return filtered_ds
+    else:
+      return interleave_ops.sample_from_datasets(
+          [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
+          weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
+          seed=seed)
 
   return _apply_fn
 
 
+def _get_prob_original_static(initial_dist_t, target_dist_t):
+  """Returns the static probability of sampling from the original.
+
+  For some reason, `tensor_util.constant_value(prob_of_original)` of a ratio
+  of two constant Tensors isn't a constant. We have some custom logic to avoid
+  this.
+
+  Args:
+    initial_dist_t: A tensor of the initial distribution.
+    target_dist_t: A tensor of the target distribution.
+
+  Returns:
+    The probability of sampling from the original distribution as a constant,
+    if it is a constant, or `None`.
+  """
+  init_static = tensor_util.constant_value(initial_dist_t)
+  target_static = tensor_util.constant_value(target_dist_t)
+
+  if init_static is None or target_static is None:
+    return None
+  else:
+    return np.min(target_static / init_static)
+
+
 def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
                seed):
   """Filters a dataset based on per-class acceptance probabilities.
@@ -216,16 +193,42 @@ def _get_target_to_initial_ratio(initial_probs, target_probs):
   return target_probs / denom
 
 
-def _calculate_acceptance_probs(initial_probs, target_probs):
-  """Calculate the per-class acceptance rates.
+def _estimate_data_distribution(c, num_examples_per_class_seen):
+  """Estimate data distribution as labels are seen.
 
   Args:
-    initial_probs: The class probabilities of the data.
-    target_probs: The desired class proportion in minibatches.
+    c: The class labels.  Type `int32`, shape `[batch_size]`.
+    num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
+      containing counts.
+
   Returns:
-    A list of the per-class acceptance probabilities.
+    num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
+      `[num_classes]`.
+    dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
+  """
+  num_classes = num_examples_per_class_seen.get_shape()[0].value
+  # Update the class-count based on what labels are seen in batch.
+  num_examples_per_class_seen = math_ops.add(
+      num_examples_per_class_seen, math_ops.reduce_sum(
+          array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
+  init_prob_estimate = math_ops.truediv(
+      num_examples_per_class_seen,
+      math_ops.reduce_sum(num_examples_per_class_seen))
+  dist = math_ops.cast(init_prob_estimate, dtypes.float32)
+  return num_examples_per_class_seen, dist
 
-  This method is based on solving the following analysis:
+
+def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
+  """Calculates the acceptance probabilities and mixing ratio.
+
+  In this case, we assume that we can *either* sample from the original data
+  distribution with probability `m`, or sample from a reshaped distribution
+  that comes from rejection sampling on the original distribution. This
+  rejection sampling is done on a per-class basis, with `a_i` representing the
+  probability of accepting data from class `i`.
+
+  This method is based on solving the following analysis for the reshaped
+  distribution:
 
   Let F be the probability of a rejection (on any example).
   Let p_i be the proportion of examples in the data in class i (init_probs)
@@ -256,47 +259,6 @@ def _calculate_acceptance_probs(initial_probs, target_probs):
 
   A solution for a_i in terms of the other variables is the following:
     ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
-  """
-  ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
-
-  # Calculate list of acceptance probabilities.
-  max_ratio = math_ops.reduce_max(ratio_l)
-  return ratio_l / max_ratio
-
-
-def _estimate_data_distribution(c, num_examples_per_class_seen):
-  """Estimate data distribution as labels are seen.
-
-  Args:
-    c: The class labels.  Type `int32`, shape `[batch_size]`.
-    num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
-      containing counts.
-
-  Returns:
-    num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
-      `[num_classes]`.
-    dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
-  """
-  num_classes = num_examples_per_class_seen.get_shape()[0].value
-  # Update the class-count based on what labels are seen in batch.
-  num_examples_per_class_seen = math_ops.add(
-      num_examples_per_class_seen, math_ops.reduce_sum(
-          array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
-  init_prob_estimate = math_ops.truediv(
-      num_examples_per_class_seen,
-      math_ops.reduce_sum(num_examples_per_class_seen))
-  dist = math_ops.cast(init_prob_estimate, dtypes.float32)
-  return num_examples_per_class_seen, dist
-
-
-def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
-  """Calculates the acceptance probabilities and mixing ratio.
-
-  In this case, we assume that we can *either* sample from the original data
-  distribution with probability `m`, or sample from a reshaped distribution
-  that comes from rejection sampling on the original distribution. This
-  rejection sampling is done on a per-class basis, with `a_i` representing the
-  probability of accepting data from class `i`.
 
   If we try to minimize the amount of data rejected, we get the following:
 
@@ -312,8 +274,6 @@ def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
 
   m = M_min
 
-  See the docstring for `_calculate_acceptance_probs` for more details.
-
   Args:
     initial_probs: A Tensor of the initial probability distribution, given or
       estimated.