Extended the Halton sequences to support randomization. Implemented the randomization...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Feb 2018 00:47:58 +0000 (16:47 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Feb 2018 00:51:47 +0000 (16:51 -0800)
PiperOrigin-RevId: 185073515

tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py
tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py

index 0a85862..c516ce4 100644 (file)
@@ -36,29 +36,35 @@ class HaltonSequenceTest(test.TestCase):
 
   def test_known_values_small_bases(self):
     with self.test_session():
-      # The first five elements of the Halton sequence with base 2 and 3
+      # The first five elements of the non-randomized Halton sequence
+      # with base 2 and 3.
       expected = np.array(((1. / 2, 1. / 3),
                            (1. / 4, 2. / 3),
                            (3. / 4, 1. / 9),
                            (1. / 8, 4. / 9),
                            (5. / 8, 7. / 9)), dtype=np.float32)
-      sample = halton.sample(2, num_samples=5)
+      sample = halton.sample(2, num_results=5, randomized=False)
       self.assertAllClose(expected, sample.eval(), rtol=1e-6)
 
-  def test_sample_indices(self):
+  def test_sequence_indices(self):
+    """Tests access of sequence elements by index."""
     with self.test_session():
       dim = 5
       indices = math_ops.range(10, dtype=dtypes.int32)
-      sample_direct = halton.sample(dim, num_samples=10)
-      sample_from_indices = halton.sample(dim, sample_indices=indices)
+      sample_direct = halton.sample(dim, num_results=10, randomized=False)
+      sample_from_indices = halton.sample(dim, sequence_indices=indices,
+                                          randomized=False)
       self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(),
                           rtol=1e-6)
 
   def test_dtypes_works_correctly(self):
+    """Tests that all supported dtypes work without error."""
     with self.test_session():
       dim = 3
-      sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32)
-      sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64)
+      sample_float32 = halton.sample(dim, num_results=10, dtype=dtypes.float32,
+                                     seed=11)
+      sample_float64 = halton.sample(dim, num_results=10, dtype=dtypes.float64,
+                                     seed=21)
       self.assertEqual(sample_float32.eval().dtype, np.float32)
       self.assertEqual(sample_float64.eval().dtype, np.float64)
 
@@ -79,7 +85,8 @@ class HaltonSequenceTest(test.TestCase):
       p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
       q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
 
-      cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64)
+      cdf_sample = halton.sample(2, num_results=n, dtype=dtypes.float64,
+                                 seed=1729)
       q_sample = q.quantile(cdf_sample)
 
       # Compute E_p[X].
@@ -90,7 +97,7 @@ class HaltonSequenceTest(test.TestCase):
       # Compute E_p[X^2].
       e_x2 = mc.expectation_importance_sampler(
           f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample,
-          seed=42)
+          seed=1412)
 
       stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x))
       # Keep the tolerance levels the same as in monte_carlo_test.py.
@@ -100,10 +107,10 @@ class HaltonSequenceTest(test.TestCase):
 
   def test_docstring_example(self):
     # Produce the first 1000 members of the Halton sequence in 3 dimensions.
-    num_samples = 1000
+    num_results = 1000
     dim = 3
     with self.test_session():
-      sample = halton.sample(dim, num_samples=num_samples)
+      sample = halton.sample(dim, num_results=num_results, seed=127)
 
       # Evaluate the integral of x_1 * x_2^2 * x_3^3  over the three dimensional
       # hypercube.
@@ -115,16 +122,76 @@ class HaltonSequenceTest(test.TestCase):
       # Produces a relative absolute error of 1.7%.
       self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02)
 
-    # Now skip the first 1000 samples and recompute the integral with the next
-    # thousand samples. The sample_indices argument can be used to do this.
+      # Now skip the first 1000 samples and recompute the integral with the next
+      # thousand samples. The sequence_indices argument can be used to do this.
 
-      sample_indices = math_ops.range(start=1000, limit=1000 + num_samples,
-                                      dtype=dtypes.int32)
-      sample_leaped = halton.sample(dim, sample_indices=sample_indices)
+      sequence_indices = math_ops.range(start=1000, limit=1000 + num_results,
+                                        dtype=dtypes.int32)
+      sample_leaped = halton.sample(dim, sequence_indices=sequence_indices,
+                                    seed=111217)
 
       integral_leaped = math_ops.reduce_mean(
           math_ops.reduce_prod(sample_leaped ** powers, axis=-1))
-      self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001)
+      self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.01)
+
+  def test_randomized_qmc_basic(self):
+    """Tests the randomization of the Halton sequences."""
+    # This test is identical to the example given in Owen (2017), Figure 5.
+
+    dim = 20
+    num_results = 5000
+    replica = 10
+
+    with self.test_session():
+      sample = halton.sample(dim, num_results=num_results, seed=121117)
+      f = math_ops.reduce_mean(math_ops.reduce_sum(sample, axis=1) ** 2)
+      values = [f.eval() for _ in range(replica)]
+      self.assertAllClose(np.mean(values), 101.6667, atol=np.std(values) * 2)
+
+  def test_partial_sum_func_qmc(self):
+    """Tests the QMC evaluation of (x_j + x_{j+1} ...+x_{n})^2.
+
+    A good test of QMC is provided by the function:
+
+      f(x_1,..x_n, x_{n+1}, ..., x_{n+m}) = (x_{n+1} + ... x_{n+m} - m / 2)^2
+
+    with the coordinates taking values in the unit interval. The mean and
+    variance of this function (with the uniform distribution over the
+    unit-hypercube) is exactly calculable:
+
+      <f> = m / 12, Var(f) = m (5m - 3) / 360
+
+    The purpose of the "shift" (if n > 0) in the coordinate dependence of the
+    function is to provide a test for Halton sequence which exhibit more
+    dependence in the higher axes.
+
+    This test confirms that the mean squared error of RQMC estimation falls
+    as O(N^(2-e)) for any e>0.
+    """
+
+    n, m = 10, 10
+    dim = n + m
+    num_results_lo, num_results_hi = 1000, 10000
+    replica = 20
+    true_mean = m / 12.
+
+    def func_estimate(x):
+      return math_ops.reduce_mean(
+          (math_ops.reduce_sum(x[:, -m:], axis=-1) - m / 2.0) ** 2)
+
+    with self.test_session():
+      sample_lo = halton.sample(dim, num_results=num_results_lo, seed=1925)
+      sample_hi = halton.sample(dim, num_results=num_results_hi, seed=898128)
+      f_lo, f_hi = func_estimate(sample_lo), func_estimate(sample_hi)
+
+      estimates = np.array([(f_lo.eval(), f_hi.eval()) for _ in range(replica)])
+      var_lo, var_hi = np.mean((estimates - true_mean) ** 2, axis=0)
+
+      # Expect that the variance scales as N^2 so var_hi / var_lo ~ k / 10^2
+      # with k a fudge factor accounting for the residual N dependence
+      # of the QMC error and the sampling error.
+      log_rel_err = np.log(100 * var_hi / var_lo)
+      self.assertAllClose(log_rel_err, 0.0, atol=1.2)
 
 
 if __name__ == '__main__':
index 8cabf18..57900d6 100644 (file)
@@ -26,8 +26,9 @@ import numpy as np
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import math_ops
-
+from tensorflow.python.ops import random_ops
 
 __all__ = [
     'sample',
@@ -39,32 +40,45 @@ __all__ = [
 _MAX_DIMENSION = 1000
 
 
-def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None):
-  r"""Returns a sample from the `m` dimensional Halton sequence.
+def sample(dim,
+           num_results=None,
+           sequence_indices=None,
+           dtype=None,
+           randomized=True,
+           seed=None,
+           name=None):
+  r"""Returns a sample from the `dim` dimensional Halton sequence.
 
   Warning: The sequence elements take values only between 0 and 1. Care must be
   taken to appropriately transform the domain of a function if it differs from
   the unit cube before evaluating integrals using Halton samples. It is also
-  important to remember that quasi-random numbers are not a replacement for
-  pseudo-random numbers in every context. Quasi random numbers are completely
-  deterministic and typically have significant negative autocorrelation (unless
-  randomized).
+  important to remember that quasi-random numbers without randomization are not
+  a replacement for pseudo-random numbers in every context. Quasi random numbers
+  are completely deterministic and typically have significant negative
+  autocorrelation unless randomization is used.
 
   Computes the members of the low discrepancy Halton sequence in dimension
-  `dim`. The d-dimensional sequence takes values in the unit hypercube in d
-  dimensions. Currently, only dimensions up to 1000 are supported. The prime
-  base for the `k`-th axes is the k-th prime starting from 2. For example,
-  if dim = 3, then the bases will be [2, 3, 5] respectively and the first
-  element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete
-  description of the Halton sequences see:
+  `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in
+  `dim` dimensions. Currently, only dimensions up to 1000 are supported. The
+  prime base for the k-th axes is the k-th prime starting from 2. For example,
+  if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first
+  element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more
+  complete description of the Halton sequences see:
   https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences
   and their applications see:
   https://en.wikipedia.org/wiki/Low-discrepancy_sequence.
 
-  The user must supply either `num_samples` or `sample_indices` but not both.
+  If `randomized` is true, this function produces a scrambled version of the
+  Halton sequence introduced by Owen in arXiv:1706.02808. For the advantages of
+  randomization of low discrepancy sequences see:
+  https://en.wikipedia.org/wiki/Quasi-Monte_Carlo_method#Randomization_of_quasi-Monte_Carlo
+
+  The number of samples produced is controlled by the `num_results` and
+  `sequence_indices` parameters. The user must supply either `num_results` or
+  `sequence_indices` but not both.
   The former is the number of samples to produce starting from the first
-  element. If `sample_indices` is given instead, the specified elements of
-  the sequence are generated. For example, sample_indices=tf.range(10) is
+  element. If `sequence_indices` is given instead, the specified elements of
+  the sequence are generated. For example, sequence_indices=tf.range(10) is
   equivalent to specifying n=10.
 
   Example Use:
@@ -73,9 +87,9 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None):
   bf = tf.contrib.bayesflow
 
   # Produce the first 1000 members of the Halton sequence in 3 dimensions.
-  num_samples = 1000
+  num_results = 1000
   dim = 3
-  sample = bf.halton_sequence.sample(dim, num_samples=num_samples)
+  sample = bf.halton_sequence.sample(dim, num_results=num_results, seed=127)
 
   # Evaluate the integral of x_1 * x_2^2 * x_3^3  over the three dimensional
   # hypercube.
@@ -89,12 +103,13 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None):
   print ("Estimated: %f, True Value: %f" % values)
 
   # Now skip the first 1000 samples and recompute the integral with the next
-  # thousand samples. The sample_indices argument can be used to do this.
+  # thousand samples. The sequence_indices argument can be used to do this.
 
 
-  sample_indices = tf.range(start=1000, limit=1000 + num_samples,
-                            dtype=tf.int32)
-  sample_leaped = halton.sample(dim, sample_indices=sample_indices)
+  sequence_indices = tf.range(start=1000, limit=1000 + num_results,
+                              dtype=tf.int32)
+  sample_leaped = halton.sample(dim, sequence_indices=sequence_indices,
+                                seed=111217)
 
   integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers,
                                                   axis=-1))
@@ -107,51 +122,57 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None):
   Args:
     dim: Positive Python `int` representing each sample's `event_size.` Must
       not be greater than 1000.
-    num_samples: (Optional) positive Python `int`. The number of samples to
-      generate. Either this parameter or sample_indices must be specified but
+    num_results: (Optional) positive Python `int`. The number of samples to
+      generate. Either this parameter or sequence_indices must be specified but
       not both. If this parameter is None, then the behaviour is determined by
-      the `sample_indices`.
-    sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements
-      of the sequence to compute specified by their position in the sequence.
-      The entries index into the Halton sequence starting with 0 and hence,
-      must be whole numbers. For example, sample_indices=[0, 5, 6] will produce
-      the first, sixth and seventh elements of the sequence. If this parameter
-      is None, then the `num_samples` parameter must be specified which gives
-      the number of desired samples starting from the first sample.
+      the `sequence_indices`.
+    sequence_indices: (Optional) `Tensor` of dtype int32 and rank 1. The
+      elements of the sequence to compute specified by their position in the
+      sequence. The entries index into the Halton sequence starting with 0 and
+      hence, must be whole numbers. For example, sequence_indices=[0, 5, 6] will
+      produce the first, sixth and seventh elements of the sequence. If this
+      parameter is None, then the `num_results` parameter must be specified
+      which gives the number of desired samples starting from the first sample.
     dtype: (Optional) The dtype of the sample. One of `float32` or `float64`.
       Default is `float32`.
+    randomized: (Optional) bool indicating whether to produce a randomized
+      Halton sequence. If True, applies the randomization described in
+      Owen (2017) [arXiv:1706.02808].
+    seed: (Optional) Python integer to seed the random number generator. Only
+      used if `randomized` is True. If not supplied and `randomized` is True,
+      no seed is set.
     name:  (Optional) Python `str` describing ops managed by this function. If
     not supplied the name of this function is used.
 
   Returns:
     halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype
-    and `shape` `[num_samples, dim]` if `num_samples` was specified or shape
-    `[s, dim]` where s is the size of `sample_indices` if `sample_indices`
+    and `shape` `[num_results, dim]` if `num_results` was specified or shape
+    `[s, dim]` where s is the size of `sequence_indices` if `sequence_indices`
     were specified.
 
   Raises:
-    ValueError: if both `sample_indices` and `num_samples` were specified or
+    ValueError: if both `sequence_indices` and `num_results` were specified or
     if dimension `dim` is less than 1 or greater than 1000.
   """
   if dim < 1 or dim > _MAX_DIMENSION:
     raise ValueError(
         'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION,
                                                                  dim))
-  if (num_samples is None) == (sample_indices is None):
-    raise ValueError('Either `num_samples` or `sample_indices` must be'
+  if (num_results is None) == (sequence_indices is None):
+    raise ValueError('Either `num_results` or `sequence_indices` must be'
                      ' specified but not both.')
 
   dtype = dtype or dtypes.float32
   if not dtype.is_floating:
     raise ValueError('dtype must be of `float`-type')
 
-  with ops.name_scope(name, 'sample', values=[sample_indices]):
+  with ops.name_scope(name, 'sample', values=[sequence_indices]):
     # Here and in the following, the shape layout is as follows:
     # [sample dimension, event dimension, coefficient dimension].
     # The coefficient dimension is an intermediate axes which will hold the
     # weights of the starting integer when expressed in the (prime) base for
     # an event dimension.
-    indices = _get_indices(num_samples, sample_indices, dtype)
+    indices = _get_indices(num_results, sequence_indices, dtype)
     radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1])
 
     max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices),
@@ -176,11 +197,74 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None):
     weights = radixes ** capped_exponents
     coeffs = math_ops.floor_div(indices, weights)
     coeffs *= 1 - math_ops.cast(weight_mask, dtype)
-    coeffs = (coeffs % radixes) / radixes
-    return math_ops.reduce_sum(coeffs / weights, axis=-1)
+    coeffs %= radixes
+    if not randomized:
+      coeffs /= radixes
+      return math_ops.reduce_sum(coeffs / weights, axis=-1)
+    coeffs = _randomize(coeffs, radixes, seed=seed)
+    coeffs *= 1 - math_ops.cast(weight_mask, dtype)
+    coeffs /= radixes
+    base_values = math_ops.reduce_sum(coeffs / weights, axis=-1)
+
+    # The randomization used in Owen (2017) does not leave 0 invariant. While
+    # we have accounted for the randomization of the first `max_size_by_axes`
+    # coefficients, we still need to correct for the trailing zeros. Luckily,
+    # this is equivalent to adding a uniform random value scaled so the first
+    # `max_size_by_axes` coefficients are zero. The following statements perform
+    # this correction.
+    zero_correction = random_ops.random_uniform([dim, 1], seed=seed,
+                                                dtype=dtype)
+    zero_correction /= (radixes ** max_sizes_by_axes)
+    return base_values + array_ops.reshape(zero_correction, [-1])
+
+
+def _randomize(coeffs, radixes, seed=None):
+  """Applies the Owen randomization to the coefficients."""
+  given_dtype = coeffs.dtype
+  coeffs = math_ops.to_int32(coeffs)
+  num_coeffs = array_ops.shape(coeffs)[-1]
+  radixes = array_ops.reshape(math_ops.to_int32(radixes), [-1])
+  perms = _get_permutations(num_coeffs, radixes, seed=seed)
+  perms = array_ops.reshape(perms, [-1])
+  radix_sum = math_ops.reduce_sum(radixes)
+  radix_offsets = array_ops.reshape(math_ops.cumsum(radixes, exclusive=True),
+                                    [-1, 1])
+  offsets = radix_offsets + math_ops.range(num_coeffs) * radix_sum
+  permuted_coeffs = array_ops.gather(perms, coeffs + offsets)
+  return math_ops.cast(permuted_coeffs, dtype=given_dtype)
+
+
+def _get_permutations(num_results, dims, seed=None):
+  """Uniform iid sample from the space of permutations.
+
+  Draws a sample of size `num_results` from the group of permutations of degrees
+  specified by the `dims` tensor. These are packed together into one tensor
+  such that each row is one sample from each of the dimensions in `dims`. For
+  example, if dims = [2,3] and num_results = 2, the result is a tensor of shape
+  [2, 2 + 3] and the first row of the result might look like:
+  [1, 0, 2, 0, 1]. The first two elements are a permutation over 2 elements
+  while the next three are a permutation over 3 elements.
 
+  Args:
+    num_results: A positive scalar `Tensor` of integral type. The number of
+      draws from the discrete uniform distribution over the permutation groups.
+    dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the
+      permutation groups from which to sample.
+    seed: (Optional) Python integer to seed the random number generator.
 
-def _get_indices(n, sample_indices, dtype, name=None):
+  Returns:
+    permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same
+    dtype as `dims`.
+  """
+  sample_range = math_ops.range(num_results)
+  def generate_one(d):
+    fn = lambda _: random_ops.random_shuffle(math_ops.range(d), seed=seed)
+    return functional_ops.map_fn(fn, sample_range)
+  return array_ops.concat([generate_one(d) for d in array_ops.unstack(dims)],
+                          axis=-1)
+
+
+def _get_indices(n, sequence_indices, dtype, name=None):
   """Generates starting points for the Halton sequence procedure.
 
   The k'th element of the sequence is generated starting from a positive integer
@@ -191,10 +275,10 @@ def _get_indices(n, sample_indices, dtype, name=None):
 
   Args:
     n: Positive `int`. The number of samples to generate. If this
-      parameter is supplied, then `sample_indices` should be None.
-    sample_indices: `Tensor` of dtype int32 and rank 1. The entries
+      parameter is supplied, then `sequence_indices` should be None.
+    sequence_indices: `Tensor` of dtype int32 and rank 1. The entries
       index into the Halton sequence starting with 0 and hence, must be whole
-      numbers. For example, sample_indices=[0, 5, 6] will produce the first,
+      numbers. For example, sequence_indices=[0, 5, 6] will produce the first,
       sixth and seventh elements of the sequence. If this parameter is not None
       then `n` must be None.
     dtype: The dtype of the sample. One of `float32` or `float64`.
@@ -204,14 +288,14 @@ def _get_indices(n, sample_indices, dtype, name=None):
   Returns:
     indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`.
   """
-  with ops.name_scope(name, 'get_indices', [n, sample_indices]):
-    if sample_indices is None:
-      sample_indices = math_ops.range(n, dtype=dtype)
+  with ops.name_scope(name, '_get_indices', [n, sequence_indices]):
+    if sequence_indices is None:
+      sequence_indices = math_ops.range(n, dtype=dtype)
     else:
-      sample_indices = math_ops.cast(sample_indices, dtype)
+      sequence_indices = math_ops.cast(sequence_indices, dtype)
 
     # Shift the indices so they are 1 based.
-    indices = sample_indices + 1
+    indices = sequence_indices + 1
 
     # Reshape to make space for the event dimension and the place value
     # coefficients.
@@ -261,4 +345,5 @@ def _primes_less_than(n):
 
 _PRIMES = _primes_less_than(7919+1)
 
+
 assert len(_PRIMES) == _MAX_DIMENSION