)
cuda_py_test(
+ name = "mcmc_diagnostics_test",
+ size = "small",
+ srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"],
+ additional_deps = [
+ ":bayesflow_py",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/distributions:distributions_py",
+ "//tensorflow/python/ops/distributions",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_seed",
+ ],
+)
+
+cuda_py_test(
name = "monte_carlo_test",
size = "small",
srcs = ["python/kernel_tests/monte_carlo_test.py"],
from tensorflow.contrib.bayesflow.python.ops import halton_sequence
from tensorflow.contrib.bayesflow.python.ops import hmc
from tensorflow.contrib.bayesflow.python.ops import layers
+from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
from tensorflow.contrib.bayesflow.python.ops import optimizers
'hmc',
'layers',
'metropolis_hastings',
+ 'mcmc_diagnostics',
'monte_carlo',
'optimizers',
'special_math',
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MCMC diagnostic utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+rng = np.random.RandomState(42)
+
+
+class _PotentialScaleReductionTest(object):
+
+ @property
+ def use_static_shape(self):
+ raise NotImplementedError(
+ "Subclass failed to impliment `use_static_shape`.")
+
+ def testListOfStatesWhereFirstPassesSecondFails(self):
+ """Simple test showing API with two states. Read first!."""
+ n_samples = 1000
+
+ # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass.
+ state_0 = rng.randn(n_samples, 2)
+
+ # state_1 is three 4-variate chains taken from Normal(0, 1) that have been
+ # shifted. Since every chain is shifted, they are not the same, and the
+ # test should fail.
+ offset = np.array([1., -1., 2.]).reshape(3, 1)
+ state_1 = rng.randn(n_samples, 3, 4) + offset
+
+ rhat = mcmc_diagnostics.potential_scale_reduction(
+ state=[state_0, state_1], independent_chain_ndims=1)
+
+ self.assertIsInstance(rhat, list)
+ with self.test_session() as sess:
+ rhat_0_, rhat_1_ = sess.run(rhat)
+
+ # r_hat_0 should be close to 1, meaning test is passed.
+ self.assertAllEqual((), rhat_0_.shape)
+ self.assertAllClose(1., rhat_0_, rtol=0.02)
+
+ # r_hat_1 should be greater than 1.2, meaning test has failed.
+ self.assertAllEqual((4,), rhat_1_.shape)
+ self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2)
+
+ def check_results(self, state_, independent_chain_shape, should_pass):
+ sample_ndims = 1
+ independent_chain_ndims = len(independent_chain_shape)
+ with self.test_session():
+ state = array_ops.placeholder_with_default(
+ input=state_, shape=state_.shape if self.use_static_shape else None)
+
+ rhat = mcmc_diagnostics.potential_scale_reduction(
+ state, independent_chain_ndims=independent_chain_ndims)
+
+ if self.use_static_shape:
+ self.assertAllEqual(
+ state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape)
+
+ rhat_ = rhat.eval()
+ if should_pass:
+ self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02)
+ else:
+ self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2)
+
+ def iid_normal_chains_should_pass_wrapper(self,
+ sample_shape,
+ independent_chain_shape,
+ other_shape,
+ dtype=np.float32):
+ """Check results with iid normal chains."""
+
+ state_shape = sample_shape + independent_chain_shape + other_shape
+ state_ = rng.randn(*state_shape).astype(dtype)
+
+ # The "other" dimensions do not have to be identical, just independent, so
+ # force them to not be identical.
+ if other_shape:
+ state_ *= rng.rand(*other_shape).astype(dtype)
+
+ self.check_results(state_, independent_chain_shape, should_pass=True)
+
+ def testPassingIIDNdimsAreIndependentOneOtherZero(self):
+ self.iid_normal_chains_should_pass_wrapper(
+ sample_shape=[10000], independent_chain_shape=[4], other_shape=[])
+
+ def testPassingIIDNdimsAreIndependentOneOtherOne(self):
+ self.iid_normal_chains_should_pass_wrapper(
+ sample_shape=[10000], independent_chain_shape=[3], other_shape=[7])
+
+ def testPassingIIDNdimsAreIndependentOneOtherTwo(self):
+ self.iid_normal_chains_should_pass_wrapper(
+ sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7])
+
+ def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self):
+ self.iid_normal_chains_should_pass_wrapper(
+ sample_shape=[10000],
+ independent_chain_shape=[2, 3],
+ other_shape=[5, 7],
+ dtype=np.float64)
+
+ def offset_normal_chains_should_fail_wrapper(
+ self, sample_shape, independent_chain_shape, other_shape):
+ """Check results with normal chains that are offset from each other."""
+
+ state_shape = sample_shape + independent_chain_shape + other_shape
+ state_ = rng.randn(*state_shape)
+
+ # Add a significant offset to the different (formerly iid) chains.
+ offset = np.linspace(
+ 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len(
+ sample_shape) + independent_chain_shape + [1] * len(other_shape))
+ state_ += offset
+
+ self.check_results(state_, independent_chain_shape, should_pass=False)
+
+ def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self):
+ self.offset_normal_chains_should_fail_wrapper(
+ sample_shape=[10000], independent_chain_shape=[2], other_shape=[5])
+
+
+class PotentialScaleReductionStaticTest(test.TestCase,
+ _PotentialScaleReductionTest):
+
+ @property
+ def use_static_shape(self):
+ return True
+
+ def testIndependentNdimsLessThanOneRaises(self):
+ with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"):
+ mcmc_diagnostics.potential_scale_reduction(
+ rng.rand(2, 3, 4), independent_chain_ndims=0)
+
+
+class PotentialScaleReductionDynamicTest(test.TestCase,
+ _PotentialScaleReductionTest):
+
+ @property
+ def use_static_shape(self):
+ return False
+
+
+class _ReduceVarianceTest(object):
+
+ @property
+ def use_static_shape(self):
+ raise NotImplementedError(
+ "Subclass failed to impliment `use_static_shape`.")
+
+ def check_versus_numpy(self, x_, axis, biased, keepdims):
+ with self.test_session():
+ x_ = np.asarray(x_)
+ x = array_ops.placeholder_with_default(
+ input=x_, shape=x_.shape if self.use_static_shape else None)
+ var = mcmc_diagnostics._reduce_variance(
+ x, axis=axis, biased=biased, keepdims=keepdims)
+ np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims)
+
+ if self.use_static_shape:
+ self.assertAllEqual(np_var.shape, var.shape)
+
+ var_ = var.eval()
+ # We will mask below, which changes shape, so check shape explicitly here.
+ self.assertAllEqual(np_var.shape, var_.shape)
+
+ # We get NaN when we divide by zero due to the size being the same as ddof
+ nan_mask = np.isnan(np_var)
+ if nan_mask.any():
+ self.assertTrue(np.isnan(var_[nan_mask]).all())
+ self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02)
+
+ def testScalarBiasedTrue(self):
+ self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False)
+
+ def testScalarBiasedFalse(self):
+ # This should result in NaN.
+ self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False)
+
+ def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self):
+ self.check_versus_numpy(
+ x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False)
+
+ def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self):
+ self.check_versus_numpy(
+ x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True)
+
+ def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self):
+ self.check_versus_numpy(
+ x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True)
+
+ def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self):
+ self.check_versus_numpy(
+ x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False)
+
+
+class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest):
+
+ @property
+ def use_static_shape(self):
+ return True
+
+
+class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest):
+
+ @property
+ def use_static_shape(self):
+ return False
+
+
+if __name__ == "__main__":
+ test.main()
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for Markov Chain Monte Carlo (MCMC) sampling."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "potential_scale_reduction",
+]
+
+remove_undocumented(__name__, _allowed_symbols)
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.
+
+@@potential_scale_reduction
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+__all__ = [
+ "potential_scale_reduction",
+]
+
+
+def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
+ """Gelman and Rubin's potential scale reduction factor for chain convergence.
+
+ Given `N > 1` samples from each of `C > 1` independent chains, the potential
+ scale reduction factor, commonly referred to as R-hat, measures convergence of
+ the chains (to the same target) by testing for equality of means.
+ Specifically, R-hat measures the degree to which variance (of the means)
+ between chains exceeds what one would expect if the chains were identically
+ distributed. See [1], [2].
+
+ Some guidelines:
+
+ * The initial state of the chains should be drawn from a distribution
+ overdispersed with respect to the target.
+ * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1.
+ Before that, R-hat > 1 (except in pathological cases, e.g. if the chain
+ paths were identical).
+ * The above holds for any number of chains `C > 1`. Increasing `C` does
+ improves effectiveness of the diagnostic.
+ * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of
+ course this is problem depedendent. See [2].
+ * R-hat only measures non-convergence of the mean. If higher moments, or other
+ statistics are desired, a different diagnostic should be used. See [2].
+
+ #### Examples
+
+ Diagnosing convergence by monitoring 10 chains that each attempt to
+ sample from a 2-variate normal.
+
+ ```python
+ tfd = tf.contrib.distributions
+ tfb = tf.contrib.bayesflow
+
+ target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
+
+ # Get 10 (2x) overdispersed initial states.
+ initial_state = target.sample(10) * 2.
+ ==> (10, 2)
+
+ # Get 1000 samples from the 10 independent chains.
+ state = tfb.hmc.sample_chain(
+ num_results=1000,
+ target_log_prob_fn=target.log_prob,
+ current_state=initial_state,
+ step_size=0.05,
+ num_leapfrog_steps=20,
+ num_burnin_steps=200)
+ state.shape
+ ==> (1000, 10, 2)
+
+ rhat = tfb.mcmc_diagnostics.potential_scale_reduction(
+ state, independent_chain_ndims=1)
+
+ # The second dimension needed a longer burn-in.
+ rhat.eval()
+ ==> [1.05, 1.3]
+ ```
+
+ To see why R-hat is reasonable, let `X` be a random variable drawn uniformly
+ from the combined states (combined over all chains). Then, in the limit
+ `N, C --> infinity`, with `E`, `Var` denoting expectation and variance,
+
+ ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].```
+
+ Using the law of total variance, the numerator is the variance of the combined
+ states, and the denominator is the total variance minus the variance of the
+ the individual chain means. If the chains are all drawing from the same
+ distribution, they will have the same mean, and thus the ratio should be one.
+
+ [1] "Inference from Iterative Simulation Using Multiple Sequences"
+ Andrew Gelman and Donald B. Rubin
+ Statist. Sci. Volume 7, Number 4 (1992), 457-472.
+ [2] "General Methods for Monitoring Convergence of Iterative Simulations"
+ Stephen P. Brooks and Andrew Gelman
+ Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4.
+
+ Args:
+ state: `Tensor` or Python `list` of `Tensor`s representing the state(s) of
+ a Markov Chain at each result step. The `ith` state is assumed to have
+ shape `[Ni, Ci1, Ci2,...,CiD] + A`.
+ Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain.
+ Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent
+ chains to be tested for convergence to the same target.
+ The remaining dimensions, `A`, can have any shape (even empty).
+ independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the
+ number of giving the number of dimensions, from `dim = 1` to `dim = D`,
+ holding independent chain results to be tested for convergence.
+ name: `String` name to prepend to created ops. Default:
+ `potential_scale_reduction`.
+
+ Returns:
+ `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for
+ the state(s). Same `dtype` as `state`, and shape equal to
+ `state.shape[1 + independent_chain_ndims:]`.
+
+ Raises:
+ ValueError: If `independent_chain_ndims < 1`.
+ """
+ # tensor_util.constant_value returns None iff a constant value (as a numpy
+ # array) is not efficiently computable. Therefore, we try constant_value then
+ # check for None.
+ icn_const_ = tensor_util.constant_value(
+ ops.convert_to_tensor(independent_chain_ndims))
+ if icn_const_ is not None:
+ independent_chain_ndims = icn_const_
+ if icn_const_ < 1:
+ raise ValueError(
+ "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format(
+ independent_chain_ndims))
+ with ops.name_scope(
+ name,
+ "potential_scale_reduction",
+ values=[state, independent_chain_ndims]):
+ if _is_list_like(state):
+ return [
+ _potential_scale_reduction_single_state(s, independent_chain_ndims)
+ for s in state
+ ]
+ return _potential_scale_reduction_single_state(state,
+ independent_chain_ndims)
+
+
+def _potential_scale_reduction_single_state(state, independent_chain_ndims):
+ """potential_scale_reduction for one single state `Tensor`."""
+ # We assume exactly one leading dimension indexes e.g. correlated samples from
+ # each Markov chain.
+ state = ops.convert_to_tensor(state, name="state")
+ sample_ndims = 1
+
+ sample_axis = math_ops.range(0, sample_ndims)
+ chain_axis = math_ops.range(sample_ndims,
+ sample_ndims + independent_chain_ndims)
+ sample_and_chain_axis = math_ops.range(0,
+ sample_ndims + independent_chain_ndims)
+
+ n = _axis_size(state, sample_axis)
+ m = _axis_size(state, chain_axis)
+
+ # In the language of [2],
+ # B / n is the between chain variance, the variance of the chain means.
+ # W is the within sequence variance, the mean of the chain variances.
+ b_div_n = _reduce_variance(
+ math_ops.reduce_mean(state, sample_axis, keepdims=True),
+ sample_and_chain_axis,
+ biased=False)
+ w = math_ops.reduce_mean(
+ _reduce_variance(state, sample_axis, keepdims=True, biased=True),
+ sample_and_chain_axis)
+
+ # sigma^2_+ is an estimate of the true variance, which would be unbiased if
+ # each chain was drawn from the target. c.f. "law of total variance."
+ sigma_2_plus = w + b_div_n
+
+ return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
+
+
+def effective_sample_size(state,
+ independent_chain_ndims=1,
+ max_lags=None,
+ max_lags_threshold=None,
+ name="effective_sample_size"):
+ if max_lags is not None and max_lags_threshold is not None:
+ raise ValueError(
+ "Expected at most one of max_lags, max_lags_threshold to be provided. "
+ "Found: {}, {}".format(max_lags, max_lags_threshold))
+ with ops.name_scope(
+ name,
+ values=[state, independent_chain_ndims, max_lags, max_lags_threshold]):
+ pass
+
+
+# TODO(b/72873233) Move some variant of this to sample_stats.
+def _reduce_variance(x, axis=None, biased=True, keepdims=False):
+ with ops.name_scope("reduce_variance"):
+ x = ops.convert_to_tensor(x, name="x")
+ mean = math_ops.reduce_mean(x, axis=axis, keepdims=True)
+ biased_var = math_ops.reduce_mean(
+ math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims)
+ if biased:
+ return biased_var
+ n = _axis_size(x, axis)
+ return (n / (n - 1.)) * biased_var
+
+
+def _axis_size(x, axis=None):
+ """Get number of elements of `x` in `axis`, as type `x.dtype`."""
+ if axis is None:
+ return math_ops.cast(array_ops.size(x), x.dtype)
+ return math_ops.cast(
+ math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype)
+
+
+def _is_list_like(x):
+ """Helper which returns `True` if input is `list`-like."""
+ return isinstance(x, (tuple, list))