From: A. Unique TensorFlower Date: Fri, 9 Feb 2018 00:47:58 +0000 (-0800) Subject: Extended the Halton sequences to support randomization. Implemented the randomization... X-Git-Tag: upstream/v1.7.0~31^2~844 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1fe6b129be3f7f8603b4cf36a6f39493b19f561e;p=platform%2Fupstream%2Ftensorflow.git Extended the Halton sequences to support randomization. Implemented the randomization scheme described in arXiv:1706.02808. PiperOrigin-RevId: 185073515 --- diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py index 0a85862..c516ce4 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py @@ -36,29 +36,35 @@ class HaltonSequenceTest(test.TestCase): def test_known_values_small_bases(self): with self.test_session(): - # The first five elements of the Halton sequence with base 2 and 3 + # The first five elements of the non-randomized Halton sequence + # with base 2 and 3. expected = np.array(((1. / 2, 1. / 3), (1. / 4, 2. / 3), (3. / 4, 1. / 9), (1. / 8, 4. / 9), (5. / 8, 7. / 9)), dtype=np.float32) - sample = halton.sample(2, num_samples=5) + sample = halton.sample(2, num_results=5, randomized=False) self.assertAllClose(expected, sample.eval(), rtol=1e-6) - def test_sample_indices(self): + def test_sequence_indices(self): + """Tests access of sequence elements by index.""" with self.test_session(): dim = 5 indices = math_ops.range(10, dtype=dtypes.int32) - sample_direct = halton.sample(dim, num_samples=10) - sample_from_indices = halton.sample(dim, sample_indices=indices) + sample_direct = halton.sample(dim, num_results=10, randomized=False) + sample_from_indices = halton.sample(dim, sequence_indices=indices, + randomized=False) self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), rtol=1e-6) def test_dtypes_works_correctly(self): + """Tests that all supported dtypes work without error.""" with self.test_session(): dim = 3 - sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) - sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) + sample_float32 = halton.sample(dim, num_results=10, dtype=dtypes.float32, + seed=11) + sample_float64 = halton.sample(dim, num_results=10, dtype=dtypes.float64, + seed=21) self.assertEqual(sample_float32.eval().dtype, np.float32) self.assertEqual(sample_float64.eval().dtype, np.float64) @@ -79,7 +85,8 @@ class HaltonSequenceTest(test.TestCase): p = normal_lib.Normal(loc=mu_p, scale=sigma_p) q = normal_lib.Normal(loc=mu_q, scale=sigma_q) - cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) + cdf_sample = halton.sample(2, num_results=n, dtype=dtypes.float64, + seed=1729) q_sample = q.quantile(cdf_sample) # Compute E_p[X]. @@ -90,7 +97,7 @@ class HaltonSequenceTest(test.TestCase): # Compute E_p[X^2]. e_x2 = mc.expectation_importance_sampler( f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) + seed=1412) stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) # Keep the tolerance levels the same as in monte_carlo_test.py. @@ -100,10 +107,10 @@ class HaltonSequenceTest(test.TestCase): def test_docstring_example(self): # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 + num_results = 1000 dim = 3 with self.test_session(): - sample = halton.sample(dim, num_samples=num_samples) + sample = halton.sample(dim, num_results=num_results, seed=127) # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional # hypercube. @@ -115,16 +122,76 @@ class HaltonSequenceTest(test.TestCase): # Produces a relative absolute error of 1.7%. self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sequence_indices argument can be used to do this. - sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, - dtype=dtypes.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) + sequence_indices = math_ops.range(start=1000, limit=1000 + num_results, + dtype=dtypes.int32) + sample_leaped = halton.sample(dim, sequence_indices=sequence_indices, + seed=111217) integral_leaped = math_ops.reduce_mean( math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) - self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) + self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.01) + + def test_randomized_qmc_basic(self): + """Tests the randomization of the Halton sequences.""" + # This test is identical to the example given in Owen (2017), Figure 5. + + dim = 20 + num_results = 5000 + replica = 10 + + with self.test_session(): + sample = halton.sample(dim, num_results=num_results, seed=121117) + f = math_ops.reduce_mean(math_ops.reduce_sum(sample, axis=1) ** 2) + values = [f.eval() for _ in range(replica)] + self.assertAllClose(np.mean(values), 101.6667, atol=np.std(values) * 2) + + def test_partial_sum_func_qmc(self): + """Tests the QMC evaluation of (x_j + x_{j+1} ...+x_{n})^2. + + A good test of QMC is provided by the function: + + f(x_1,..x_n, x_{n+1}, ..., x_{n+m}) = (x_{n+1} + ... x_{n+m} - m / 2)^2 + + with the coordinates taking values in the unit interval. The mean and + variance of this function (with the uniform distribution over the + unit-hypercube) is exactly calculable: + + = m / 12, Var(f) = m (5m - 3) / 360 + + The purpose of the "shift" (if n > 0) in the coordinate dependence of the + function is to provide a test for Halton sequence which exhibit more + dependence in the higher axes. + + This test confirms that the mean squared error of RQMC estimation falls + as O(N^(2-e)) for any e>0. + """ + + n, m = 10, 10 + dim = n + m + num_results_lo, num_results_hi = 1000, 10000 + replica = 20 + true_mean = m / 12. + + def func_estimate(x): + return math_ops.reduce_mean( + (math_ops.reduce_sum(x[:, -m:], axis=-1) - m / 2.0) ** 2) + + with self.test_session(): + sample_lo = halton.sample(dim, num_results=num_results_lo, seed=1925) + sample_hi = halton.sample(dim, num_results=num_results_hi, seed=898128) + f_lo, f_hi = func_estimate(sample_lo), func_estimate(sample_hi) + + estimates = np.array([(f_lo.eval(), f_hi.eval()) for _ in range(replica)]) + var_lo, var_hi = np.mean((estimates - true_mean) ** 2, axis=0) + + # Expect that the variance scales as N^2 so var_hi / var_lo ~ k / 10^2 + # with k a fudge factor accounting for the residual N dependence + # of the QMC error and the sampling error. + log_rel_err = np.log(100 * var_hi / var_lo) + self.assertAllClose(log_rel_err, 0.0, atol=1.2) if __name__ == '__main__': diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py index 8cabf18..57900d6 100644 --- a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py @@ -26,8 +26,9 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops - +from tensorflow.python.ops import random_ops __all__ = [ 'sample', @@ -39,32 +40,45 @@ __all__ = [ _MAX_DIMENSION = 1000 -def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): - r"""Returns a sample from the `m` dimensional Halton sequence. +def sample(dim, + num_results=None, + sequence_indices=None, + dtype=None, + randomized=True, + seed=None, + name=None): + r"""Returns a sample from the `dim` dimensional Halton sequence. Warning: The sequence elements take values only between 0 and 1. Care must be taken to appropriately transform the domain of a function if it differs from the unit cube before evaluating integrals using Halton samples. It is also - important to remember that quasi-random numbers are not a replacement for - pseudo-random numbers in every context. Quasi random numbers are completely - deterministic and typically have significant negative autocorrelation (unless - randomized). + important to remember that quasi-random numbers without randomization are not + a replacement for pseudo-random numbers in every context. Quasi random numbers + are completely deterministic and typically have significant negative + autocorrelation unless randomization is used. Computes the members of the low discrepancy Halton sequence in dimension - `dim`. The d-dimensional sequence takes values in the unit hypercube in d - dimensions. Currently, only dimensions up to 1000 are supported. The prime - base for the `k`-th axes is the k-th prime starting from 2. For example, - if dim = 3, then the bases will be [2, 3, 5] respectively and the first - element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete - description of the Halton sequences see: + `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in + `dim` dimensions. Currently, only dimensions up to 1000 are supported. The + prime base for the k-th axes is the k-th prime starting from 2. For example, + if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first + element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more + complete description of the Halton sequences see: https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences and their applications see: https://en.wikipedia.org/wiki/Low-discrepancy_sequence. - The user must supply either `num_samples` or `sample_indices` but not both. + If `randomized` is true, this function produces a scrambled version of the + Halton sequence introduced by Owen in arXiv:1706.02808. For the advantages of + randomization of low discrepancy sequences see: + https://en.wikipedia.org/wiki/Quasi-Monte_Carlo_method#Randomization_of_quasi-Monte_Carlo + + The number of samples produced is controlled by the `num_results` and + `sequence_indices` parameters. The user must supply either `num_results` or + `sequence_indices` but not both. The former is the number of samples to produce starting from the first - element. If `sample_indices` is given instead, the specified elements of - the sequence are generated. For example, sample_indices=tf.range(10) is + element. If `sequence_indices` is given instead, the specified elements of + the sequence are generated. For example, sequence_indices=tf.range(10) is equivalent to specifying n=10. Example Use: @@ -73,9 +87,9 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): bf = tf.contrib.bayesflow # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 + num_results = 1000 dim = 3 - sample = bf.halton_sequence.sample(dim, num_samples=num_samples) + sample = bf.halton_sequence.sample(dim, num_results=num_results, seed=127) # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional # hypercube. @@ -89,12 +103,13 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): print ("Estimated: %f, True Value: %f" % values) # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. + # thousand samples. The sequence_indices argument can be used to do this. - sample_indices = tf.range(start=1000, limit=1000 + num_samples, - dtype=tf.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) + sequence_indices = tf.range(start=1000, limit=1000 + num_results, + dtype=tf.int32) + sample_leaped = halton.sample(dim, sequence_indices=sequence_indices, + seed=111217) integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, axis=-1)) @@ -107,51 +122,57 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): Args: dim: Positive Python `int` representing each sample's `event_size.` Must not be greater than 1000. - num_samples: (Optional) positive Python `int`. The number of samples to - generate. Either this parameter or sample_indices must be specified but + num_results: (Optional) positive Python `int`. The number of samples to + generate. Either this parameter or sequence_indices must be specified but not both. If this parameter is None, then the behaviour is determined by - the `sample_indices`. - sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements - of the sequence to compute specified by their position in the sequence. - The entries index into the Halton sequence starting with 0 and hence, - must be whole numbers. For example, sample_indices=[0, 5, 6] will produce - the first, sixth and seventh elements of the sequence. If this parameter - is None, then the `num_samples` parameter must be specified which gives - the number of desired samples starting from the first sample. + the `sequence_indices`. + sequence_indices: (Optional) `Tensor` of dtype int32 and rank 1. The + elements of the sequence to compute specified by their position in the + sequence. The entries index into the Halton sequence starting with 0 and + hence, must be whole numbers. For example, sequence_indices=[0, 5, 6] will + produce the first, sixth and seventh elements of the sequence. If this + parameter is None, then the `num_results` parameter must be specified + which gives the number of desired samples starting from the first sample. dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. Default is `float32`. + randomized: (Optional) bool indicating whether to produce a randomized + Halton sequence. If True, applies the randomization described in + Owen (2017) [arXiv:1706.02808]. + seed: (Optional) Python integer to seed the random number generator. Only + used if `randomized` is True. If not supplied and `randomized` is True, + no seed is set. name: (Optional) Python `str` describing ops managed by this function. If not supplied the name of this function is used. Returns: halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype - and `shape` `[num_samples, dim]` if `num_samples` was specified or shape - `[s, dim]` where s is the size of `sample_indices` if `sample_indices` + and `shape` `[num_results, dim]` if `num_results` was specified or shape + `[s, dim]` where s is the size of `sequence_indices` if `sequence_indices` were specified. Raises: - ValueError: if both `sample_indices` and `num_samples` were specified or + ValueError: if both `sequence_indices` and `num_results` were specified or if dimension `dim` is less than 1 or greater than 1000. """ if dim < 1 or dim > _MAX_DIMENSION: raise ValueError( 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, dim)) - if (num_samples is None) == (sample_indices is None): - raise ValueError('Either `num_samples` or `sample_indices` must be' + if (num_results is None) == (sequence_indices is None): + raise ValueError('Either `num_results` or `sequence_indices` must be' ' specified but not both.') dtype = dtype or dtypes.float32 if not dtype.is_floating: raise ValueError('dtype must be of `float`-type') - with ops.name_scope(name, 'sample', values=[sample_indices]): + with ops.name_scope(name, 'sample', values=[sequence_indices]): # Here and in the following, the shape layout is as follows: # [sample dimension, event dimension, coefficient dimension]. # The coefficient dimension is an intermediate axes which will hold the # weights of the starting integer when expressed in the (prime) base for # an event dimension. - indices = _get_indices(num_samples, sample_indices, dtype) + indices = _get_indices(num_results, sequence_indices, dtype) radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), @@ -176,11 +197,74 @@ def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): weights = radixes ** capped_exponents coeffs = math_ops.floor_div(indices, weights) coeffs *= 1 - math_ops.cast(weight_mask, dtype) - coeffs = (coeffs % radixes) / radixes - return math_ops.reduce_sum(coeffs / weights, axis=-1) + coeffs %= radixes + if not randomized: + coeffs /= radixes + return math_ops.reduce_sum(coeffs / weights, axis=-1) + coeffs = _randomize(coeffs, radixes, seed=seed) + coeffs *= 1 - math_ops.cast(weight_mask, dtype) + coeffs /= radixes + base_values = math_ops.reduce_sum(coeffs / weights, axis=-1) + + # The randomization used in Owen (2017) does not leave 0 invariant. While + # we have accounted for the randomization of the first `max_size_by_axes` + # coefficients, we still need to correct for the trailing zeros. Luckily, + # this is equivalent to adding a uniform random value scaled so the first + # `max_size_by_axes` coefficients are zero. The following statements perform + # this correction. + zero_correction = random_ops.random_uniform([dim, 1], seed=seed, + dtype=dtype) + zero_correction /= (radixes ** max_sizes_by_axes) + return base_values + array_ops.reshape(zero_correction, [-1]) + + +def _randomize(coeffs, radixes, seed=None): + """Applies the Owen randomization to the coefficients.""" + given_dtype = coeffs.dtype + coeffs = math_ops.to_int32(coeffs) + num_coeffs = array_ops.shape(coeffs)[-1] + radixes = array_ops.reshape(math_ops.to_int32(radixes), [-1]) + perms = _get_permutations(num_coeffs, radixes, seed=seed) + perms = array_ops.reshape(perms, [-1]) + radix_sum = math_ops.reduce_sum(radixes) + radix_offsets = array_ops.reshape(math_ops.cumsum(radixes, exclusive=True), + [-1, 1]) + offsets = radix_offsets + math_ops.range(num_coeffs) * radix_sum + permuted_coeffs = array_ops.gather(perms, coeffs + offsets) + return math_ops.cast(permuted_coeffs, dtype=given_dtype) + + +def _get_permutations(num_results, dims, seed=None): + """Uniform iid sample from the space of permutations. + + Draws a sample of size `num_results` from the group of permutations of degrees + specified by the `dims` tensor. These are packed together into one tensor + such that each row is one sample from each of the dimensions in `dims`. For + example, if dims = [2,3] and num_results = 2, the result is a tensor of shape + [2, 2 + 3] and the first row of the result might look like: + [1, 0, 2, 0, 1]. The first two elements are a permutation over 2 elements + while the next three are a permutation over 3 elements. + Args: + num_results: A positive scalar `Tensor` of integral type. The number of + draws from the discrete uniform distribution over the permutation groups. + dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the + permutation groups from which to sample. + seed: (Optional) Python integer to seed the random number generator. -def _get_indices(n, sample_indices, dtype, name=None): + Returns: + permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same + dtype as `dims`. + """ + sample_range = math_ops.range(num_results) + def generate_one(d): + fn = lambda _: random_ops.random_shuffle(math_ops.range(d), seed=seed) + return functional_ops.map_fn(fn, sample_range) + return array_ops.concat([generate_one(d) for d in array_ops.unstack(dims)], + axis=-1) + + +def _get_indices(n, sequence_indices, dtype, name=None): """Generates starting points for the Halton sequence procedure. The k'th element of the sequence is generated starting from a positive integer @@ -191,10 +275,10 @@ def _get_indices(n, sample_indices, dtype, name=None): Args: n: Positive `int`. The number of samples to generate. If this - parameter is supplied, then `sample_indices` should be None. - sample_indices: `Tensor` of dtype int32 and rank 1. The entries + parameter is supplied, then `sequence_indices` should be None. + sequence_indices: `Tensor` of dtype int32 and rank 1. The entries index into the Halton sequence starting with 0 and hence, must be whole - numbers. For example, sample_indices=[0, 5, 6] will produce the first, + numbers. For example, sequence_indices=[0, 5, 6] will produce the first, sixth and seventh elements of the sequence. If this parameter is not None then `n` must be None. dtype: The dtype of the sample. One of `float32` or `float64`. @@ -204,14 +288,14 @@ def _get_indices(n, sample_indices, dtype, name=None): Returns: indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. """ - with ops.name_scope(name, 'get_indices', [n, sample_indices]): - if sample_indices is None: - sample_indices = math_ops.range(n, dtype=dtype) + with ops.name_scope(name, '_get_indices', [n, sequence_indices]): + if sequence_indices is None: + sequence_indices = math_ops.range(n, dtype=dtype) else: - sample_indices = math_ops.cast(sample_indices, dtype) + sequence_indices = math_ops.cast(sequence_indices, dtype) # Shift the indices so they are 1 based. - indices = sample_indices + 1 + indices = sequence_indices + 1 # Reshape to make space for the event dimension and the place value # coefficients. @@ -261,4 +345,5 @@ def _primes_less_than(n): _PRIMES = _primes_less_than(7919+1) + assert len(_PRIMES) == _MAX_DIMENSION