From af9afcfe44ed97dc13422445b1f1c91eaa98d583 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Fri, 2 Feb 2018 13:07:37 -0800 Subject: [PATCH] Breaking change: Revise HMC interface to accept a list of Tensors representing a partitioning of chain state. PiperOrigin-RevId: 184323369 --- .../bayesflow/python/kernel_tests/hmc_test.py | 680 ++++++--- tensorflow/contrib/bayesflow/python/ops/hmc.py | 11 +- .../contrib/bayesflow/python/ops/hmc_impl.py | 1546 +++++++++++++------- 3 files changed, 1455 insertions(+), 782 deletions(-) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index cbc66b6..d9d0dfc 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -19,29 +19,36 @@ from __future__ import division from __future__ import print_function import numpy as np -from scipy import special from scipy import stats from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_impl as gradients_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.platform import tf_logging as logging_ops + + +def _reduce_variance(x, axis=None, keepdims=False): + sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) + return math_ops.reduce_mean( + math_ops.squared_difference(x, sample_mean), axis, keepdims) -# TODO(b/66964210): Test float16. class HMCTest(test.TestCase): def setUp(self): self._shape_param = 5. self._rate_param = 10. - self._expected_x = (special.digamma(self._shape_param) - - np.log(self._rate_param)) - self._expected_exp_x = self._shape_param / self._rate_param random_seed.set_random_seed(10003) np.random.seed(10003) @@ -63,63 +70,46 @@ class HMCTest(test.TestCase): self._rate_param * math_ops.exp(x), event_dims) - def _log_gamma_log_prob_grad(self, x, event_dims=()): - """Computes log-pdf and gradient of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. Default is (), - i.e., all dimensions are independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - grad: The gradient of the log-pdf with respect to x. - """ - return (math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims), - self._shape_param - self._rate_param * math_ops.exp(x)) - - def _n_event_dims(self, x_shape, event_dims): - return np.prod([int(x_shape[i]) for i in event_dims]) - - def _integrator_conserves_energy(self, x, event_dims, sess, + def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, feed_dict=None): - def potential_and_grad(x): - log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims) - return -log_prob, -grad - - step_size = array_ops.placeholder(np.float32, [], name='step_size') - hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') + step_size = array_ops.placeholder(np.float32, [], name="step_size") + hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") if feed_dict is None: feed_dict = {} feed_dict[hmc_lf_steps] = 1000 - m = random_ops.random_normal(array_ops.shape(x)) - potential_0, grad_0 = potential_and_grad(x) - old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m, - event_dims) - - _, new_m, potential_1, _ = ( - hmc.leapfrog_integrator(step_size, hmc_lf_steps, x, - m, potential_and_grad, grad_0)) + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) - new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, + m = random_ops.random_normal(array_ops.shape(x)) + log_prob_0 = self._log_gamma_log_prob(x, event_dims) + grad_0 = gradients_ops.gradients(log_prob_0, x) + old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) + + new_m, _, log_prob_1, _ = _leapfrog_integrator( + current_momentums=[m], + target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), + current_state_parts=[x], + step_sizes=[step_size], + num_leapfrog_steps=hmc_lf_steps, + current_target_log_prob=log_prob_0, + current_grads_target_log_prob=grad_0) + new_m = new_m[0] + + new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, event_dims) x_shape = sess.run(x, feed_dict).shape - n_event_dims = self._n_event_dims(x_shape, event_dims) - feed_dict[step_size] = 0.1 / n_event_dims - old_energy_val, new_energy_val = sess.run([old_energy, new_energy], - feed_dict) - logging.vlog(1, 'average energy change: {}'.format( - abs(old_energy_val - new_energy_val).mean())) - - self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool), - abs(old_energy_val - new_energy_val) < 1.) - - def _integrator_conserves_energy_wrapper(self, event_dims): + event_size = np.prod(x_shape[independent_chain_ndims:]) + feed_dict[step_size] = 0.1 / event_size + old_energy_, new_energy_ = sess.run([old_energy, new_energy], + feed_dict) + logging_ops.vlog(1, "average energy relative change: {}".format( + (1. - new_energy_ / old_energy_).mean())) + self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) + + def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): """Tests the long-term energy conservation of the leapfrog integrator. The leapfrog integrator is symplectic, so for sufficiently small step @@ -127,135 +117,167 @@ class HMCTest(test.TestCase): the energy of the system blowing up or collapsing. Args: - event_dims: A tuple of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. + independent_chain_ndims: Python `int` scalar representing the number of + dims associated with independent chains. """ with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - - feed_dict = {x_ph: np.zeros([50, 10, 2])} - self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict) + x_ph = array_ops.placeholder(np.float32, name="x_ph") + feed_dict = {x_ph: np.random.rand(50, 10, 2)} + self._integrator_conserves_energy(x_ph, independent_chain_ndims, + sess, feed_dict) def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper([]) + self._integrator_conserves_energy_wrapper(0) def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper([1]) + self._integrator_conserves_energy_wrapper(1) def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper([2]) + self._integrator_conserves_energy_wrapper(2) - def testIntegratorEnergyConservation12(self): - self._integrator_conserves_energy_wrapper([1, 2]) + def testIntegratorEnergyConservation3(self): + self._integrator_conserves_energy_wrapper(3) - def testIntegratorEnergyConservation012(self): - self._integrator_conserves_energy_wrapper([0, 1, 2]) - - def _chain_gets_correct_expectations(self, x, event_dims, sess, - feed_dict=None): + def _chain_gets_correct_expectations(self, x, independent_chain_ndims, + sess, feed_dict=None): def log_gamma_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) - step_size = array_ops.placeholder(np.float32, [], name='step_size') - hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') - hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps') + num_results = array_ops.placeholder( + np.int32, [], name="num_results") + step_size = array_ops.placeholder( + np.float32, [], name="step_size") + num_leapfrog_steps = array_ops.placeholder( + np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} - feed_dict.update({step_size: 0.1, - hmc_lf_steps: 2, - hmc_n_steps: 300}) - - sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps], - step_size, - hmc_lf_steps, - x, log_gamma_log_prob, - event_dims) - - acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain], - feed_dict) - samples = samples[feed_dict[hmc_n_steps] // 2:] - expected_x_est = samples.mean() - expected_exp_x_est = np.exp(samples).mean() - - logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format( - self._expected_x, self._expected_exp_x)) - logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format( - expected_x_est, expected_exp_x_est)) - self.assertNear(expected_x_est, self._expected_x, 2e-2) - self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2) - self.assertTrue((acceptance_probs > 0.5).all()) - self.assertTrue((acceptance_probs <= 1.0).all()) - - def _chain_gets_correct_expectations_wrapper(self, event_dims): + feed_dict.update({num_results: 150, + step_size: 0.1, + num_leapfrog_steps: 2}) + + samples, kernel_results = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=log_gamma_log_prob, + current_state=x, + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps, + num_burnin_steps=150, + seed=42) + + expected_x = (math_ops.digamma(self._shape_param) + - np.log(self._rate_param)) + + expected_exp_x = self._shape_param / self._rate_param + + acceptance_probs_, samples_, expected_x_ = sess.run( + [kernel_results.acceptance_probs, samples, expected_x], + feed_dict) + + actual_x = samples_.mean() + actual_exp_x = np.exp(samples_).mean() + + logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( + expected_x_, expected_exp_x)) + logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( + actual_x, actual_exp_x)) + self.assertNear(actual_x, expected_x_, 2e-2) + self.assertNear(actual_exp_x, expected_exp_x, 2e-2) + self.assertTrue((acceptance_probs_ > 0.5).all()) + self.assertTrue((acceptance_probs_ <= 1.0).all()) + + def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - - feed_dict = {x_ph: np.zeros([50, 10, 2])} - self._chain_gets_correct_expectations(x_ph, event_dims, sess, - feed_dict) + x_ph = array_ops.placeholder(np.float32, name="x_ph") + feed_dict = {x_ph: np.random.rand(50, 10, 2)} + self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, + sess, feed_dict) def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper([]) + self._chain_gets_correct_expectations_wrapper(0) def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper([1]) + self._chain_gets_correct_expectations_wrapper(1) def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper([2]) - - def testHMCChainExpectations12(self): - self._chain_gets_correct_expectations_wrapper([1, 2]) + self._chain_gets_correct_expectations_wrapper(2) - def _kernel_leaves_target_invariant(self, initial_draws, event_dims, + def _kernel_leaves_target_invariant(self, initial_draws, + independent_chain_ndims, sess, feed_dict=None): def log_gamma_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) def fake_log_prob(x): """Cooled version of the target distribution.""" return 1.1 * log_gamma_log_prob(x) - step_size = array_ops.placeholder(np.float32, [], name='step_size') + step_size = array_ops.placeholder(np.float32, [], name="step_size") if feed_dict is None: feed_dict = {} feed_dict[step_size] = 0.4 - sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws, - log_gamma_log_prob, event_dims) - bad_sample, bad_acceptance_probs, _, _ = hmc.kernel( - step_size, 5, initial_draws, fake_log_prob, event_dims) - (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val, - updated_draws_val, fake_draws_val) = sess.run([acceptance_probs, - bad_acceptance_probs, - initial_draws, sample, - bad_sample], feed_dict) + sample, kernel_results = hmc.kernel( + target_log_prob_fn=log_gamma_log_prob, + current_state=initial_draws, + step_size=step_size, + num_leapfrog_steps=5, + seed=43) + + bad_sample, bad_kernel_results = hmc.kernel( + target_log_prob_fn=fake_log_prob, + current_state=initial_draws, + step_size=step_size, + num_leapfrog_steps=5, + seed=44) + + [ + acceptance_probs_, + bad_acceptance_probs_, + initial_draws_, + updated_draws_, + fake_draws_, + ] = sess.run([ + kernel_results.acceptance_probs, + bad_kernel_results.acceptance_probs, + initial_draws, + sample, + bad_sample, + ], feed_dict) + # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_val.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_val.mean(), 0.5) + self.assertGreater(acceptance_probs_.mean(), 0.5) + self.assertGreater(bad_acceptance_probs_.mean(), 0.5) + # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_val.mean(), 0.99) - self.assertLess(bad_acceptance_probs_val.mean(), 0.99) - _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(), - updated_draws_val.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(), - fake_draws_val.flatten()) - logging.vlog(1, 'acceptance rate for true target: {}'.format( - acceptance_probs_val.mean())) - logging.vlog(1, 'acceptance rate for fake target: {}'.format( - bad_acceptance_probs_val.mean())) - logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true)) - logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake)) + self.assertLess(acceptance_probs_.mean(), 0.99) + self.assertLess(bad_acceptance_probs_.mean(), 0.99) + + _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), + updated_draws_.flatten()) + _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), + fake_draws_.flatten()) + + logging_ops.vlog(1, "acceptance rate for true target: {}".format( + acceptance_probs_.mean())) + logging_ops.vlog(1, "acceptance rate for fake target: {}".format( + bad_acceptance_probs_.mean())) + logging_ops.vlog(1, "K-S p-value for true target: {}".format( + ks_p_value_true)) + logging_ops.vlog(1, "K-S p-value for fake target: {}".format( + ks_p_value_fake)) # Make sure that the MCMC update hasn't changed the empirical CDF much. self.assertGreater(ks_p_value_true, 1e-3) # Confirm that targeting the wrong distribution does # significantly change the empirical CDF. self.assertLess(ks_p_value_fake, 1e-6) - def _kernel_leaves_target_invariant_wrapper(self, event_dims): + def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): """Tests that the kernel leaves the target distribution invariant. Draws some independent samples from the target distribution, @@ -267,86 +289,116 @@ class HMCTest(test.TestCase): does change the target distribution. (And that we can detect that.) Args: - event_dims: A tuple of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. + independent_chain_ndims: Python `int` scalar representing the number of + dims associated with independent chains. """ with self.test_session() as sess: initial_draws = np.log(np.random.gamma(self._shape_param, size=[50000, 2, 2])) initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name='x_ph') + x_ph = array_ops.placeholder(np.float32, name="x_ph") feed_dict = {x_ph: initial_draws} - self._kernel_leaves_target_invariant(x_ph, event_dims, sess, - feed_dict) - - def testKernelLeavesTargetInvariantNullShape(self): - self._kernel_leaves_target_invariant_wrapper([]) + self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, + sess, feed_dict) def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper([1]) + self._kernel_leaves_target_invariant_wrapper(1) def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper([2]) + self._kernel_leaves_target_invariant_wrapper(2) - def testKernelLeavesTargetInvariant12(self): - self._kernel_leaves_target_invariant_wrapper([1, 2]) + def testKernelLeavesTargetInvariant3(self): + self._kernel_leaves_target_invariant_wrapper(3) - def _ais_gets_correct_log_normalizer(self, init, event_dims, sess, - feed_dict=None): + def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, + sess, feed_dict=None): def proposal_log_prob(x): - return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi), - event_dims) + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), + axis=event_dims) def target_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) if feed_dict is None: feed_dict = {} - w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob, - proposal_log_prob, event_dims) - - w_val = sess.run(w, feed_dict) - init_shape = sess.run(init, feed_dict).shape - normalizer_multiplier = np.prod([init_shape[i] for i in event_dims]) - - true_normalizer = -self._shape_param * np.log(self._rate_param) - true_normalizer += special.gammaln(self._shape_param) - true_normalizer *= normalizer_multiplier - - n_weights = np.prod(w_val.shape) - normalized_w = np.exp(w_val - true_normalizer) - standard_error = np.std(normalized_w) / np.sqrt(n_weights) - logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format( - true_normalizer, np.log(normalized_w.mean()) + true_normalizer, - n_weights)) - self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error) - - def _ais_gets_correct_log_normalizer_wrapper(self, event_dims): + num_steps = 200 + + _, ais_weights, _ = hmc.sample_annealed_importance_chain( + proposal_log_prob_fn=proposal_log_prob, + num_steps=num_steps, + target_log_prob_fn=target_log_prob, + step_size=0.5, + current_state=init, + num_leapfrog_steps=2, + seed=45) + + event_shape = array_ops.shape(init)[independent_chain_ndims:] + event_size = math_ops.reduce_prod(event_shape) + + log_true_normalizer = ( + -self._shape_param * math_ops.log(self._rate_param) + + math_ops.lgamma(self._shape_param)) + log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) + + log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) + - np.log(num_steps)) + + ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) + ais_weights_size = array_ops.size(ais_weights) + standard_error = math_ops.sqrt( + _reduce_variance(ratio_estimate_true) + / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) + + [ + ratio_estimate_true_, + log_true_normalizer_, + log_estimated_normalizer_, + standard_error_, + ais_weights_size_, + event_size_, + ] = sess.run([ + ratio_estimate_true, + log_true_normalizer, + log_estimated_normalizer, + standard_error, + ais_weights_size, + event_size, + ], feed_dict) + + logging_ops.vlog(1, " log_true_normalizer: {}\n" + " log_estimated_normalizer: {}\n" + " ais_weights_size: {}\n" + " event_size: {}\n".format( + log_true_normalizer_, + log_estimated_normalizer_, + ais_weights_size_, + event_size_)) + self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) + + def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): """Tests that AIS yields reasonable estimates of normalizers.""" with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - + x_ph = array_ops.placeholder(np.float32, name="x_ph") initial_draws = np.random.normal(size=[30, 2, 1]) - feed_dict = {x_ph: initial_draws} - - self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess, - feed_dict) - - def testAISNullShape(self): - self._ais_gets_correct_log_normalizer_wrapper([]) + self._ais_gets_correct_log_normalizer( + x_ph, + independent_chain_ndims, + sess, + feed_dict={x_ph: initial_draws}) def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper([1]) + self._ais_gets_correct_log_normalizer_wrapper(1) def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper([2]) + self._ais_gets_correct_log_normalizer_wrapper(2) - def testAIS12(self): - self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + def testAIS3(self): + self._ais_gets_correct_log_normalizer_wrapper(3) def testNanRejection(self): """Tests that an update that yields NaN potentials gets rejected. @@ -359,24 +411,29 @@ class HMCTest(test.TestCase): """ def _unbounded_exponential_log_prob(x): """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where(x < 0, - np.nan * array_ops.ones_like(x), - -x) + per_element_potentials = array_ops.where( + x < 0., + array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), + -x) return math_ops.reduce_sum(per_element_potentials) with self.test_session() as sess: initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, acceptance_probs, _, _ = hmc.kernel( - 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) - initial_x_val, updated_x_val, acceptance_probs_val = sess.run( - [initial_x, updated_x, acceptance_probs]) - - logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) - logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) - logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) - - self.assertAllEqual(initial_x_val, updated_x_val) - self.assertEqual(acceptance_probs_val, 0.) + updated_x, kernel_results = hmc.kernel( + target_log_prob_fn=_unbounded_exponential_log_prob, + current_state=initial_x, + step_size=2., + num_leapfrog_steps=5, + seed=46) + initial_x_, updated_x_, acceptance_probs_ = sess.run( + [initial_x, updated_x, kernel_results.acceptance_probs]) + + logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) + logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) + logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + + self.assertAllEqual(initial_x_, updated_x_) + self.assertEqual(acceptance_probs_, 0.) def testNanFromGradsDontPropagate(self): """Test that update with NaN gradients does not cause NaN in results.""" @@ -385,60 +442,195 @@ class HMCTest(test.TestCase): with self.test_session() as sess: initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( - 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) - initial_x_val, updated_x_val, acceptance_probs_val = sess.run( - [initial_x, updated_x, acceptance_probs]) - - logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) - logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) - logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) - - self.assertAllEqual(initial_x_val, updated_x_val) - self.assertEqual(acceptance_probs_val, 0.) + updated_x, kernel_results = hmc.kernel( + target_log_prob_fn=_nan_log_prob_with_nan_gradient, + current_state=initial_x, + step_size=2., + num_leapfrog_steps=5, + seed=47) + initial_x_, updated_x_, acceptance_probs_ = sess.run( + [initial_x, updated_x, kernel_results.acceptance_probs]) + + logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) + logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) + logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + + self.assertAllEqual(initial_x_, updated_x_) + self.assertEqual(acceptance_probs_, 0.) self.assertAllFinite( - gradients_impl.gradients(updated_x, initial_x)[0].eval()) - self.assertTrue( - gradients_impl.gradients(new_grad, initial_x)[0] is None) + gradients_ops.gradients(updated_x, initial_x)[0].eval()) + self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( + kernel_results.proposed_grads_target_log_prob, initial_x)]) + self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( + kernel_results.proposed_grads_target_log_prob, + kernel_results.proposed_state)]) # Gradients of the acceptance probs and new log prob are not finite. - _ = new_log_prob # Prevent unused arg error. # self.assertAllFinite( - # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) # self.assertAllFinite( - # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) + + def _testChainWorksDtype(self, dtype): + states, kernel_results = hmc.sample_chain( + num_results=10, + target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), + current_state=np.zeros(5).astype(dtype), + step_size=0.01, + num_leapfrog_steps=10, + seed=48) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run( + [states, kernel_results.acceptance_probs]) + self.assertEqual(dtype, states_.dtype) + self.assertEqual(dtype, acceptance_probs_.dtype) def testChainWorksIn64Bit(self): - def log_prob(x): - return - math_ops.reduce_sum(x * x, axis=-1) - states, acceptance_probs = hmc.chain( - n_iterations=10, - step_size=np.float64(0.01), - n_leapfrog_steps=10, - initial_x=np.zeros(5).astype(np.float64), - target_log_prob_fn=log_prob, - event_dims=[-1]) - with self.test_session() as sess: - states_, acceptance_probs_ = sess.run([states, acceptance_probs]) - self.assertEqual(np.float64, states_.dtype) - self.assertEqual(np.float64, acceptance_probs_.dtype) + self._testChainWorksDtype(np.float64) def testChainWorksIn16Bit(self): - def log_prob(x): - return - math_ops.reduce_sum(x * x, axis=-1) - states, acceptance_probs = hmc.chain( - n_iterations=10, - step_size=np.float16(0.01), - n_leapfrog_steps=10, - initial_x=np.zeros(5).astype(np.float16), - target_log_prob_fn=log_prob, - event_dims=[-1]) + self._testChainWorksDtype(np.float16) + + +class _EnergyComputationTest(object): + + def testHandlesNanFromPotential(self): + with self.test_session() as sess: + x = [1, np.inf, -np.inf, np.nan] + target_log_prob, proposed_target_log_prob = [ + self.dtype(x.flatten()) for x in np.meshgrid(x, x)] + num_chains = len(target_log_prob) + dummy_momentums = [-1, 1] + momentums = [self.dtype([dummy_momentums] * num_chains)] + proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] + + target_log_prob = ops.convert_to_tensor(target_log_prob) + momentums = [ops.convert_to_tensor(momentums[0])] + proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) + proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] + + energy = _compute_energy_change( + target_log_prob, + momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims=1) + grads = gradients_ops.gradients(energy, momentums) + + [actual_energy, grads_] = sess.run([energy, grads]) + + # Ensure energy is `inf` (note: that's positive inf) in weird cases and + # finite otherwise. + expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) + self.assertAllEqual(expected_energy, actual_energy) + + # Ensure gradient is finite. + self.assertAllEqual(np.ones_like(grads_).astype(np.bool), + np.isfinite(grads_)) + + def testHandlesNanFromKinetic(self): with self.test_session() as sess: - states_, acceptance_probs_ = sess.run([states, acceptance_probs]) - self.assertEqual(np.float16, states_.dtype) - self.assertEqual(np.float16, acceptance_probs_.dtype) + x = [1, np.inf, -np.inf, np.nan] + momentums, proposed_momentums = [ + [np.reshape(self.dtype(x), [-1, 1])] + for x in np.meshgrid(x, x)] + num_chains = len(momentums[0]) + target_log_prob = np.ones(num_chains, self.dtype) + proposed_target_log_prob = np.ones(num_chains, self.dtype) + target_log_prob = ops.convert_to_tensor(target_log_prob) + momentums = [ops.convert_to_tensor(momentums[0])] + proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) + proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] -if __name__ == '__main__': + energy = _compute_energy_change( + target_log_prob, + momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims=1) + grads = gradients_ops.gradients(energy, momentums) + + [actual_energy, grads_] = sess.run([energy, grads]) + + # Ensure energy is `inf` (note: that's positive inf) in weird cases and + # finite otherwise. + expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) + self.assertAllEqual(expected_energy, actual_energy) + + # Ensure gradient is finite. + g = grads_[0].reshape([len(x), len(x)])[:, 0] + self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) + + # The remaining gradients are nan because the momentum was itself nan or + # inf. + g = grads_[0].reshape([len(x), len(x)])[:, 1:] + self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) + + +class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): + dtype = np.float16 + + +class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): + dtype = np.float32 + + +class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): + dtype = np.float64 + + +class _HMCHandlesLists(object): + + def testStateParts(self): + with self.test_session() as sess: + dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) + dist_y = independent_lib.Independent( + gamma_lib.Gamma(concentration=self.dtype([1, 2]), + rate=self.dtype([0.5, 0.75])), + reinterpreted_batch_ndims=1) + def target_log_prob(x, y): + return dist_x.log_prob(x) + dist_y.log_prob(y) + x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] + samples, _ = hmc.sample_chain( + num_results=int(2e3), + target_log_prob_fn=target_log_prob, + current_state=x0, + step_size=0.85, + num_leapfrog_steps=3, + num_burnin_steps=int(250), + seed=49) + actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] + actual_vars = [_reduce_variance(s, axis=0) for s in samples] + expected_means = [dist_x.mean(), dist_y.mean()] + expected_vars = [dist_x.variance(), dist_y.variance()] + [ + actual_means_, + actual_vars_, + expected_means_, + expected_vars_, + ] = sess.run([ + actual_means, + actual_vars, + expected_means, + expected_vars, + ]) + self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) + self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.30) + + +class HMCHandlesLists16(_HMCHandlesLists, test.TestCase): + dtype = np.float16 + + +class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): + dtype = np.float32 + + +class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): + dtype = np.float64 + + +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py index 977d42f..7fd5652 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -""" +"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" from __future__ import absolute_import from __future__ import division @@ -24,11 +23,9 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disabl from tensorflow.python.util import all_util _allowed_symbols = [ - 'chain', - 'kernel', - 'leapfrog_integrator', - 'leapfrog_step', - 'ais_chain' + "sample_chain", + "sample_annealed_importance_chain", + "kernel", ] all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 5685a94..f7a11c2 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -14,17 +14,16 @@ # ============================================================================== """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -@@chain -@@update -@@leapfrog_integrator -@@leapfrog_step -@@ais_chain +@@sample_chain +@@sample_annealed_importance_chain +@@kernel """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import numpy as np from tensorflow.python.framework import dtypes @@ -32,168 +31,292 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_impl as gradients_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.ops.distributions import util as distributions_util __all__ = [ - 'chain', - 'kernel', - 'leapfrog_integrator', - 'leapfrog_step', - 'ais_chain' + "sample_chain", + "sample_annealed_importance_chain", + "kernel", ] -def _make_potential_and_grad(target_log_prob_fn): - def potential_and_grad(x): - log_prob_result = -target_log_prob_fn(x) - grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result), - x)[0] - return log_prob_result, grad_result - return potential_and_grad - - -def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, - target_log_prob_fn, event_dims=(), name=None): +KernelResults = collections.namedtuple( + "KernelResults", + [ + "acceptance_probs", + "current_grads_target_log_prob", # "Current result" means "accepted". + "current_target_log_prob", # "Current result" means "accepted". + "energy_change", + "is_accepted", + "proposed_grads_target_log_prob", + "proposed_state", + "proposed_target_log_prob", + "random_positive", + ]) + + +def _make_dummy_kernel_results( + dummy_state, + dummy_target_log_prob, + dummy_grads_target_log_prob): + return KernelResults( + acceptance_probs=dummy_target_log_prob, + current_grads_target_log_prob=dummy_grads_target_log_prob, + current_target_log_prob=dummy_target_log_prob, + energy_change=dummy_target_log_prob, + is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), + proposed_grads_target_log_prob=dummy_grads_target_log_prob, + proposed_state=dummy_state, + proposed_target_log_prob=dummy_target_log_prob, + random_positive=dummy_target_log_prob, + ) + + +def sample_chain( + num_results, + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + num_burnin_steps=0, + num_steps_between_results=0, + seed=None, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function samples from an HMC Markov - chain whose initial state is `initial_x` and whose stationary - distribution has log-density `target_log_prob_fn()`. - - This function can update multiple chains in parallel. It assumes - that all dimensions of `initial_x` not specified in `event_dims` are - independent, and should therefore be updated independently. The - output of `target_log_prob_fn()` should sum log-probabilities across - all event dimensions. Slices along dimensions not in `event_dims` - may have different target distributions; this is up to + Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm + that takes a series of gradient-informed steps to produce a Metropolis + proposal. This function samples from an HMC Markov chain at `current_state` + and whose stationary distribution has log-unnormalized-density `target_log_prob_fn()`. - This function basically just wraps `hmc.kernel()` in a tf.scan() loop. + This function samples from multiple chains in parallel. It assumes that the + the leftmost dimensions of (each) `current_state` (part) index an independent + chain. The function `target_log_prob_fn()` sums log-probabilities across + event dimensions (i.e., current state (part) rightmost dimensions). Each + element of the output of `target_log_prob_fn()` represents the (possibly + unnormalized) log-probability of the joint distribution over (all) the current + state (parts). - Args: - n_iterations: Integer number of Markov chain updates to run. - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - initial_x: Tensor of initial state(s) of the Markov chain(s). - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - name: Python `str` name prefixed to Ops created by this function. + The `current_state` can be represented as a single `Tensor` or a `list` of + `Tensors` which collectively represent the current state. When specifying a + `list`, one must also specify a list of `step_size`s. - Returns: - acceptance_probs: Tensor with the acceptance probabilities for each - iteration. Has shape matching `target_log_prob_fn(initial_x)`. - chain_states: Tensor with the state of the Markov chain at each iteration. - Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`. + Only one out of every `num_steps_between_samples + 1` steps is included in the + returned results. This "thinning" comes at a cost of reduced statistical + power, while reducing memory requirements and autocorrelation. For more + discussion see [1]. + + [1]: "Statistically efficient thinning of a Markov chain sampler." + Art B. Owen. April 2017. + http://statweb.stanford.edu/~owen/reports/bestthinning.pdf #### Examples: - ```python - # Sampling from a standard normal (note `log_joint()` is unnormalized): - def log_joint(x): - return tf.reduce_sum(-0.5 * tf.square(x)) - chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, - event_dims=[0]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) - ``` + ##### Sample from a diagonal-variance Gaussian. ```python - # Sampling from a diagonal-variance Gaussian: - variances = tf.linspace(1., 3., 10) - def log_joint(x): - return tf.reduce_sum(-0.5 / variances * tf.square(x)) - chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, - event_dims=[0]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + tfd = tf.contrib.distributions + + def make_likelihood(true_variances): + return tfd.MultivariateNormalDiag( + scale_diag=tf.sqrt(true_variances)) + + dims = 10 + dtype = np.float32 + true_variances = tf.linspace(dtype(1), dtype(3), dims) + likelihood = make_likelihood(true_variances) + + states, kernel_results = hmc.sample_chain( + num_results=1000, + target_log_prob_fn=likelihood.log_prob, + current_state=tf.zeros(dims), + step_size=0.5, + num_leapfrog_steps=2, + num_burnin_steps=500) + + # Compute sample stats. + sample_mean = tf.reduce_mean(states, axis=0) + sample_var = tf.reduce_mean( + tf.squared_difference(states, sample_mean), + axis=0) ``` - ```python - # Sampling from factor-analysis posteriors with known factors W: - # mu[i, j] ~ Normal(0, 1) - # x[i] ~ Normal(matmul(mu[i], W), I) - def log_joint(mu, x, W): - prior = -0.5 * tf.reduce_sum(tf.square(mu), 1) - x_mean = tf.matmul(mu, W) - likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1) - return prior + likelihood - chain, acceptance_probs = hmc.chain(1000, 0.1, 2, - tf.zeros([x.shape[0], W.shape[0]]), - lambda mu: log_joint(mu, x, W), - event_dims=[1]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + ##### Sampling from factor-analysis posteriors with known factors. + + I.e., + + ```none + for i=1..n: + w[i] ~ Normal(0, eye(d)) # prior + x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood ``` + where `F` denotes factors. + ```python - # Sampling from the posterior of a Bayesian regression model.: - - # Run 100 chains in parallel, each with a different initialization. - initial_beta = tf.random_normal([100, x.shape[1]]) - chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta, - log_joint_partial, event_dims=[1]) - # Discard first halves of chains as warmup/burn-in - warmed_up = chain[500:] - # Averaging across samples within a chain and across chains - mean_est = tf.reduce_mean(warmed_up, [0, 1]) - var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est) + tfd = tf.contrib.distributions + + def make_prior(dims, dtype): + return tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + def make_likelihood(weights, factors): + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) + + # Setup data. + num_weights = 10 + num_factors = 4 + num_chains = 100 + dtype = np.float32 + + prior = make_prior(num_weights, dtype) + weights = prior.sample(num_chains) + factors = np.random.randn(num_factors, num_weights).astype(dtype) + x = make_likelihood(weights, factors).sample(num_chains) + + def target_log_prob(w): + # Target joint is: `f(w) = p(w, x | factors)`. + return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) + + # Get `num_results` samples from `num_chains` independent chains. + chains_states, kernels_results = hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target_log_prob, + current_state=tf.zeros([num_chains, dims], dtype), + step_size=0.1, + num_leapfrog_steps=2, + num_burnin_steps=500) + + # Compute sample stats. + sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) + sample_var = tf.reduce_mean( + tf.squared_difference(chains_states, sample_mean), + axis=[0, 1]) ``` - """ - with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size, - n_leapfrog_steps, initial_x]): - initial_x = ops.convert_to_tensor(initial_x, name='initial_x') - non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) - - def body(a, _): - updated_x, acceptance_probs, log_prob, grad = kernel( - step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims, - a[2], a[3]) - return updated_x, acceptance_probs, log_prob, grad - - potential_and_grad = _make_potential_and_grad(target_log_prob_fn) - potential, grad = potential_and_grad(initial_x) - return functional_ops.scan( - body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), - (initial_x, - array_ops.zeros(non_event_shape, dtype=initial_x.dtype), - -potential, -grad))[:2] + Args: + num_results: Integer number of Markov chain draws. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + num_burnin_steps: Integer number of chain steps to take before starting to + collect results. + Default value: 0 (i.e., no burn-in). + num_steps_between_results: Integer number of chain steps between collecting + a result. Only one out of every `num_steps_between_samples + 1` steps is + included in the returned results. This "thinning" comes at a cost of + reduced statistical power, while reducing memory requirements and + autocorrelation. For more discussion see [1]. + Default value: 0 (i.e., no subsampling). + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to specify + this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `target_log_prob` at the `current_state` and wrt + the `current_state`. Must have same shape as `current_state`. The only + reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_sample_chain"). + + Returns: + accepted_states: Tensor or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at each result step. Has same shape as + input `current_state` but with a prepended `num_results`-size dimension. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. + """ + with ops.name_scope( + name, "hmc_sample_chain", + [num_results, current_state, step_size, num_leapfrog_steps, + num_burnin_steps, num_steps_between_results, seed, + current_target_log_prob, current_grads_target_log_prob]): + with ops.name_scope("initialize"): + [ + current_state, + step_size, + current_target_log_prob, + current_grads_target_log_prob, + ] = _prepare_args( + target_log_prob_fn, current_state, step_size, + current_target_log_prob, current_grads_target_log_prob) + def _run_chain(num_steps, current_state, seed, kernel_results): + """Runs the chain(s) for `num_steps`.""" + def _loop_body(iter_, current_state, kernel_results): + return [iter_ + 1] + list(kernel( + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed, + kernel_results.current_target_log_prob, + kernel_results.current_grads_target_log_prob)) + return control_flow_ops.while_loop( + cond=lambda iter_, *args: iter_ < num_steps, + body=_loop_body, + loop_vars=[0, current_state, kernel_results])[1:] # Lop-off "iter_". + + def _scan_body(args_list, _): + """Closure which implements `tf.scan` body.""" + current_state, kernel_results = args_list + return _run_chain(num_steps_between_results + 1, current_state, seed, + kernel_results) + + current_state, kernel_results = _run_chain( + num_burnin_steps, + current_state, + distributions_util.gen_new_seed( + seed, salt="hmc_sample_chain_burnin"), + _make_dummy_kernel_results( + current_state, + current_target_log_prob, + current_grads_target_log_prob)) -def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, - target_log_prob_fn, proposal_log_prob_fn, event_dims=(), - name=None): + return functional_ops.scan( + fn=_scan_body, + elems=array_ops.zeros(num_results, dtype=dtypes.bool), # Dummy arg. + initializer=[current_state, kernel_results]) + + +def sample_annealed_importance_chain( + proposal_log_prob_fn, + num_steps, + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed=None, + name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. - This routine uses Hamiltonian Monte Carlo to sample from a series of + This function uses Hamiltonian Monte Carlo to sample from a series of distributions that slowly interpolates between an initial "proposal" - distribution + distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - and the target distribution + and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, @@ -202,113 +325,183 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, normalizing constants of the initial distribution and the target distribution: - E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer). - - Args: - n_iterations: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - initial_x: Tensor of initial state(s) of the Markov chain(s). Must - be a sample from q, or results will be incorrect. - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(initial_x)`. - chain_states: Tensor with the state(s) of the Markov chain(s) the final - iteration. Has shape matching `initial_x`. - acceptance_probs: Tensor with the acceptance probabilities for the final - iteration. Has shape matching `target_log_prob_fn(initial_x)`. + `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. #### Examples: + ##### Estimate the normalizing constant of a log-gamma distribution. + ```python - # Estimating the normalizing constant of a log-gamma distribution: - def proposal_log_prob(x): - # Standard normal log-probability. This is properly normalized. - return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1) - def target_log_prob(x): - # Unnormalized log-gamma(2, 3) distribution. - # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1] - return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1) + tfd = tf.contrib.distributions + # Run 100 AIS chains in parallel - initial_x = tf.random_normal([100, 20]) - w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob, - proposal_log_prob, event_dims=[1]) - log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + num_chains = 100 + dims = 20 + dtype = np.float32 + + proposal = tfd.MultivatiateNormalDiag( + loc=tf.zeros([dims], dtype=dtype)) + + target = tfd.TransformedDistribution( + distribution=tfd.Gamma(concentration=dtype(2), + rate=dtype(3)), + bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), + event_shape=[dims]) + + chains_state, ais_weights, kernels_results = ( + hmc.sample_annealed_importance_chain( + proposal_log_prob_fn=proposal.log_prob, + num_steps=1000, + target_log_prob_fn=target.log_prob, + step_size=0.2, + current_state=proposal.sample(num_chains), + num_leapfrog_steps=2)) + + log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) + - np.log(num_chains)) + log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` + ##### Estimate marginal likelihood of a Bayesian regression model. + ```python - # Estimating the marginal likelihood of a Bayesian regression model: - base_measure = -0.5 * np.log(2 * np.pi) - def proposal_log_prob(x): - # Standard normal log-probability. This is properly normalized. - return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1) - def regression_log_joint(beta, x, y): - # This function returns a vector whose ith element is log p(beta[i], y | x). - # Each row of beta corresponds to the state of an independent Markov chain. - log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1) - means = tf.matmul(beta, x, transpose_b=True) - log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) + - base_measure, 1) - return log_prior + log_likelihood - def log_joint_partial(beta): - return regression_log_joint(beta, x, y) + tfd = tf.contrib.distributions + + def make_prior(dims, dtype): + return tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + def make_likelihood(weights, x): + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(weights, x, axes=[[0], [-1]])) + # Run 100 AIS chains in parallel - initial_beta = tf.random_normal([100, x.shape[1]]) - w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta, - log_joint_partial, proposal_log_prob, - event_dims=[1]) - log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + num_chains = 100 + dims = 10 + dtype = np.float32 + + # Make training data. + x = np.random.randn(num_chains, dims).astype(dtype) + true_weights = np.random.randn(dims).astype(dtype) + y = np.dot(x, true_weights) + np.random.randn(num_chains) + + # Setup model. + prior = make_prior(dims, dtype) + def target_log_prob_fn(weights): + return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) + + proposal = tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + weight_samples, ais_weights, kernel_results = ( + hmc.sample_annealed_importance_chain( + num_steps=1000, + proposal_log_prob_fn=proposal.log_prob, + target_log_prob_fn=target_log_prob_fn + current_state=tf.zeros([num_chains, dims], dtype), + step_size=0.1, + num_leapfrog_steps=2)) + log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) + - np.log(num_chains)) ``` + + Args: + proposal_log_prob_fn: Python callable that returns the log density of the + initial distribution. + num_steps: Integer number of Markov chain updates to run. More + iterations means more expense, but smoother annealing between q + and p, which in turn means exponentially lower variance for the + normalizing constant estimator. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + seed: Python integer to seed the random number generator. + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). + + Returns: + accepted_state: `Tensor` or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at the final iteration. Has same shape as + input `current_state`. + ais_weights: Tensor with the estimated weight(s). Has shape matching + `target_log_prob_fn(current_state)`. """ - with ops.name_scope(name, 'hmc_ais_chain', - [n_iterations, step_size, n_leapfrog_steps, initial_x]): - non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) - - beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:] - def _body(a, beta): # pylint: disable=missing-docstring - def log_prob_beta(x): - return ((1 - beta) * proposal_log_prob_fn(x) + - beta * target_log_prob_fn(x)) - last_x = a[0] - w = a[2] - w += (1. / n_iterations) * (target_log_prob_fn(last_x) - - proposal_log_prob_fn(last_x)) - # TODO(b/66917083): There's an opportunity for gradient reuse here. - updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps, - last_x, log_prob_beta, - event_dims) - return updated_x, acceptance_probs, w - - x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, - (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), - array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) - return w[-1], x[-1], acceptance_probs[-1] - - -def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), - x_log_prob=None, x_grad=None, name=None): + def make_convex_combined_log_prob_fn(iter_): + def _fn(*args): + p = proposal_log_prob_fn(*args) + t = target_log_prob_fn(*args) + dtype = p.dtype.base_dtype + beta = (math_ops.cast(iter_ + 1, dtype) + / math_ops.cast(num_steps, dtype)) + return (1. - beta) * p + beta * t + return _fn + + with ops.name_scope( + name, "hmc_sample_annealed_importance_chain", + [num_steps, current_state, step_size, num_leapfrog_steps, seed]): + with ops.name_scope("initialize"): + [ + current_state, + step_size, + current_log_prob, + current_grads_log_prob, + ] = _prepare_args( + make_convex_combined_log_prob_fn(iter_=0), + current_state, + step_size, + description="convex_combined_log_prob") + def _loop_body(iter_, ais_weights, current_state, kernel_results): + """Closure which implements `tf.while_loop` body.""" + current_state_parts = (list(current_state) + if _is_list_like(current_state) + else [current_state]) + ais_weights += ((target_log_prob_fn(*current_state_parts) + - proposal_log_prob_fn(*current_state_parts)) + / math_ops.cast(num_steps, ais_weights.dtype)) + return [iter_ + 1, ais_weights] + list(kernel( + make_convex_combined_log_prob_fn(iter_), + current_state, + step_size, + num_leapfrog_steps, + seed, + kernel_results.current_target_log_prob, + kernel_results.current_grads_target_log_prob)) + + [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( + cond=lambda iter_, *args: iter_ < num_steps, + body=_loop_body, + loop_vars=[ + 0, # iter_ + array_ops.zeros_like(current_log_prob), # ais_weights + current_state, + _make_dummy_kernel_results(current_state, + current_log_prob, + current_grads_log_prob), + ])[1:] # Lop-off "iter_". + + return [current_state, ais_weights, kernel_results] + + +def kernel(target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed=None, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): """Runs one iteration of Hamiltonian Monte Carlo. Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) @@ -316,334 +509,625 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), a Metropolis proposal. This function applies one step of HMC to randomly update the variable `x`. - This function can update multiple chains in parallel. It assumes - that all dimensions of `x` not specified in `event_dims` are - independent, and should therefore be updated independently. The - output of `target_log_prob_fn()` should sum log-probabilities across - all event dimensions. Slices along dimensions not in `event_dims` - may have different target distributions; for example, if - `event_dims == (1,)`, then `x[0, :]` could have a different target - distribution from x[1, :]. This is up to `target_log_prob_fn()`. - - Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - x: Tensor containing the value(s) of the random variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - x_log_prob (optional): Tensor containing the cached output of a previous - call to `target_log_prob_fn()` evaluated at `x` (such as that provided by - a previous call to `kernel()`). Providing `x_log_prob` and - `x_grad` saves one gradient computation per call to `kernel()`. - x_grad (optional): Tensor containing the cached gradient of - `target_log_prob_fn()` evaluated at `x` (such as that provided by - a previous call to `kernel()`). Providing `x_log_prob` and - `x_grad` saves one gradient computation per call to `kernel()`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - updated_x: The updated variable(s) x. Has shape matching `initial_x`. - acceptance_probs: Tensor with the acceptance probabilities for the final - iteration. This is useful for diagnosing step size problems etc. Has - shape matching `target_log_prob_fn(initial_x)`. - new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`. - new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at - `updated_x`. + This function can update multiple chains in parallel. It assumes that all + leftmost dimensions of `current_state` index independent chain states (and are + therefore updated independently). The output of `target_log_prob_fn()` should + sum log-probabilities across all event dimensions. Slices along the rightmost + dimensions may have different target distributions; for example, + `current_state[0, :]` could have a different target distribution from + `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of + independent chains is `tf.size(target_log_prob_fn(*current_state))`.) #### Examples: + ##### Simple chain with warm-up. + ```python + tfd = tf.contrib.distributions + # Tuning acceptance rates: + dtype = np.float32 target_accept_rate = 0.631 - def target_log_prob(x): - # Standard normal - return tf.reduce_sum(-0.5 * tf.square(x)) - initial_x = tf.zeros([10]) - initial_log_prob = target_log_prob(initial_x) - initial_grad = tf.gradients(initial_log_prob, initial_x)[0] - # Algorithm state - x = tf.Variable(initial_x, name='x') - step_size = tf.Variable(1., name='step_size') - last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob') - last_grad = tf.Variable(initial_grad, name='last_grad') - # Compute updates - new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x, - target_log_prob, - event_dims=[0], - x_log_prob=last_log_prob) - x_update = tf.assign(x, new_x) - log_prob_update = tf.assign(last_log_prob, log_prob) - grad_update = tf.assign(last_grad, grad) - step_size_update = tf.assign(step_size, - tf.where(acceptance_prob > target_accept_rate, - step_size * 1.01, step_size / 1.01)) - adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update] - sampling_updates = [x_update, log_prob_update, grad_update] - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) + num_warmup_iter = 500 + num_chain_iter = 500 + + x = tf.get_variable(name="x", initializer=dtype(1)) + step_size = tf.get_variable(name="step_size", initializer=dtype(1)) + + target = tfd.Normal(loc=dtype(0), scale=dtype(1)) + + new_x, other_results = hmc.kernel( + target_log_prob_fn=target.log_prob, + current_state=x, + step_size=step_size, + num_leapfrog_steps=3)[:4] + + x_update = x.assign(new_x) + + step_size_update = step_size.assign_add( + step_size * tf.where( + other_results.acceptance_probs > target_accept_rate, + 0.01, -0.01)) + + warmup = tf.group([x_update, step_size_update]) + + tf.global_variables_initializer().run() + + sess.graph.finalize() # No more graph building. + # Warm up the sampler and adapt the step size - for i in xrange(500): - sess.run(adaptive_updates) + for _ in xrange(num_warmup_iter): + sess.run(warmup) + # Collect samples without adapting step size - samples = np.zeros([500, 10]) - for i in xrange(500): - x_val, _ = sess.run([new_x, sampling_updates]) - samples[i] = x_val + samples = np.zeros([num_chain_iter]) + for i in xrange(num_chain_iter): + _, x_, target_log_prob_, grad_ = sess.run([ + x_update, + x, + other_results.target_log_prob, + other_results.grads_target_log_prob]) + samples[i] = x_ + + print(samples.mean(), samples.std()) ``` - ```python - # Empirical-Bayes estimation of a hyperparameter by MCMC-EM: - - # Problem setup - N = 150 - D = 10 - x = np.random.randn(N, D).astype(np.float32) - true_sigma = 0.5 - true_beta = true_sigma * np.random.randn(D).astype(np.float32) - y = x.dot(true_beta) + np.random.randn(N).astype(np.float32) - - def log_prior(beta, log_sigma): - return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) - - log_sigma) - def regression_log_joint(beta, log_sigma, x, y): - # This function returns log p(beta | log_sigma) + log p(y | x, beta). - means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True) - means = tf.squeeze(means) - log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means)) - return log_prior(beta, log_sigma) + log_likelihood - def log_joint_partial(beta): - return regression_log_joint(beta, log_sigma, x, y) - # Our estimate of log(sigma) - log_sigma = tf.Variable(0., name='log_sigma') - # The state of the Markov chain - beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta') - new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial, - event_dims=[0]) - beta_update = tf.assign(beta, new_beta) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - with tf.control_dependencies([beta_update]): - log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma), - var_list=[log_sigma]) - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) - log_sigma_history = np.zeros(1000) - for i in xrange(1000): - log_sigma_val, _ = sess.run([log_sigma, log_sigma_update]) - log_sigma_history[i] = log_sigma_val - # Should converge to something close to true_sigma - plt.plot(np.exp(log_sigma_history)) - ``` - """ - with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): - potential_and_grad = _make_potential_and_grad(target_log_prob_fn) - x = ops.convert_to_tensor(x, name='x') - - x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape, dtype=x.dtype) - - kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) - - if (x_log_prob is not None) and (x_grad is not None): - log_potential_0, grad_0 = -x_log_prob, -x_grad # pylint: disable=invalid-unary-operand-type - else: - if x_log_prob is not None: - logging.warn('x_log_prob was provided, but x_grad was not,' - ' so x_log_prob was not used.') - if x_grad is not None: - logging.warn('x_grad was provided, but x_log_prob was not,' - ' so x_grad was not used.') - log_potential_0, grad_0 = potential_and_grad(x) - - new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator( - step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0) - - kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - - energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 - # Treat NaN as infinite energy (and therefore guaranteed rejection). - energy_change = array_ops.where( - math_ops.is_nan(energy_change), - array_ops.fill(array_ops.shape(energy_change), - energy_change.dtype.as_numpy_dtype(np.inf)), - energy_change) - acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) - accepted = ( - random_ops.random_uniform( - array_ops.shape(acceptance_probs), dtype=x.dtype) - < acceptance_probs) - new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) - - # TODO(b/65738010): This should work, but it doesn't for now. - # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) - reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, - keep_dims=True)) - accepted = array_ops.reshape(accepted, reduced_shape) - accepted = math_ops.logical_or( - accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) - new_x = array_ops.where(accepted, new_x, x) - new_grad = -array_ops.where(accepted, grad_1, grad_0) - - # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect - # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This - # should be fixed. - return new_x, acceptance_probs, new_log_prob, new_grad - - -def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, - potential_and_grad, initial_grad, name=None): - """Applies `n_steps` steps of the leapfrog integrator. - - This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing - gradient computations where possible. + ##### Sample from more complicated posterior. - Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_position`. Larger step sizes lead to faster progress, but - too-large step sizes lead to larger discretization error and - worse energy conservation. - n_steps: Number of steps to run the leapfrog integrator. - initial_position: Tensor containing the value(s) of the position variable(s) - to update. - initial_momentum: Tensor containing the value(s) of the momentum variable(s) - to update. - potential_and_grad: Python callable that takes a position tensor like - `initial_position` and returns the potential energy and its gradient at - that position. - initial_grad: Tensor with the value of the gradient of the potential energy - at `initial_position`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - updated_position: Updated value of the position. - updated_momentum: Updated value of the momentum. - new_potential: Potential energy of the new position. Has shape matching - `potential_and_grad(initial_position)`. - new_grad: Gradient from potential_and_grad() evaluated at the new position. - Has shape matching `initial_position`. + I.e., - Example: Simple quadratic potential. + ```none + W ~ MVN(loc=0, scale=sigma * eye(dims)) + for i=1...num_samples: + X[i] ~ MVN(loc=0, scale=eye(dims)) + eps[i] ~ Normal(loc=0, scale=1) + Y[i] = X[i].T * W + eps[i] + ``` ```python - def potential_and_grad(position): - return tf.reduce_sum(0.5 * tf.square(position)), position - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - potential, grad = potential_and_grad(position) - new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator( - 0.1, 3, position, momentum, potential_and_grad, grad) - - sess = tf.Session() - position_val = np.random.randn(10) - momentum_val = np.random.randn(10) - potential_val, grad_val = sess.run([potential, grad], - {position: position_val}) - positions = np.zeros([100, 10]) - for i in xrange(100): - position_val, momentum_val, potential_val, grad_val = sess.run( - [new_position, new_momentum, new_potential, new_grad], - {position: position_val, momentum: momentum_val}) - positions[i] = position_val - # Should trace out sinusoidal dynamics. - plt.plot(positions[:, 0]) - ``` - """ - def leapfrog_wrapper(step_size, x, m, grad, l): - x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad) - return step_size, x, m, grad, l + 1 + tfd = tf.contrib.distributions + + def make_training_data(num_samples, dims, sigma): + dt = np.asarray(sigma).dtype + zeros = tf.zeros(dims, dtype=dt) + x = tfd.MultivariateNormalDiag( + loc=zeros).sample(num_samples, seed=1) + w = tfd.MultivariateNormalDiag( + loc=zeros, + scale_identity_multiplier=sigma).sample(seed=2) + noise = tfd.Normal( + loc=dt(0), + scale=dt(1)).sample(num_samples, seed=3) + y = tf.tensordot(x, w, axes=[[1], [0]]) + noise + return y, x, w + + def make_prior(sigma, dims): + # p(w | sigma) + return tfd.MultivariateNormalDiag( + loc=tf.zeros([dims], dtype=sigma.dtype), + scale_identity_multiplier=sigma) + + def make_likelihood(x, w): + # p(y | x, w) + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(x, w, axes=[[1], [0]])) + + # Setup assumptions. + dtype = np.float32 + num_samples = 150 + dims = 10 + num_iters = int(5e3) + + true_sigma = dtype(0.5) + y, x, true_weights = make_training_data(num_samples, dims, true_sigma) + + # Estimate of `log(true_sigma)`. + log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) + sigma = tf.exp(log_sigma) + + # State of the Markov chain. + weights = tf.get_variable( + name="weights", + initializer=np.random.randn(dims).astype(dtype)) + + prior = make_prior(sigma, dims) + + def joint_log_prob_fn(w): + # f(w) = log p(w, y | x) + return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) + + weights_update = weights.assign( + hmc.kernel(target_log_prob_fn=joint_log_prob, + current_state=weights, + step_size=0.1, + num_leapfrog_steps=5)[0]) + + with tf.control_dependencies([weights_update]): + loss = -prior.log_prob(weights) + + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) + + sess.graph.finalize() # No more graph building. - def counter_fn(a, b, c, d, counter): # pylint: disable=unused-argument - return counter < n_steps + tf.global_variables_initializer().run() - with ops.name_scope(name, 'leapfrog_integrator', - [step_size, n_steps, initial_position, initial_momentum, - initial_grad]): - _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop( - counter_fn, leapfrog_wrapper, [step_size, initial_position, - initial_momentum, initial_grad, - array_ops.constant(0)], back_prop=False) - # We're counting on the runtime to eliminate this redundant computation. - new_potential, new_grad = potential_and_grad(new_x) - return new_x, new_m, new_potential, new_grad + sigma_history = np.zeros(num_iters, dtype) + weights_history = np.zeros([num_iters, dims], dtype) + for i in xrange(num_iters): + _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) + weights_history[i, :] = weights_ + sigma_history[i] = sigma_ -def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, - name=None): - """Applies one step of the leapfrog integrator. + true_weights_ = sess.run(true_weights) - Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2. + # Should converge to something close to true_sigma. + plt.plot(sigma_history); + plt.ylabel("sigma"); + plt.xlabel("iteration"); + ``` Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `position`. Larger step sizes lead to faster progress, but - too-large step sizes lead to larger discretization error and - worse energy conservation. - position: Tensor containing the value(s) of the position variable(s) - to update. - momentum: Tensor containing the value(s) of the momentum variable(s) - to update. - potential_and_grad: Python callable that takes a position tensor like - `position` and returns the potential energy and its gradient at that - position. - grad: Tensor with the value of the gradient of the potential energy - at `position`. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to + specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `current_target_log_prob` at the `current_state` + and wrt the `current_state`. Must have same shape as `current_state`. The + only reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_kernel"). Returns: - updated_position: Updated value of the position. - updated_momentum: Updated value of the momentum. - new_potential: Potential energy of the new position. Has shape matching - `potential_and_grad(position)`. - new_grad: Gradient from potential_and_grad() evaluated at the new position. - Has shape matching `position`. + accepted_state: Tensor or Python list of `Tensor`s representing the state(s) + of the Markov chain(s) at each result step. Has same shape as + `current_state`. + acceptance_probs: Tensor with the acceptance probabilities for each + iteration. Has shape matching `target_log_prob_fn(current_state)`. + accepted_target_log_prob: `Tensor` representing the value of + `target_log_prob_fn` at `accepted_state`. + accepted_grads_target_log_prob: Python `list` of `Tensor`s representing the + gradient of `accepted_target_log_prob` wrt each `accepted_state`. + + Raises: + ValueError: if there isn't one `step_size` or a list with same length as + `current_state`. + """ + with ops.name_scope( + name, "hmc_kernel", + [current_state, step_size, num_leapfrog_steps, seed, + current_target_log_prob, current_grads_target_log_prob]): + with ops.name_scope("initialize"): + [current_state_parts, step_sizes, current_target_log_prob, + current_grads_target_log_prob] = _prepare_args( + target_log_prob_fn, current_state, step_size, + current_target_log_prob, current_grads_target_log_prob, + maybe_expand=True) + independent_chain_ndims = distributions_util.prefer_static_rank( + current_target_log_prob) + def init_momentum(s): + return random_ops.random_normal( + shape=array_ops.shape(s), + dtype=s.dtype.base_dtype, + seed=distributions_util.gen_new_seed( + seed, salt="hmc_kernel_momentums")) + current_momentums = [init_momentum(s) for s in current_state_parts] + + [ + proposed_momentums, + proposed_state_parts, + proposed_target_log_prob, + proposed_grads_target_log_prob, + ] = _leapfrog_integrator(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + num_leapfrog_steps, + current_target_log_prob, + current_grads_target_log_prob) + + energy_change = _compute_energy_change(current_target_log_prob, + current_momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims) + + # u < exp(min(-energy, 0)), where u~Uniform[0,1) + # ==> -log(u) >= max(e, 0) + # ==> -log(u) >= e + # (Perhaps surprisingly, we don't have a better way to obtain a random + # uniform from positive reals, i.e., `tf.random_uniform(minval=0, + # maxval=np.inf)` won't work.) + random_uniform = random_ops.random_uniform( + shape=array_ops.shape(energy_change), + dtype=energy_change.dtype, + seed=seed) + random_positive = -math_ops.log(random_uniform) + is_accepted = random_positive >= energy_change + + accepted_target_log_prob = array_ops.where(is_accepted, + proposed_target_log_prob, + current_target_log_prob) + + accepted_state_parts = [_choose(is_accepted, + proposed_state_part, + current_state_part, + independent_chain_ndims) + for current_state_part, proposed_state_part + in zip(current_state_parts, proposed_state_parts)] + + accepted_grads_target_log_prob = [ + _choose(is_accepted, + proposed_grad, + grad, + independent_chain_ndims) + for proposed_grad, grad + in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] + + maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] + return [ + maybe_flatten(accepted_state_parts), + KernelResults( + acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), + current_grads_target_log_prob=accepted_grads_target_log_prob, + current_target_log_prob=accepted_target_log_prob, + energy_change=energy_change, + is_accepted=is_accepted, + proposed_grads_target_log_prob=proposed_grads_target_log_prob, + proposed_state=maybe_flatten(proposed_state_parts), + proposed_target_log_prob=proposed_target_log_prob, + random_positive=random_positive, + ), + ] + + +def _leapfrog_integrator(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + num_leapfrog_steps, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): + """Applies `num_leapfrog_steps` of the leapfrog integrator. + + Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. + + #### Examples: - Example: Simple quadratic potential. + ##### Simple quadratic potential. ```python - def potential_and_grad(position): - # Simple quadratic potential - return tf.reduce_sum(0.5 * tf.square(position)), position + tfd = tf.contrib.distributions + + dims = 10 + num_iter = int(1e3) + dtype = np.float32 + position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) - potential, grad = potential_and_grad(position) - new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step( - 0.1, position, momentum, potential_and_grad, grad) - - sess = tf.Session() - position_val = np.random.randn(10) - momentum_val = np.random.randn(10) - potential_val, grad_val = sess.run([potential, grad], - {position: position_val}) - positions = np.zeros([100, 10]) - for i in xrange(100): - position_val, momentum_val, potential_val, grad_val = sess.run( - [new_position, new_momentum, new_potential, new_grad], - {position: position_val, momentum: momentum_val}) - positions[i] = position_val - # Should trace out sinusoidal dynamics. - plt.plot(positions[:, 0]) + + [ + new_momentums, + new_positions, + ] = hmc._leapfrog_integrator( + current_momentums=[momentum], + target_log_prob_fn=tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)).log_prob, + current_state_parts=[position], + step_sizes=0.1, + num_leapfrog_steps=3)[:2] + + sess.graph.finalize() # No more graph building. + + momentum_ = np.random.randn(dims).astype(dtype) + position_ = np.random.randn(dims).astype(dtype) + + positions = np.zeros([num_iter, dims], dtype) + for i in xrange(num_iter): + position_, momentum_ = sess.run( + [new_momentums[0], new_position[0]], + feed_dict={position: position_, momentum: momentum_}) + positions[i] = position_ + + plt.plot(positions[:, 0]); # Sinusoidal. ``` + + Args: + current_momentums: Tensor containing the value(s) of the momentum + variable(s) to update. + target_log_prob_fn: Python callable which takes an argument like + `*current_state_parts` and returns its (possibly unnormalized) log-density + under the target distribution. + current_state_parts: Python `list` of `Tensor`s representing the current + state(s) of the Markov chain(s). The first `independent_chain_ndims` of + the `Tensor`(s) index different chains. + step_sizes: Python `list` of `Tensor`s representing the step size for the + leapfrog integrator. Must broadcast with the shape of + `current_state_parts`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. When + possible, it's often helpful to match per-variable step sizes to the + standard deviations of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn(*current_state_parts)`. The only reason to specify + this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `target_log_prob_fn(*current_state_parts`) wrt + `current_state_parts`. Must have same shape as `current_state_parts`. The + only reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_leapfrog_integrator"). + + Returns: + proposed_momentums: Updated value of the momentum. + proposed_state_parts: Tensor or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at each result step. Has same shape as + input `current_state_parts`. + proposed_target_log_prob: `Tensor` representing the value of + `target_log_prob_fn` at `accepted_state`. + proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt + `accepted_state`. + + Raises: + ValueError: if `len(momentums) != len(state_parts)`. + ValueError: if `len(state_parts) != len(step_sizes)`. + ValueError: if `len(state_parts) != len(grads_target_log_prob)`. + TypeError: if `not target_log_prob.dtype.is_floating`. """ - with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum, - grad]): - momentum -= 0.5 * step_size * grad - position += step_size * momentum - potential, grad = potential_and_grad(position) - momentum -= 0.5 * step_size * grad - - return position, momentum, potential, grad + def _loop_body(step, + current_momentums, + current_state_parts, + ignore_current_target_log_prob, # pylint: disable=unused-argument + current_grads_target_log_prob): + return [step + 1] + list(_leapfrog_step(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + current_grads_target_log_prob)) + + with ops.name_scope( + name, "hmc_leapfrog_integrator", + [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, + current_target_log_prob, current_grads_target_log_prob]): + if len(current_momentums) != len(current_state_parts): + raise ValueError("`momentums` must be in one-to-one correspondence " + "with `state_parts`") + num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, + name="num_leapfrog_steps") + current_target_log_prob, current_grads_target_log_prob = ( + _maybe_call_fn_and_grads( + target_log_prob_fn, + current_state_parts, + current_target_log_prob, + current_grads_target_log_prob)) + return control_flow_ops.while_loop( + cond=lambda iter_, *args: iter_ < num_leapfrog_steps, + body=_loop_body, + loop_vars=[ + 0, # iter_ + current_momentums, + current_state_parts, + current_target_log_prob, + current_grads_target_log_prob, + ], + back_prop=False)[1:] # Lop-off "iter_". + + +def _leapfrog_step(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + current_grads_target_log_prob, + name=None): + """Applies one step of the leapfrog integrator.""" + with ops.name_scope( + name, "_leapfrog_step", + [current_momentums, current_state_parts, step_sizes, + current_grads_target_log_prob]): + proposed_momentums = [m + 0.5 * ss * g for m, ss, g + in zip(current_momentums, + step_sizes, + current_grads_target_log_prob)] + proposed_state_parts = [x + ss * m for x, ss, m + in zip(current_state_parts, + step_sizes, + proposed_momentums)] + proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) + if not proposed_target_log_prob.dtype.is_floating: + raise TypeError("`target_log_prob_fn` must produce a `Tensor` " + "with `float` `dtype`.") + proposed_grads_target_log_prob = gradients_ops.gradients( + proposed_target_log_prob, proposed_state_parts) + if any(g is None for g in proposed_grads_target_log_prob): + raise ValueError( + "Encountered `None` gradient. Does your target `target_log_prob_fn` " + "access all `tf.Variable`s via `tf.get_variable`?\n" + " current_state_parts: {}\n" + " proposed_state_parts: {}\n" + " proposed_grads_target_log_prob: {}".format( + current_state_parts, + proposed_state_parts, + proposed_grads_target_log_prob)) + proposed_momentums = [m + 0.5 * ss * g for m, ss, g + in zip(proposed_momentums, + step_sizes, + proposed_grads_target_log_prob)] + return [ + proposed_momentums, + proposed_state_parts, + proposed_target_log_prob, + proposed_grads_target_log_prob, + ] + + +def _compute_energy_change(current_target_log_prob, + current_momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims, + name=None): + """Helper to `kernel` which computes the energy change.""" + with ops.name_scope( + name, "compute_energy_change", + ([current_target_log_prob, proposed_target_log_prob, + independent_chain_ndims] + + current_momentums + proposed_momentums)): + # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy + # since they're a mouthful and lets us inline more. + lk0, lk1 = [], [] + for current_momentum, proposed_momentum in zip(current_momentums, + proposed_momentums): + axis = math_ops.range(independent_chain_ndims, + array_ops.rank(current_momentum)) + lk0.append(_log_sum_sq(current_momentum, axis)) + lk1.append(_log_sum_sq(proposed_momentum, axis)) + + lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), + axis=-1) + lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), + axis=-1) + lp0 = -current_target_log_prob # log_potential + lp1 = -proposed_target_log_prob # proposed_log_potential + x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], + axis=-1) + + # The sum is NaN if any element is NaN or we see both +Inf and -Inf. + # Thus we will replace such rows with infinite energy change which implies + # rejection. Recall that float-comparisons with NaN are always False. + is_sum_determinate = ( + math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & + math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) + is_sum_determinate = array_ops.tile( + is_sum_determinate[..., array_ops.newaxis], + multiples=array_ops.concat([ + array_ops.ones(array_ops.rank(is_sum_determinate), + dtype=dtypes.int32), + [4], + ], axis=0)) + x = array_ops.where(is_sum_determinate, + x, + array_ops.fill(array_ops.shape(x), + value=x.dtype.as_numpy_dtype(np.inf))) + + return math_ops.reduce_sum(x, axis=-1) + + +def _choose(is_accepted, + accepted, + rejected, + independent_chain_ndims, + name=None): + """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" + def _expand_is_accepted_like(x): + with ops.name_scope("_choose"): + expand_shape = array_ops.concat([ + array_ops.shape(is_accepted), + array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], + dtype=dtypes.int32), + ], axis=0) + multiples = array_ops.concat([ + array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), + array_ops.shape(x)[independent_chain_ndims:], + ], axis=0) + m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), + multiples) + m.set_shape(x.shape) + return m + with ops.name_scope(name, "_choose", values=[ + is_accepted, accepted, rejected, independent_chain_ndims]): + return array_ops.where(_expand_is_accepted_like(accepted), + accepted, + rejected) + + +def _maybe_call_fn_and_grads(fn, + fn_arg_list, + fn_result=None, + grads_fn_result=None, + description="target_log_prob"): + """Helper which computes `fn_result` and `grads` if needed.""" + fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) + else [fn_arg_list]) + if fn_result is None: + fn_result = fn(*fn_arg_list) + if not fn_result.dtype.is_floating: + raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( + description)) + if grads_fn_result is None: + grads_fn_result = gradients_ops.gradients( + fn_result, fn_arg_list) + if len(fn_arg_list) != len(grads_fn_result): + raise ValueError("`{}` must be in one-to-one correspondence with " + "`grads_{}`".format(*[description]*2)) + if any(g is None for g in grads_fn_result): + raise ValueError("Encountered `None` gradient.") + return fn_result, grads_fn_result + + +def _prepare_args(target_log_prob_fn, state, step_size, + target_log_prob=None, grads_target_log_prob=None, + maybe_expand=False, description="target_log_prob"): + """Helper which processes input args to meet list-like assumptions.""" + state_parts = list(state) if _is_list_like(state) else [state] + state_parts = [ops.convert_to_tensor(s, name="state") + for s in state_parts] + target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( + target_log_prob_fn, + state_parts, + target_log_prob, + grads_target_log_prob, + description) + step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] + step_sizes = [ + ops.convert_to_tensor( + s, name="step_size", dtype=target_log_prob.dtype) + for s in step_sizes] + if len(step_sizes) == 1: + step_sizes *= len(state_parts) + if len(state_parts) != len(step_sizes): + raise ValueError("There should be exactly one `step_size` or it should " + "have same length as `current_state`.") + if maybe_expand: + maybe_flatten = lambda x: x + else: + maybe_flatten = lambda x: x if _is_list_like(state) else x[0] + return [ + maybe_flatten(state_parts), + maybe_flatten(step_sizes), + target_log_prob, + grads_target_log_prob, + ] + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) + + +def _log_sum_sq(x, axis=None): + """Computes log(sum(x**2)).""" + return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) -- 2.7.4