mcmc_diagnostics.py added to contrib/bayesflow/. potential_scale_reduction function...
authorIan Langmore <langmore@google.com>
Tue, 6 Feb 2018 08:29:29 +0000 (00:29 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Feb 2018 08:36:03 +0000 (00:36 -0800)
.

PiperOrigin-RevId: 184644450

tensorflow/contrib/bayesflow/BUILD
tensorflow/contrib/bayesflow/__init__.py
tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py [new file with mode: 0644]
tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py [new file with mode: 0644]
tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py [new file with mode: 0644]

index 6e0f0a0..82944f5 100644 (file)
@@ -138,6 +138,25 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "mcmc_diagnostics_test",
+    size = "small",
+    srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"],
+    additional_deps = [
+        ":bayesflow_py",
+        "//third_party/py/numpy",
+        "//tensorflow/contrib/distributions:distributions_py",
+        "//tensorflow/python/ops/distributions",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:random_seed",
+    ],
+)
+
+cuda_py_test(
     name = "monte_carlo_test",
     size = "small",
     srcs = ["python/kernel_tests/monte_carlo_test.py"],
index 95b9452..c411026 100644 (file)
@@ -26,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import custom_grad
 from tensorflow.contrib.bayesflow.python.ops import halton_sequence
 from tensorflow.contrib.bayesflow.python.ops import hmc
 from tensorflow.contrib.bayesflow.python.ops import layers
+from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics
 from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
 from tensorflow.contrib.bayesflow.python.ops import monte_carlo
 from tensorflow.contrib.bayesflow.python.ops import optimizers
@@ -42,6 +43,7 @@ _allowed_symbols = [
     'hmc',
     'layers',
     'metropolis_hastings',
+    'mcmc_diagnostics',
     'monte_carlo',
     'optimizers',
     'special_math',
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py
new file mode 100644 (file)
index 0000000..7652b6a
--- /dev/null
@@ -0,0 +1,230 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MCMC diagnostic utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+rng = np.random.RandomState(42)
+
+
+class _PotentialScaleReductionTest(object):
+
+  @property
+  def use_static_shape(self):
+    raise NotImplementedError(
+        "Subclass failed to impliment `use_static_shape`.")
+
+  def testListOfStatesWhereFirstPassesSecondFails(self):
+    """Simple test showing API with two states.  Read first!."""
+    n_samples = 1000
+
+    # state_0 is two scalar chains taken from iid Normal(0, 1).  Will pass.
+    state_0 = rng.randn(n_samples, 2)
+
+    # state_1 is three 4-variate chains taken from Normal(0, 1) that have been
+    # shifted.  Since every chain is shifted, they are not the same, and the
+    # test should fail.
+    offset = np.array([1., -1., 2.]).reshape(3, 1)
+    state_1 = rng.randn(n_samples, 3, 4) + offset
+
+    rhat = mcmc_diagnostics.potential_scale_reduction(
+        state=[state_0, state_1], independent_chain_ndims=1)
+
+    self.assertIsInstance(rhat, list)
+    with self.test_session() as sess:
+      rhat_0_, rhat_1_ = sess.run(rhat)
+
+    # r_hat_0 should be close to 1, meaning test is passed.
+    self.assertAllEqual((), rhat_0_.shape)
+    self.assertAllClose(1., rhat_0_, rtol=0.02)
+
+    # r_hat_1 should be greater than 1.2, meaning test has failed.
+    self.assertAllEqual((4,), rhat_1_.shape)
+    self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2)
+
+  def check_results(self, state_, independent_chain_shape, should_pass):
+    sample_ndims = 1
+    independent_chain_ndims = len(independent_chain_shape)
+    with self.test_session():
+      state = array_ops.placeholder_with_default(
+          input=state_, shape=state_.shape if self.use_static_shape else None)
+
+      rhat = mcmc_diagnostics.potential_scale_reduction(
+          state, independent_chain_ndims=independent_chain_ndims)
+
+      if self.use_static_shape:
+        self.assertAllEqual(
+            state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape)
+
+      rhat_ = rhat.eval()
+      if should_pass:
+        self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02)
+      else:
+        self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2)
+
+  def iid_normal_chains_should_pass_wrapper(self,
+                                            sample_shape,
+                                            independent_chain_shape,
+                                            other_shape,
+                                            dtype=np.float32):
+    """Check results with iid normal chains."""
+
+    state_shape = sample_shape + independent_chain_shape + other_shape
+    state_ = rng.randn(*state_shape).astype(dtype)
+
+    # The "other" dimensions do not have to be identical, just independent, so
+    # force them to not be identical.
+    if other_shape:
+      state_ *= rng.rand(*other_shape).astype(dtype)
+
+    self.check_results(state_, independent_chain_shape, should_pass=True)
+
+  def testPassingIIDNdimsAreIndependentOneOtherZero(self):
+    self.iid_normal_chains_should_pass_wrapper(
+        sample_shape=[10000], independent_chain_shape=[4], other_shape=[])
+
+  def testPassingIIDNdimsAreIndependentOneOtherOne(self):
+    self.iid_normal_chains_should_pass_wrapper(
+        sample_shape=[10000], independent_chain_shape=[3], other_shape=[7])
+
+  def testPassingIIDNdimsAreIndependentOneOtherTwo(self):
+    self.iid_normal_chains_should_pass_wrapper(
+        sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7])
+
+  def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self):
+    self.iid_normal_chains_should_pass_wrapper(
+        sample_shape=[10000],
+        independent_chain_shape=[2, 3],
+        other_shape=[5, 7],
+        dtype=np.float64)
+
+  def offset_normal_chains_should_fail_wrapper(
+      self, sample_shape, independent_chain_shape, other_shape):
+    """Check results with normal chains that are offset from each other."""
+
+    state_shape = sample_shape + independent_chain_shape + other_shape
+    state_ = rng.randn(*state_shape)
+
+    # Add a significant offset to the different (formerly iid) chains.
+    offset = np.linspace(
+        0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len(
+            sample_shape) + independent_chain_shape + [1] * len(other_shape))
+    state_ += offset
+
+    self.check_results(state_, independent_chain_shape, should_pass=False)
+
+  def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self):
+    self.offset_normal_chains_should_fail_wrapper(
+        sample_shape=[10000], independent_chain_shape=[2], other_shape=[5])
+
+
+class PotentialScaleReductionStaticTest(test.TestCase,
+                                        _PotentialScaleReductionTest):
+
+  @property
+  def use_static_shape(self):
+    return True
+
+  def testIndependentNdimsLessThanOneRaises(self):
+    with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"):
+      mcmc_diagnostics.potential_scale_reduction(
+          rng.rand(2, 3, 4), independent_chain_ndims=0)
+
+
+class PotentialScaleReductionDynamicTest(test.TestCase,
+                                         _PotentialScaleReductionTest):
+
+  @property
+  def use_static_shape(self):
+    return False
+
+
+class _ReduceVarianceTest(object):
+
+  @property
+  def use_static_shape(self):
+    raise NotImplementedError(
+        "Subclass failed to impliment `use_static_shape`.")
+
+  def check_versus_numpy(self, x_, axis, biased, keepdims):
+    with self.test_session():
+      x_ = np.asarray(x_)
+      x = array_ops.placeholder_with_default(
+          input=x_, shape=x_.shape if self.use_static_shape else None)
+      var = mcmc_diagnostics._reduce_variance(
+          x, axis=axis, biased=biased, keepdims=keepdims)
+      np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims)
+
+      if self.use_static_shape:
+        self.assertAllEqual(np_var.shape, var.shape)
+
+      var_ = var.eval()
+      # We will mask below, which changes shape, so check shape explicitly here.
+      self.assertAllEqual(np_var.shape, var_.shape)
+
+      # We get NaN when we divide by zero due to the size being the same as ddof
+      nan_mask = np.isnan(np_var)
+      if nan_mask.any():
+        self.assertTrue(np.isnan(var_[nan_mask]).all())
+      self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02)
+
+  def testScalarBiasedTrue(self):
+    self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False)
+
+  def testScalarBiasedFalse(self):
+    # This should result in NaN.
+    self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False)
+
+  def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self):
+    self.check_versus_numpy(
+        x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False)
+
+  def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self):
+    self.check_versus_numpy(
+        x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True)
+
+  def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self):
+    self.check_versus_numpy(
+        x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True)
+
+  def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self):
+    self.check_versus_numpy(
+        x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False)
+
+
+class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest):
+
+  @property
+  def use_static_shape(self):
+    return True
+
+
+class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest):
+
+  @property
+  def use_static_shape(self):
+    return False
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py
new file mode 100644 (file)
index 0000000..5f3e6ad
--- /dev/null
@@ -0,0 +1,31 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for Markov Chain Monte Carlo (MCMC) sampling."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+    "potential_scale_reduction",
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py
new file mode 100644 (file)
index 0000000..3b6f924
--- /dev/null
@@ -0,0 +1,228 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.
+
+@@potential_scale_reduction
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+__all__ = [
+    "potential_scale_reduction",
+]
+
+
+def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
+  """Gelman and Rubin's potential scale reduction factor for chain convergence.
+
+  Given `N > 1` samples from each of `C > 1` independent chains, the potential
+  scale reduction factor, commonly referred to as R-hat, measures convergence of
+  the chains (to the same target) by testing for equality of means.
+  Specifically, R-hat measures the degree to which variance (of the means)
+  between chains exceeds what one would expect if the chains were identically
+  distributed.  See [1], [2].
+
+  Some guidelines:
+
+  * The initial state of the chains should be drawn from a distribution
+    overdispersed with respect to the target.
+  * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1.
+    Before that, R-hat > 1 (except in pathological cases, e.g. if the chain
+    paths were identical).
+  * The above holds for any number of chains `C > 1`.  Increasing `C` does
+    improves effectiveness of the diagnostic.
+  * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of
+    course this is problem depedendent.  See [2].
+  * R-hat only measures non-convergence of the mean. If higher moments, or other
+    statistics are desired, a different diagnostic should be used.  See [2].
+
+  #### Examples
+
+  Diagnosing convergence by monitoring 10 chains that each attempt to
+  sample from a 2-variate normal.
+
+  ```python
+  tfd = tf.contrib.distributions
+  tfb = tf.contrib.bayesflow
+
+  target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
+
+  # Get 10 (2x) overdispersed initial states.
+  initial_state = target.sample(10) * 2.
+  ==> (10, 2)
+
+  # Get 1000 samples from the 10 independent chains.
+  state = tfb.hmc.sample_chain(
+      num_results=1000,
+      target_log_prob_fn=target.log_prob,
+      current_state=initial_state,
+      step_size=0.05,
+      num_leapfrog_steps=20,
+      num_burnin_steps=200)
+  state.shape
+  ==> (1000, 10, 2)
+
+  rhat = tfb.mcmc_diagnostics.potential_scale_reduction(
+      state, independent_chain_ndims=1)
+
+  # The second dimension needed a longer burn-in.
+  rhat.eval()
+  ==> [1.05, 1.3]
+  ```
+
+  To see why R-hat is reasonable, let `X` be a random variable drawn uniformly
+  from the combined states (combined over all chains).  Then, in the limit
+  `N, C --> infinity`, with `E`, `Var` denoting expectation and variance,
+
+  ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].```
+
+  Using the law of total variance, the numerator is the variance of the combined
+  states, and the denominator is the total variance minus the variance of the
+  the individual chain means.  If the chains are all drawing from the same
+  distribution, they will have the same mean, and thus the ratio should be one.
+
+  [1] "Inference from Iterative Simulation Using Multiple Sequences"
+      Andrew Gelman and Donald B. Rubin
+      Statist. Sci. Volume 7, Number 4 (1992), 457-472.
+  [2] "General Methods for Monitoring Convergence of Iterative Simulations"
+      Stephen P. Brooks and Andrew Gelman
+      Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4.
+
+  Args:
+    state:  `Tensor` or Python `list` of `Tensor`s representing the state(s) of
+      a Markov Chain at each result step.  The `ith` state is assumed to have
+      shape `[Ni, Ci1, Ci2,...,CiD] + A`.
+      Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain.
+      Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent
+      chains to be tested for convergence to the same target.
+      The remaining dimensions, `A`, can have any shape (even empty).
+    independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the
+      number of giving the number of dimensions, from `dim = 1` to `dim = D`,
+      holding independent chain results to be tested for convergence.
+    name: `String` name to prepend to created ops.  Default:
+      `potential_scale_reduction`.
+
+  Returns:
+    `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for
+    the state(s).  Same `dtype` as `state`, and shape equal to
+    `state.shape[1 + independent_chain_ndims:]`.
+
+  Raises:
+    ValueError:  If `independent_chain_ndims < 1`.
+  """
+  # tensor_util.constant_value returns None iff a constant value (as a numpy
+  # array) is not efficiently computable.  Therefore, we try constant_value then
+  # check for None.
+  icn_const_ = tensor_util.constant_value(
+      ops.convert_to_tensor(independent_chain_ndims))
+  if icn_const_ is not None:
+    independent_chain_ndims = icn_const_
+    if icn_const_ < 1:
+      raise ValueError(
+          "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format(
+              independent_chain_ndims))
+  with ops.name_scope(
+      name,
+      "potential_scale_reduction",
+      values=[state, independent_chain_ndims]):
+    if _is_list_like(state):
+      return [
+          _potential_scale_reduction_single_state(s, independent_chain_ndims)
+          for s in state
+      ]
+    return _potential_scale_reduction_single_state(state,
+                                                   independent_chain_ndims)
+
+
+def _potential_scale_reduction_single_state(state, independent_chain_ndims):
+  """potential_scale_reduction for one single state `Tensor`."""
+  # We assume exactly one leading dimension indexes e.g. correlated samples from
+  # each Markov chain.
+  state = ops.convert_to_tensor(state, name="state")
+  sample_ndims = 1
+
+  sample_axis = math_ops.range(0, sample_ndims)
+  chain_axis = math_ops.range(sample_ndims,
+                              sample_ndims + independent_chain_ndims)
+  sample_and_chain_axis = math_ops.range(0,
+                                         sample_ndims + independent_chain_ndims)
+
+  n = _axis_size(state, sample_axis)
+  m = _axis_size(state, chain_axis)
+
+  # In the language of [2],
+  # B / n is the between chain variance, the variance of the chain means.
+  # W is the within sequence variance, the mean of the chain variances.
+  b_div_n = _reduce_variance(
+      math_ops.reduce_mean(state, sample_axis, keepdims=True),
+      sample_and_chain_axis,
+      biased=False)
+  w = math_ops.reduce_mean(
+      _reduce_variance(state, sample_axis, keepdims=True, biased=True),
+      sample_and_chain_axis)
+
+  # sigma^2_+ is an estimate of the true variance, which would be unbiased if
+  # each chain was drawn from the target.  c.f. "law of total variance."
+  sigma_2_plus = w + b_div_n
+
+  return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
+
+
+def effective_sample_size(state,
+                          independent_chain_ndims=1,
+                          max_lags=None,
+                          max_lags_threshold=None,
+                          name="effective_sample_size"):
+  if max_lags is not None and max_lags_threshold is not None:
+    raise ValueError(
+        "Expected at most one of max_lags, max_lags_threshold to be provided.  "
+        "Found: {}, {}".format(max_lags, max_lags_threshold))
+  with ops.name_scope(
+      name,
+      values=[state, independent_chain_ndims, max_lags, max_lags_threshold]):
+    pass
+
+
+# TODO(b/72873233) Move some variant of this to sample_stats.
+def _reduce_variance(x, axis=None, biased=True, keepdims=False):
+  with ops.name_scope("reduce_variance"):
+    x = ops.convert_to_tensor(x, name="x")
+    mean = math_ops.reduce_mean(x, axis=axis, keepdims=True)
+    biased_var = math_ops.reduce_mean(
+        math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims)
+    if biased:
+      return biased_var
+    n = _axis_size(x, axis)
+    return (n / (n - 1.)) * biased_var
+
+
+def _axis_size(x, axis=None):
+  """Get number of elements of `x` in `axis`, as type `x.dtype`."""
+  if axis is None:
+    return math_ops.cast(array_ops.size(x), x.dtype)
+  return math_ops.cast(
+      math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype)
+
+
+def _is_list_like(x):
+  """Helper which returns `True` if input is `list`-like."""
+  return isinstance(x, (tuple, list))