From: Joshua V. Dillon Date: Mon, 5 Feb 2018 18:45:33 +0000 (-0800) Subject: Automated g4 rollback of changelist 184323369 X-Git-Tag: upstream/v1.7.0~31^2~1010 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=16963086368463552d2e53b8cdac11c65482d13f;p=platform%2Fupstream%2Ftensorflow.git Automated g4 rollback of changelist 184323369 PiperOrigin-RevId: 184551259 --- diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index d9d0dfc..cbc66b6 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -19,36 +19,29 @@ from __future__ import division from __future__ import print_function import numpy as np +from scipy import special from scipy import stats from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl as gradients_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging_ops - - -def _reduce_variance(x, axis=None, keepdims=False): - sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) - return math_ops.reduce_mean( - math_ops.squared_difference(x, sample_mean), axis, keepdims) +from tensorflow.python.platform import tf_logging as logging +# TODO(b/66964210): Test float16. class HMCTest(test.TestCase): def setUp(self): self._shape_param = 5. self._rate_param = 10. + self._expected_x = (special.digamma(self._shape_param) + - np.log(self._rate_param)) + self._expected_exp_x = self._shape_param / self._rate_param random_seed.set_random_seed(10003) np.random.seed(10003) @@ -70,46 +63,63 @@ class HMCTest(test.TestCase): self._rate_param * math_ops.exp(x), event_dims) - def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, + def _log_gamma_log_prob_grad(self, x, event_dims=()): + """Computes log-pdf and gradient of a log-gamma random variable. + + Args: + x: Value of the random variable. + event_dims: Dimensions not to treat as independent. Default is (), + i.e., all dimensions are independent. + + Returns: + log_prob: The log-pdf up to a normalizing constant. + grad: The gradient of the log-pdf with respect to x. + """ + return (math_ops.reduce_sum(self._shape_param * x - + self._rate_param * math_ops.exp(x), + event_dims), + self._shape_param - self._rate_param * math_ops.exp(x)) + + def _n_event_dims(self, x_shape, event_dims): + return np.prod([int(x_shape[i]) for i in event_dims]) + + def _integrator_conserves_energy(self, x, event_dims, sess, feed_dict=None): - step_size = array_ops.placeholder(np.float32, [], name="step_size") - hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") + def potential_and_grad(x): + log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims) + return -log_prob, -grad + + step_size = array_ops.placeholder(np.float32, [], name='step_size') + hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') if feed_dict is None: feed_dict = {} feed_dict[hmc_lf_steps] = 1000 - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - m = random_ops.random_normal(array_ops.shape(x)) - log_prob_0 = self._log_gamma_log_prob(x, event_dims) - grad_0 = gradients_ops.gradients(log_prob_0, x) - old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) - - new_m, _, log_prob_1, _ = _leapfrog_integrator( - current_momentums=[m], - target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), - current_state_parts=[x], - step_sizes=[step_size], - num_leapfrog_steps=hmc_lf_steps, - current_target_log_prob=log_prob_0, - current_grads_target_log_prob=grad_0) - new_m = new_m[0] - - new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, + potential_0, grad_0 = potential_and_grad(x) + old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m, + event_dims) + + _, new_m, potential_1, _ = ( + hmc.leapfrog_integrator(step_size, hmc_lf_steps, x, + m, potential_and_grad, grad_0)) + + new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, event_dims) x_shape = sess.run(x, feed_dict).shape - event_size = np.prod(x_shape[independent_chain_ndims:]) - feed_dict[step_size] = 0.1 / event_size - old_energy_, new_energy_ = sess.run([old_energy, new_energy], - feed_dict) - logging_ops.vlog(1, "average energy relative change: {}".format( - (1. - new_energy_ / old_energy_).mean())) - self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) - - def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): + n_event_dims = self._n_event_dims(x_shape, event_dims) + feed_dict[step_size] = 0.1 / n_event_dims + old_energy_val, new_energy_val = sess.run([old_energy, new_energy], + feed_dict) + logging.vlog(1, 'average energy change: {}'.format( + abs(old_energy_val - new_energy_val).mean())) + + self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool), + abs(old_energy_val - new_energy_val) < 1.) + + def _integrator_conserves_energy_wrapper(self, event_dims): """Tests the long-term energy conservation of the leapfrog integrator. The leapfrog integrator is symplectic, so for sufficiently small step @@ -117,167 +127,135 @@ class HMCTest(test.TestCase): the energy of the system blowing up or collapsing. Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. + event_dims: A tuple of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. """ with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._integrator_conserves_energy(x_ph, independent_chain_ndims, - sess, feed_dict) + x_ph = array_ops.placeholder(np.float32, name='x_ph') + + feed_dict = {x_ph: np.zeros([50, 10, 2])} + self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict) def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper(0) + self._integrator_conserves_energy_wrapper([]) def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper(1) + self._integrator_conserves_energy_wrapper([1]) def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper(2) + self._integrator_conserves_energy_wrapper([2]) - def testIntegratorEnergyConservation3(self): - self._integrator_conserves_energy_wrapper(3) + def testIntegratorEnergyConservation12(self): + self._integrator_conserves_energy_wrapper([1, 2]) - def _chain_gets_correct_expectations(self, x, independent_chain_ndims, - sess, feed_dict=None): + def testIntegratorEnergyConservation012(self): + self._integrator_conserves_energy_wrapper([0, 1, 2]) + + def _chain_gets_correct_expectations(self, x, event_dims, sess, + feed_dict=None): def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) - num_results = array_ops.placeholder( - np.int32, [], name="num_results") - step_size = array_ops.placeholder( - np.float32, [], name="step_size") - num_leapfrog_steps = array_ops.placeholder( - np.int32, [], name="num_leapfrog_steps") + step_size = array_ops.placeholder(np.float32, [], name='step_size') + hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') + hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps') if feed_dict is None: feed_dict = {} - feed_dict.update({num_results: 150, - step_size: 0.1, - num_leapfrog_steps: 2}) - - samples, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_gamma_log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - num_burnin_steps=150, - seed=42) - - expected_x = (math_ops.digamma(self._shape_param) - - np.log(self._rate_param)) - - expected_exp_x = self._shape_param / self._rate_param - - acceptance_probs_, samples_, expected_x_ = sess.run( - [kernel_results.acceptance_probs, samples, expected_x], - feed_dict) - - actual_x = samples_.mean() - actual_exp_x = np.exp(samples_).mean() - - logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( - expected_x_, expected_exp_x)) - logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( - actual_x, actual_exp_x)) - self.assertNear(actual_x, expected_x_, 2e-2) - self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertTrue((acceptance_probs_ > 0.5).all()) - self.assertTrue((acceptance_probs_ <= 1.0).all()) - - def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): + feed_dict.update({step_size: 0.1, + hmc_lf_steps: 2, + hmc_n_steps: 300}) + + sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps], + step_size, + hmc_lf_steps, + x, log_gamma_log_prob, + event_dims) + + acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain], + feed_dict) + samples = samples[feed_dict[hmc_n_steps] // 2:] + expected_x_est = samples.mean() + expected_exp_x_est = np.exp(samples).mean() + + logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format( + self._expected_x, self._expected_exp_x)) + logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format( + expected_x_est, expected_exp_x_est)) + self.assertNear(expected_x_est, self._expected_x, 2e-2) + self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2) + self.assertTrue((acceptance_probs > 0.5).all()) + self.assertTrue((acceptance_probs <= 1.0).all()) + + def _chain_gets_correct_expectations_wrapper(self, event_dims): with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, - sess, feed_dict) + x_ph = array_ops.placeholder(np.float32, name='x_ph') + + feed_dict = {x_ph: np.zeros([50, 10, 2])} + self._chain_gets_correct_expectations(x_ph, event_dims, sess, + feed_dict) def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper(0) + self._chain_gets_correct_expectations_wrapper([]) def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper(1) + self._chain_gets_correct_expectations_wrapper([1]) def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper(2) + self._chain_gets_correct_expectations_wrapper([2]) + + def testHMCChainExpectations12(self): + self._chain_gets_correct_expectations_wrapper([1, 2]) - def _kernel_leaves_target_invariant(self, initial_draws, - independent_chain_ndims, + def _kernel_leaves_target_invariant(self, initial_draws, event_dims, sess, feed_dict=None): def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) def fake_log_prob(x): """Cooled version of the target distribution.""" return 1.1 * log_gamma_log_prob(x) - step_size = array_ops.placeholder(np.float32, [], name="step_size") + step_size = array_ops.placeholder(np.float32, [], name='step_size') if feed_dict is None: feed_dict = {} feed_dict[step_size] = 0.4 - sample, kernel_results = hmc.kernel( - target_log_prob_fn=log_gamma_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=43) - - bad_sample, bad_kernel_results = hmc.kernel( - target_log_prob_fn=fake_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=44) - - [ - acceptance_probs_, - bad_acceptance_probs_, - initial_draws_, - updated_draws_, - fake_draws_, - ] = sess.run([ - kernel_results.acceptance_probs, - bad_kernel_results.acceptance_probs, - initial_draws, - sample, - bad_sample, - ], feed_dict) - + sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws, + log_gamma_log_prob, event_dims) + bad_sample, bad_acceptance_probs, _, _ = hmc.kernel( + step_size, 5, initial_draws, fake_log_prob, event_dims) + (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val, + updated_draws_val, fake_draws_val) = sess.run([acceptance_probs, + bad_acceptance_probs, + initial_draws, sample, + bad_sample], feed_dict) # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_.mean(), 0.5) - + self.assertGreater(acceptance_probs_val.mean(), 0.5) + self.assertGreater(bad_acceptance_probs_val.mean(), 0.5) # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_.mean(), 0.99) - self.assertLess(bad_acceptance_probs_.mean(), 0.99) - - _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), - updated_draws_.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), - fake_draws_.flatten()) - - logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs_.mean())) - logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs_.mean())) - logging_ops.vlog(1, "K-S p-value for true target: {}".format( - ks_p_value_true)) - logging_ops.vlog(1, "K-S p-value for fake target: {}".format( - ks_p_value_fake)) + self.assertLess(acceptance_probs_val.mean(), 0.99) + self.assertLess(bad_acceptance_probs_val.mean(), 0.99) + _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(), + updated_draws_val.flatten()) + _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(), + fake_draws_val.flatten()) + logging.vlog(1, 'acceptance rate for true target: {}'.format( + acceptance_probs_val.mean())) + logging.vlog(1, 'acceptance rate for fake target: {}'.format( + bad_acceptance_probs_val.mean())) + logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true)) + logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake)) # Make sure that the MCMC update hasn't changed the empirical CDF much. self.assertGreater(ks_p_value_true, 1e-3) # Confirm that targeting the wrong distribution does # significantly change the empirical CDF. self.assertLess(ks_p_value_fake, 1e-6) - def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): + def _kernel_leaves_target_invariant_wrapper(self, event_dims): """Tests that the kernel leaves the target distribution invariant. Draws some independent samples from the target distribution, @@ -289,116 +267,86 @@ class HMCTest(test.TestCase): does change the target distribution. (And that we can detect that.) Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. + event_dims: A tuple of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. """ with self.test_session() as sess: initial_draws = np.log(np.random.gamma(self._shape_param, size=[50000, 2, 2])) initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name="x_ph") + x_ph = array_ops.placeholder(np.float32, name='x_ph') feed_dict = {x_ph: initial_draws} - self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, - sess, feed_dict) + self._kernel_leaves_target_invariant(x_ph, event_dims, sess, + feed_dict) + + def testKernelLeavesTargetInvariantNullShape(self): + self._kernel_leaves_target_invariant_wrapper([]) def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper(1) + self._kernel_leaves_target_invariant_wrapper([1]) def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper(2) + self._kernel_leaves_target_invariant_wrapper([2]) - def testKernelLeavesTargetInvariant3(self): - self._kernel_leaves_target_invariant_wrapper(3) + def testKernelLeavesTargetInvariant12(self): + self._kernel_leaves_target_invariant_wrapper([1, 2]) - def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, - sess, feed_dict=None): + def _ais_gets_correct_log_normalizer(self, init, event_dims, sess, + feed_dict=None): def proposal_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) + return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi), + event_dims) def target_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) if feed_dict is None: feed_dict = {} - num_steps = 200 - - _, ais_weights, _ = hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal_log_prob, - num_steps=num_steps, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=init, - num_leapfrog_steps=2, - seed=45) - - event_shape = array_ops.shape(init)[independent_chain_ndims:] - event_size = math_ops.reduce_prod(event_shape) - - log_true_normalizer = ( - -self._shape_param * math_ops.log(self._rate_param) - + math_ops.lgamma(self._shape_param)) - log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) - - log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) - - np.log(num_steps)) - - ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) - ais_weights_size = array_ops.size(ais_weights) - standard_error = math_ops.sqrt( - _reduce_variance(ratio_estimate_true) - / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) - - [ - ratio_estimate_true_, - log_true_normalizer_, - log_estimated_normalizer_, - standard_error_, - ais_weights_size_, - event_size_, - ] = sess.run([ - ratio_estimate_true, - log_true_normalizer, - log_estimated_normalizer, - standard_error, - ais_weights_size, - event_size, - ], feed_dict) - - logging_ops.vlog(1, " log_true_normalizer: {}\n" - " log_estimated_normalizer: {}\n" - " ais_weights_size: {}\n" - " event_size: {}\n".format( - log_true_normalizer_, - log_estimated_normalizer_, - ais_weights_size_, - event_size_)) - self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) - - def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): + w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob, + proposal_log_prob, event_dims) + + w_val = sess.run(w, feed_dict) + init_shape = sess.run(init, feed_dict).shape + normalizer_multiplier = np.prod([init_shape[i] for i in event_dims]) + + true_normalizer = -self._shape_param * np.log(self._rate_param) + true_normalizer += special.gammaln(self._shape_param) + true_normalizer *= normalizer_multiplier + + n_weights = np.prod(w_val.shape) + normalized_w = np.exp(w_val - true_normalizer) + standard_error = np.std(normalized_w) / np.sqrt(n_weights) + logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format( + true_normalizer, np.log(normalized_w.mean()) + true_normalizer, + n_weights)) + self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error) + + def _ais_gets_correct_log_normalizer_wrapper(self, event_dims): """Tests that AIS yields reasonable estimates of normalizers.""" with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") + x_ph = array_ops.placeholder(np.float32, name='x_ph') + initial_draws = np.random.normal(size=[30, 2, 1]) - self._ais_gets_correct_log_normalizer( - x_ph, - independent_chain_ndims, - sess, - feed_dict={x_ph: initial_draws}) + feed_dict = {x_ph: initial_draws} + + self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess, + feed_dict) + + def testAISNullShape(self): + self._ais_gets_correct_log_normalizer_wrapper([]) def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper(1) + self._ais_gets_correct_log_normalizer_wrapper([1]) def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper(2) + self._ais_gets_correct_log_normalizer_wrapper([2]) - def testAIS3(self): - self._ais_gets_correct_log_normalizer_wrapper(3) + def testAIS12(self): + self._ais_gets_correct_log_normalizer_wrapper([1, 2]) def testNanRejection(self): """Tests that an update that yields NaN potentials gets rejected. @@ -411,29 +359,24 @@ class HMCTest(test.TestCase): """ def _unbounded_exponential_log_prob(x): """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where( - x < 0., - array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), - -x) + per_element_potentials = array_ops.where(x < 0, + np.nan * array_ops.ones_like(x), + -x) return math_ops.reduce_sum(per_element_potentials) with self.test_session() as sess: initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_unbounded_exponential_log_prob, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=46) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) + updated_x, acceptance_probs, _, _ = hmc.kernel( + 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) def testNanFromGradsDontPropagate(self): """Test that update with NaN gradients does not cause NaN in results.""" @@ -442,195 +385,60 @@ class HMCTest(test.TestCase): with self.test_session() as sess: initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_nan_log_prob_with_nan_gradient, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=47) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) + updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( + 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) self.assertAllFinite( - gradients_ops.gradients(updated_x, initial_x)[0].eval()) - self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, initial_x)]) - self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, - kernel_results.proposed_state)]) + gradients_impl.gradients(updated_x, initial_x)[0].eval()) + self.assertTrue( + gradients_impl.gradients(new_grad, initial_x)[0] is None) # Gradients of the acceptance probs and new log prob are not finite. + _ = new_log_prob # Prevent unused arg error. # self.assertAllFinite( - # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) + # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) # self.assertAllFinite( - # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) - - def _testChainWorksDtype(self, dtype): - states, kernel_results = hmc.sample_chain( - num_results=10, - target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), - current_state=np.zeros(5).astype(dtype), - step_size=0.01, - num_leapfrog_steps=10, - seed=48) - with self.test_session() as sess: - states_, acceptance_probs_ = sess.run( - [states, kernel_results.acceptance_probs]) - self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, acceptance_probs_.dtype) + # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) def testChainWorksIn64Bit(self): - self._testChainWorksDtype(np.float64) + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float64(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float64), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float64, states_.dtype) + self.assertEqual(np.float64, acceptance_probs_.dtype) def testChainWorksIn16Bit(self): - self._testChainWorksDtype(np.float16) - - -class _EnergyComputationTest(object): - - def testHandlesNanFromPotential(self): - with self.test_session() as sess: - x = [1, np.inf, -np.inf, np.nan] - target_log_prob, proposed_target_log_prob = [ - self.dtype(x.flatten()) for x in np.meshgrid(x, x)] - num_chains = len(target_log_prob) - dummy_momentums = [-1, 1] - momentums = [self.dtype([dummy_momentums] * num_chains)] - proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - self.assertAllEqual(np.ones_like(grads_).astype(np.bool), - np.isfinite(grads_)) - - def testHandlesNanFromKinetic(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float16(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float16), + target_log_prob_fn=log_prob, + event_dims=[-1]) with self.test_session() as sess: - x = [1, np.inf, -np.inf, np.nan] - momentums, proposed_momentums = [ - [np.reshape(self.dtype(x), [-1, 1])] - for x in np.meshgrid(x, x)] - num_chains = len(momentums[0]) - target_log_prob = np.ones(num_chains, self.dtype) - proposed_target_log_prob = np.ones(num_chains, self.dtype) + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float16, states_.dtype) + self.assertEqual(np.float16, acceptance_probs_.dtype) - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - g = grads_[0].reshape([len(x), len(x)])[:, 0] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) - - # The remaining gradients are nan because the momentum was itself nan or - # inf. - g = grads_[0].reshape([len(x), len(x)])[:, 1:] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) - - -class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): - dtype = np.float16 - - -class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): - dtype = np.float32 - - -class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): - dtype = np.float64 - - -class _HMCHandlesLists(object): - - def testStateParts(self): - with self.test_session() as sess: - dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) - dist_y = independent_lib.Independent( - gamma_lib.Gamma(concentration=self.dtype([1, 2]), - rate=self.dtype([0.5, 0.75])), - reinterpreted_batch_ndims=1) - def target_log_prob(x, y): - return dist_x.log_prob(x) + dist_y.log_prob(y) - x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] - samples, _ = hmc.sample_chain( - num_results=int(2e3), - target_log_prob_fn=target_log_prob, - current_state=x0, - step_size=0.85, - num_leapfrog_steps=3, - num_burnin_steps=int(250), - seed=49) - actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] - actual_vars = [_reduce_variance(s, axis=0) for s in samples] - expected_means = [dist_x.mean(), dist_y.mean()] - expected_vars = [dist_x.variance(), dist_y.variance()] - [ - actual_means_, - actual_vars_, - expected_means_, - expected_vars_, - ] = sess.run([ - actual_means, - actual_vars, - expected_means, - expected_vars, - ]) - self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) - self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.30) - - -class HMCHandlesLists16(_HMCHandlesLists, test.TestCase): - dtype = np.float16 - - -class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): - dtype = np.float32 - - -class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): - dtype = np.float64 - - -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py index 7fd5652..977d42f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" +"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. +""" from __future__ import absolute_import from __future__ import division @@ -23,9 +24,11 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disabl from tensorflow.python.util import all_util _allowed_symbols = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", + 'chain', + 'kernel', + 'leapfrog_integrator', + 'leapfrog_step', + 'ais_chain' ] all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index f7a11c2..5685a94 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -14,16 +14,17 @@ # ============================================================================== """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -@@sample_chain -@@sample_annealed_importance_chain -@@kernel +@@chain +@@update +@@leapfrog_integrator +@@leapfrog_step +@@ais_chain """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import numpy as np from tensorflow.python.framework import dtypes @@ -31,292 +32,168 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl as gradients_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import util as distributions_util +from tensorflow.python.platform import tf_logging as logging __all__ = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", + 'chain', + 'kernel', + 'leapfrog_integrator', + 'leapfrog_step', + 'ais_chain' ] -KernelResults = collections.namedtuple( - "KernelResults", - [ - "acceptance_probs", - "current_grads_target_log_prob", # "Current result" means "accepted". - "current_target_log_prob", # "Current result" means "accepted". - "energy_change", - "is_accepted", - "proposed_grads_target_log_prob", - "proposed_state", - "proposed_target_log_prob", - "random_positive", - ]) - - -def _make_dummy_kernel_results( - dummy_state, - dummy_target_log_prob, - dummy_grads_target_log_prob): - return KernelResults( - acceptance_probs=dummy_target_log_prob, - current_grads_target_log_prob=dummy_grads_target_log_prob, - current_target_log_prob=dummy_target_log_prob, - energy_change=dummy_target_log_prob, - is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), - proposed_grads_target_log_prob=dummy_grads_target_log_prob, - proposed_state=dummy_state, - proposed_target_log_prob=dummy_target_log_prob, - random_positive=dummy_target_log_prob, - ) - - -def sample_chain( - num_results, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - num_burnin_steps=0, - num_steps_between_results=0, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): +def _make_potential_and_grad(target_log_prob_fn): + def potential_and_grad(x): + log_prob_result = -target_log_prob_fn(x) + grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result), + x)[0] + return log_prob_result, grad_result + return potential_and_grad + + +def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, + target_log_prob_fn, event_dims=(), name=None): """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm - that takes a series of gradient-informed steps to produce a Metropolis - proposal. This function samples from an HMC Markov chain at `current_state` - and whose stationary distribution has log-unnormalized-density + Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) + algorithm that takes a series of gradient-informed steps to produce + a Metropolis proposal. This function samples from an HMC Markov + chain whose initial state is `initial_x` and whose stationary + distribution has log-density `target_log_prob_fn()`. + + This function can update multiple chains in parallel. It assumes + that all dimensions of `initial_x` not specified in `event_dims` are + independent, and should therefore be updated independently. The + output of `target_log_prob_fn()` should sum log-probabilities across + all event dimensions. Slices along dimensions not in `event_dims` + may have different target distributions; this is up to `target_log_prob_fn()`. - This function samples from multiple chains in parallel. It assumes that the - the leftmost dimensions of (each) `current_state` (part) index an independent - chain. The function `target_log_prob_fn()` sums log-probabilities across - event dimensions (i.e., current state (part) rightmost dimensions). Each - element of the output of `target_log_prob_fn()` represents the (possibly - unnormalized) log-probability of the joint distribution over (all) the current - state (parts). + This function basically just wraps `hmc.kernel()` in a tf.scan() loop. - The `current_state` can be represented as a single `Tensor` or a `list` of - `Tensors` which collectively represent the current state. When specifying a - `list`, one must also specify a list of `step_size`s. - - Only one out of every `num_steps_between_samples + 1` steps is included in the - returned results. This "thinning" comes at a cost of reduced statistical - power, while reducing memory requirements and autocorrelation. For more - discussion see [1]. + Args: + n_iterations: Integer number of Markov chain updates to run. + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `initial_x`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. + When possible, it's often helpful to match per-variable step + sizes to the standard deviations of the target distribution in + each variable. + n_leapfrog_steps: Integer number of steps to run the leapfrog + integrator for. Total progress per HMC step is roughly + proportional to step_size * n_leapfrog_steps. + initial_x: Tensor of initial state(s) of the Markov chain(s). + target_log_prob_fn: Python callable which takes an argument like `initial_x` + and returns its (possibly unnormalized) log-density under the target + distribution. + event_dims: List of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + name: Python `str` name prefixed to Ops created by this function. - [1]: "Statistically efficient thinning of a Markov chain sampler." - Art B. Owen. April 2017. - http://statweb.stanford.edu/~owen/reports/bestthinning.pdf + Returns: + acceptance_probs: Tensor with the acceptance probabilities for each + iteration. Has shape matching `target_log_prob_fn(initial_x)`. + chain_states: Tensor with the state of the Markov chain at each iteration. + Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`. #### Examples: - ##### Sample from a diagonal-variance Gaussian. - ```python - tfd = tf.contrib.distributions - - def make_likelihood(true_variances): - return tfd.MultivariateNormalDiag( - scale_diag=tf.sqrt(true_variances)) - - dims = 10 - dtype = np.float32 - true_variances = tf.linspace(dtype(1), dtype(3), dims) - likelihood = make_likelihood(true_variances) - - states, kernel_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=likelihood.log_prob, - current_state=tf.zeros(dims), - step_size=0.5, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(states, axis=0) - sample_var = tf.reduce_mean( - tf.squared_difference(states, sample_mean), - axis=0) + # Sampling from a standard normal (note `log_joint()` is unnormalized): + def log_joint(x): + return tf.reduce_sum(-0.5 * tf.square(x)) + chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, + event_dims=[0]) + # Discard first half of chain as warmup/burn-in + warmed_up = chain[500:] + mean_est = tf.reduce_mean(warmed_up, 0) + var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) ``` - ##### Sampling from factor-analysis posteriors with known factors. - - I.e., - - ```none - for i=1..n: - w[i] ~ Normal(0, eye(d)) # prior - x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood + ```python + # Sampling from a diagonal-variance Gaussian: + variances = tf.linspace(1., 3., 10) + def log_joint(x): + return tf.reduce_sum(-0.5 / variances * tf.square(x)) + chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, + event_dims=[0]) + # Discard first half of chain as warmup/burn-in + warmed_up = chain[500:] + mean_est = tf.reduce_mean(warmed_up, 0) + var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) ``` - where `F` denotes factors. - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, factors): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) - - # Setup data. - num_weights = 10 - num_factors = 4 - num_chains = 100 - dtype = np.float32 - - prior = make_prior(num_weights, dtype) - weights = prior.sample(num_chains) - factors = np.random.randn(num_factors, num_weights).astype(dtype) - x = make_likelihood(weights, factors).sample(num_chains) - - def target_log_prob(w): - # Target joint is: `f(w) = p(w, x | factors)`. - return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) - - # Get `num_results` samples from `num_chains` independent chains. - chains_states, kernels_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target_log_prob, - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) - sample_var = tf.reduce_mean( - tf.squared_difference(chains_states, sample_mean), - axis=[0, 1]) + # Sampling from factor-analysis posteriors with known factors W: + # mu[i, j] ~ Normal(0, 1) + # x[i] ~ Normal(matmul(mu[i], W), I) + def log_joint(mu, x, W): + prior = -0.5 * tf.reduce_sum(tf.square(mu), 1) + x_mean = tf.matmul(mu, W) + likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1) + return prior + likelihood + chain, acceptance_probs = hmc.chain(1000, 0.1, 2, + tf.zeros([x.shape[0], W.shape[0]]), + lambda mu: log_joint(mu, x, W), + event_dims=[1]) + # Discard first half of chain as warmup/burn-in + warmed_up = chain[500:] + mean_est = tf.reduce_mean(warmed_up, 0) + var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) ``` - Args: - num_results: Integer number of Markov chain draws. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - num_burnin_steps: Integer number of chain steps to take before starting to - collect results. - Default value: 0 (i.e., no burn-in). - num_steps_between_results: Integer number of chain steps between collecting - a result. Only one out of every `num_steps_between_samples + 1` steps is - included in the returned results. This "thinning" comes at a cost of - reduced statistical power, while reducing memory requirements and - autocorrelation. For more discussion see [1]. - Default value: 0 (i.e., no subsampling). - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob` at the `current_state` and wrt - the `current_state`. Must have same shape as `current_state`. The only - reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_chain"). - - Returns: - accepted_states: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state` but with a prepended `num_results`-size dimension. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. + ```python + # Sampling from the posterior of a Bayesian regression model.: + + # Run 100 chains in parallel, each with a different initialization. + initial_beta = tf.random_normal([100, x.shape[1]]) + chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta, + log_joint_partial, event_dims=[1]) + # Discard first halves of chains as warmup/burn-in + warmed_up = chain[500:] + # Averaging across samples within a chain and across chains + mean_est = tf.reduce_mean(warmed_up, [0, 1]) + var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est) + ``` """ - with ops.name_scope( - name, "hmc_sample_chain", - [num_results, current_state, step_size, num_leapfrog_steps, - num_burnin_steps, num_steps_between_results, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob, - ] = _prepare_args( - target_log_prob_fn, current_state, step_size, - current_target_log_prob, current_grads_target_log_prob) - def _run_chain(num_steps, current_state, seed, kernel_results): - """Runs the chain(s) for `num_steps`.""" - def _loop_body(iter_, current_state, kernel_results): - return [iter_ + 1] + list(kernel( - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - return control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[0, current_state, kernel_results])[1:] # Lop-off "iter_". - - def _scan_body(args_list, _): - """Closure which implements `tf.scan` body.""" - current_state, kernel_results = args_list - return _run_chain(num_steps_between_results + 1, current_state, seed, - kernel_results) - - current_state, kernel_results = _run_chain( - num_burnin_steps, - current_state, - distributions_util.gen_new_seed( - seed, salt="hmc_sample_chain_burnin"), - _make_dummy_kernel_results( - current_state, - current_target_log_prob, - current_grads_target_log_prob)) - + with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size, + n_leapfrog_steps, initial_x]): + initial_x = ops.convert_to_tensor(initial_x, name='initial_x') + non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) + + def body(a, _): + updated_x, acceptance_probs, log_prob, grad = kernel( + step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims, + a[2], a[3]) + return updated_x, acceptance_probs, log_prob, grad + + potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + potential, grad = potential_and_grad(initial_x) return functional_ops.scan( - fn=_scan_body, - elems=array_ops.zeros(num_results, dtype=dtypes.bool), # Dummy arg. - initializer=[current_state, kernel_results]) - - -def sample_annealed_importance_chain( - proposal_log_prob_fn, - num_steps, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - name=None): + body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), + (initial_x, + array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + -potential, -grad))[:2] + + +def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, + target_log_prob_fn, proposal_log_prob_fn, event_dims=(), + name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. - This function uses Hamiltonian Monte Carlo to sample from a series of + This routine uses Hamiltonian Monte Carlo to sample from a series of distributions that slowly interpolates between an initial "proposal" - distribution: + distribution `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - and the target distribution: + and the target distribution `exp(target_log_prob_fn(x) - target_log_normalizer)`, @@ -325,183 +202,113 @@ def sample_annealed_importance_chain( normalizing constants of the initial distribution and the target distribution: - `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. + E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer). - #### Examples: + Args: + n_iterations: Integer number of Markov chain updates to run. More + iterations means more expense, but smoother annealing between q + and p, which in turn means exponentially lower variance for the + normalizing constant estimator. + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `initial_x`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. + When possible, it's often helpful to match per-variable step + sizes to the standard deviations of the target distribution in + each variable. + n_leapfrog_steps: Integer number of steps to run the leapfrog + integrator for. Total progress per HMC step is roughly + proportional to step_size * n_leapfrog_steps. + initial_x: Tensor of initial state(s) of the Markov chain(s). Must + be a sample from q, or results will be incorrect. + target_log_prob_fn: Python callable which takes an argument like `initial_x` + and returns its (possibly unnormalized) log-density under the target + distribution. + proposal_log_prob_fn: Python callable that returns the log density of the + initial distribution. + event_dims: List of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + ais_weights: Tensor with the estimated weight(s). Has shape matching + `target_log_prob_fn(initial_x)`. + chain_states: Tensor with the state(s) of the Markov chain(s) the final + iteration. Has shape matching `initial_x`. + acceptance_probs: Tensor with the acceptance probabilities for the final + iteration. Has shape matching `target_log_prob_fn(initial_x)`. - ##### Estimate the normalizing constant of a log-gamma distribution. + #### Examples: ```python - tfd = tf.contrib.distributions - + # Estimating the normalizing constant of a log-gamma distribution: + def proposal_log_prob(x): + # Standard normal log-probability. This is properly normalized. + return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1) + def target_log_prob(x): + # Unnormalized log-gamma(2, 3) distribution. + # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1] + return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1) # Run 100 AIS chains in parallel - num_chains = 100 - dims = 20 - dtype = np.float32 - - proposal = tfd.MultivatiateNormalDiag( - loc=tf.zeros([dims], dtype=dtype)) - - target = tfd.TransformedDistribution( - distribution=tfd.Gamma(concentration=dtype(2), - rate=dtype(3)), - bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), - event_shape=[dims]) - - chains_state, ais_weights, kernels_results = ( - hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal.log_prob, - num_steps=1000, - target_log_prob_fn=target.log_prob, - step_size=0.2, - current_state=proposal.sample(num_chains), - num_leapfrog_steps=2)) - - log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) + initial_x = tf.random_normal([100, 20]) + w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob, + proposal_log_prob, event_dims=[1]) + log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) ``` - ##### Estimate marginal likelihood of a Bayesian regression model. - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, x): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, x, axes=[[0], [-1]])) - + # Estimating the marginal likelihood of a Bayesian regression model: + base_measure = -0.5 * np.log(2 * np.pi) + def proposal_log_prob(x): + # Standard normal log-probability. This is properly normalized. + return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1) + def regression_log_joint(beta, x, y): + # This function returns a vector whose ith element is log p(beta[i], y | x). + # Each row of beta corresponds to the state of an independent Markov chain. + log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1) + means = tf.matmul(beta, x, transpose_b=True) + log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) + + base_measure, 1) + return log_prior + log_likelihood + def log_joint_partial(beta): + return regression_log_joint(beta, x, y) # Run 100 AIS chains in parallel - num_chains = 100 - dims = 10 - dtype = np.float32 - - # Make training data. - x = np.random.randn(num_chains, dims).astype(dtype) - true_weights = np.random.randn(dims).astype(dtype) - y = np.dot(x, true_weights) + np.random.randn(num_chains) - - # Setup model. - prior = make_prior(dims, dtype) - def target_log_prob_fn(weights): - return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) - - proposal = tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - weight_samples, ais_weights, kernel_results = ( - hmc.sample_annealed_importance_chain( - num_steps=1000, - proposal_log_prob_fn=proposal.log_prob, - target_log_prob_fn=target_log_prob_fn - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2)) - log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) + initial_beta = tf.random_normal([100, x.shape[1]]) + w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta, + log_joint_partial, proposal_log_prob, + event_dims=[1]) + log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) ``` - - Args: - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - num_steps: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). - - Returns: - accepted_state: `Tensor` or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at the final iteration. Has same shape as - input `current_state`. - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(current_state)`. """ - def make_convex_combined_log_prob_fn(iter_): - def _fn(*args): - p = proposal_log_prob_fn(*args) - t = target_log_prob_fn(*args) - dtype = p.dtype.base_dtype - beta = (math_ops.cast(iter_ + 1, dtype) - / math_ops.cast(num_steps, dtype)) - return (1. - beta) * p + beta * t - return _fn - - with ops.name_scope( - name, "hmc_sample_annealed_importance_chain", - [num_steps, current_state, step_size, num_leapfrog_steps, seed]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_log_prob, - current_grads_log_prob, - ] = _prepare_args( - make_convex_combined_log_prob_fn(iter_=0), - current_state, - step_size, - description="convex_combined_log_prob") - def _loop_body(iter_, ais_weights, current_state, kernel_results): - """Closure which implements `tf.while_loop` body.""" - current_state_parts = (list(current_state) - if _is_list_like(current_state) - else [current_state]) - ais_weights += ((target_log_prob_fn(*current_state_parts) - - proposal_log_prob_fn(*current_state_parts)) - / math_ops.cast(num_steps, ais_weights.dtype)) - return [iter_ + 1, ais_weights] + list(kernel( - make_convex_combined_log_prob_fn(iter_), - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - - [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - 0, # iter_ - array_ops.zeros_like(current_log_prob), # ais_weights - current_state, - _make_dummy_kernel_results(current_state, - current_log_prob, - current_grads_log_prob), - ])[1:] # Lop-off "iter_". - - return [current_state, ais_weights, kernel_results] - - -def kernel(target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): + with ops.name_scope(name, 'hmc_ais_chain', + [n_iterations, step_size, n_leapfrog_steps, initial_x]): + non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) + + beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:] + def _body(a, beta): # pylint: disable=missing-docstring + def log_prob_beta(x): + return ((1 - beta) * proposal_log_prob_fn(x) + + beta * target_log_prob_fn(x)) + last_x = a[0] + w = a[2] + w += (1. / n_iterations) * (target_log_prob_fn(last_x) - + proposal_log_prob_fn(last_x)) + # TODO(b/66917083): There's an opportunity for gradient reuse here. + updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps, + last_x, log_prob_beta, + event_dims) + return updated_x, acceptance_probs, w + + x, acceptance_probs, w = functional_ops.scan( + _body, beta_series, + (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) + return w[-1], x[-1], acceptance_probs[-1] + + +def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), + x_log_prob=None, x_grad=None, name=None): """Runs one iteration of Hamiltonian Monte Carlo. Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) @@ -509,625 +316,334 @@ def kernel(target_log_prob_fn, a Metropolis proposal. This function applies one step of HMC to randomly update the variable `x`. - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) + This function can update multiple chains in parallel. It assumes + that all dimensions of `x` not specified in `event_dims` are + independent, and should therefore be updated independently. The + output of `target_log_prob_fn()` should sum log-probabilities across + all event dimensions. Slices along dimensions not in `event_dims` + may have different target distributions; for example, if + `event_dims == (1,)`, then `x[0, :]` could have a different target + distribution from x[1, :]. This is up to `target_log_prob_fn()`. - #### Examples: + Args: + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `x`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. + When possible, it's often helpful to match per-variable step + sizes to the standard deviations of the target distribution in + each variable. + n_leapfrog_steps: Integer number of steps to run the leapfrog + integrator for. Total progress per HMC step is roughly + proportional to step_size * n_leapfrog_steps. + x: Tensor containing the value(s) of the random variable(s) to update. + target_log_prob_fn: Python callable which takes an argument like `initial_x` + and returns its (possibly unnormalized) log-density under the target + distribution. + event_dims: List of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + x_log_prob (optional): Tensor containing the cached output of a previous + call to `target_log_prob_fn()` evaluated at `x` (such as that provided by + a previous call to `kernel()`). Providing `x_log_prob` and + `x_grad` saves one gradient computation per call to `kernel()`. + x_grad (optional): Tensor containing the cached gradient of + `target_log_prob_fn()` evaluated at `x` (such as that provided by + a previous call to `kernel()`). Providing `x_log_prob` and + `x_grad` saves one gradient computation per call to `kernel()`. + name: Python `str` name prefixed to Ops created by this function. - ##### Simple chain with warm-up. + Returns: + updated_x: The updated variable(s) x. Has shape matching `initial_x`. + acceptance_probs: Tensor with the acceptance probabilities for the final + iteration. This is useful for diagnosing step size problems etc. Has + shape matching `target_log_prob_fn(initial_x)`. + new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`. + new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at + `updated_x`. - ```python - tfd = tf.contrib.distributions + #### Examples: + ```python # Tuning acceptance rates: - dtype = np.float32 target_accept_rate = 0.631 - num_warmup_iter = 500 - num_chain_iter = 500 - - x = tf.get_variable(name="x", initializer=dtype(1)) - step_size = tf.get_variable(name="step_size", initializer=dtype(1)) - - target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - - new_x, other_results = hmc.kernel( - target_log_prob_fn=target.log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=3)[:4] - - x_update = x.assign(new_x) - - step_size_update = step_size.assign_add( - step_size * tf.where( - other_results.acceptance_probs > target_accept_rate, - 0.01, -0.01)) - - warmup = tf.group([x_update, step_size_update]) - - tf.global_variables_initializer().run() - - sess.graph.finalize() # No more graph building. - + def target_log_prob(x): + # Standard normal + return tf.reduce_sum(-0.5 * tf.square(x)) + initial_x = tf.zeros([10]) + initial_log_prob = target_log_prob(initial_x) + initial_grad = tf.gradients(initial_log_prob, initial_x)[0] + # Algorithm state + x = tf.Variable(initial_x, name='x') + step_size = tf.Variable(1., name='step_size') + last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob') + last_grad = tf.Variable(initial_grad, name='last_grad') + # Compute updates + new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x, + target_log_prob, + event_dims=[0], + x_log_prob=last_log_prob) + x_update = tf.assign(x, new_x) + log_prob_update = tf.assign(last_log_prob, log_prob) + grad_update = tf.assign(last_grad, grad) + step_size_update = tf.assign(step_size, + tf.where(acceptance_prob > target_accept_rate, + step_size * 1.01, step_size / 1.01)) + adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update] + sampling_updates = [x_update, log_prob_update, grad_update] + + sess = tf.Session() + sess.run(tf.global_variables_initializer()) # Warm up the sampler and adapt the step size - for _ in xrange(num_warmup_iter): - sess.run(warmup) - + for i in xrange(500): + sess.run(adaptive_updates) # Collect samples without adapting step size - samples = np.zeros([num_chain_iter]) - for i in xrange(num_chain_iter): - _, x_, target_log_prob_, grad_ = sess.run([ - x_update, - x, - other_results.target_log_prob, - other_results.grads_target_log_prob]) - samples[i] = x_ - - print(samples.mean(), samples.std()) + samples = np.zeros([500, 10]) + for i in xrange(500): + x_val, _ = sess.run([new_x, sampling_updates]) + samples[i] = x_val ``` - ##### Sample from more complicated posterior. - - I.e., - - ```none - W ~ MVN(loc=0, scale=sigma * eye(dims)) - for i=1...num_samples: - X[i] ~ MVN(loc=0, scale=eye(dims)) - eps[i] ~ Normal(loc=0, scale=1) - Y[i] = X[i].T * W + eps[i] + ```python + # Empirical-Bayes estimation of a hyperparameter by MCMC-EM: + + # Problem setup + N = 150 + D = 10 + x = np.random.randn(N, D).astype(np.float32) + true_sigma = 0.5 + true_beta = true_sigma * np.random.randn(D).astype(np.float32) + y = x.dot(true_beta) + np.random.randn(N).astype(np.float32) + + def log_prior(beta, log_sigma): + return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) - + log_sigma) + def regression_log_joint(beta, log_sigma, x, y): + # This function returns log p(beta | log_sigma) + log p(y | x, beta). + means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True) + means = tf.squeeze(means) + log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means)) + return log_prior(beta, log_sigma) + log_likelihood + def log_joint_partial(beta): + return regression_log_joint(beta, log_sigma, x, y) + # Our estimate of log(sigma) + log_sigma = tf.Variable(0., name='log_sigma') + # The state of the Markov chain + beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta') + new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial, + event_dims=[0]) + beta_update = tf.assign(beta, new_beta) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + with tf.control_dependencies([beta_update]): + log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma), + var_list=[log_sigma]) + + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + log_sigma_history = np.zeros(1000) + for i in xrange(1000): + log_sigma_val, _ = sess.run([log_sigma, log_sigma_update]) + log_sigma_history[i] = log_sigma_val + # Should converge to something close to true_sigma + plt.plot(np.exp(log_sigma_history)) ``` + """ + with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): + potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + x = ops.convert_to_tensor(x, name='x') + + x_shape = array_ops.shape(x) + m = random_ops.random_normal(x_shape, dtype=x.dtype) + + kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) + + if (x_log_prob is not None) and (x_grad is not None): + log_potential_0, grad_0 = -x_log_prob, -x_grad # pylint: disable=invalid-unary-operand-type + else: + if x_log_prob is not None: + logging.warn('x_log_prob was provided, but x_grad was not,' + ' so x_log_prob was not used.') + if x_grad is not None: + logging.warn('x_grad was provided, but x_log_prob was not,' + ' so x_grad was not used.') + log_potential_0, grad_0 = potential_and_grad(x) + + new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator( + step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0) + + kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) + + energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 + # Treat NaN as infinite energy (and therefore guaranteed rejection). + energy_change = array_ops.where( + math_ops.is_nan(energy_change), + array_ops.fill(array_ops.shape(energy_change), + energy_change.dtype.as_numpy_dtype(np.inf)), + energy_change) + acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) + accepted = ( + random_ops.random_uniform( + array_ops.shape(acceptance_probs), dtype=x.dtype) + < acceptance_probs) + new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) + + # TODO(b/65738010): This should work, but it doesn't for now. + # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) + reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, + keep_dims=True)) + accepted = array_ops.reshape(accepted, reduced_shape) + accepted = math_ops.logical_or( + accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) + new_x = array_ops.where(accepted, new_x, x) + new_grad = -array_ops.where(accepted, grad_1, grad_0) + + # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect + # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This + # should be fixed. + return new_x, acceptance_probs, new_log_prob, new_grad + + +def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, + potential_and_grad, initial_grad, name=None): + """Applies `n_steps` steps of the leapfrog integrator. + + This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing + gradient computations where possible. - ```python - tfd = tf.contrib.distributions - - def make_training_data(num_samples, dims, sigma): - dt = np.asarray(sigma).dtype - zeros = tf.zeros(dims, dtype=dt) - x = tfd.MultivariateNormalDiag( - loc=zeros).sample(num_samples, seed=1) - w = tfd.MultivariateNormalDiag( - loc=zeros, - scale_identity_multiplier=sigma).sample(seed=2) - noise = tfd.Normal( - loc=dt(0), - scale=dt(1)).sample(num_samples, seed=3) - y = tf.tensordot(x, w, axes=[[1], [0]]) + noise - return y, x, w - - def make_prior(sigma, dims): - # p(w | sigma) - return tfd.MultivariateNormalDiag( - loc=tf.zeros([dims], dtype=sigma.dtype), - scale_identity_multiplier=sigma) - - def make_likelihood(x, w): - # p(y | x, w) - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(x, w, axes=[[1], [0]])) - - # Setup assumptions. - dtype = np.float32 - num_samples = 150 - dims = 10 - num_iters = int(5e3) - - true_sigma = dtype(0.5) - y, x, true_weights = make_training_data(num_samples, dims, true_sigma) - - # Estimate of `log(true_sigma)`. - log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) - sigma = tf.exp(log_sigma) - - # State of the Markov chain. - weights = tf.get_variable( - name="weights", - initializer=np.random.randn(dims).astype(dtype)) - - prior = make_prior(sigma, dims) - - def joint_log_prob_fn(w): - # f(w) = log p(w, y | x) - return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) - - weights_update = weights.assign( - hmc.kernel(target_log_prob_fn=joint_log_prob, - current_state=weights, - step_size=0.1, - num_leapfrog_steps=5)[0]) - - with tf.control_dependencies([weights_update]): - loss = -prior.log_prob(weights) + Args: + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `initial_position`. Larger step sizes lead to faster progress, but + too-large step sizes lead to larger discretization error and + worse energy conservation. + n_steps: Number of steps to run the leapfrog integrator. + initial_position: Tensor containing the value(s) of the position variable(s) + to update. + initial_momentum: Tensor containing the value(s) of the momentum variable(s) + to update. + potential_and_grad: Python callable that takes a position tensor like + `initial_position` and returns the potential energy and its gradient at + that position. + initial_grad: Tensor with the value of the gradient of the potential energy + at `initial_position`. + name: Python `str` name prefixed to Ops created by this function. - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) + Returns: + updated_position: Updated value of the position. + updated_momentum: Updated value of the momentum. + new_potential: Potential energy of the new position. Has shape matching + `potential_and_grad(initial_position)`. + new_grad: Gradient from potential_and_grad() evaluated at the new position. + Has shape matching `initial_position`. + + Example: Simple quadratic potential. - sess.graph.finalize() # No more graph building. + ```python + def potential_and_grad(position): + return tf.reduce_sum(0.5 * tf.square(position)), position + position = tf.placeholder(np.float32) + momentum = tf.placeholder(np.float32) + potential, grad = potential_and_grad(position) + new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator( + 0.1, 3, position, momentum, potential_and_grad, grad) + + sess = tf.Session() + position_val = np.random.randn(10) + momentum_val = np.random.randn(10) + potential_val, grad_val = sess.run([potential, grad], + {position: position_val}) + positions = np.zeros([100, 10]) + for i in xrange(100): + position_val, momentum_val, potential_val, grad_val = sess.run( + [new_position, new_momentum, new_potential, new_grad], + {position: position_val, momentum: momentum_val}) + positions[i] = position_val + # Should trace out sinusoidal dynamics. + plt.plot(positions[:, 0]) + ``` + """ + def leapfrog_wrapper(step_size, x, m, grad, l): + x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad) + return step_size, x, m, grad, l + 1 - tf.global_variables_initializer().run() + def counter_fn(a, b, c, d, counter): # pylint: disable=unused-argument + return counter < n_steps - sigma_history = np.zeros(num_iters, dtype) - weights_history = np.zeros([num_iters, dims], dtype) + with ops.name_scope(name, 'leapfrog_integrator', + [step_size, n_steps, initial_position, initial_momentum, + initial_grad]): + _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop( + counter_fn, leapfrog_wrapper, [step_size, initial_position, + initial_momentum, initial_grad, + array_ops.constant(0)], back_prop=False) + # We're counting on the runtime to eliminate this redundant computation. + new_potential, new_grad = potential_and_grad(new_x) + return new_x, new_m, new_potential, new_grad - for i in xrange(num_iters): - _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) - weights_history[i, :] = weights_ - sigma_history[i] = sigma_ - true_weights_ = sess.run(true_weights) +def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, + name=None): + """Applies one step of the leapfrog integrator. - # Should converge to something close to true_sigma. - plt.plot(sigma_history); - plt.ylabel("sigma"); - plt.xlabel("iteration"); - ``` + Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2. Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `current_target_log_prob` at the `current_state` - and wrt the `current_state`. Must have same shape as `current_state`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `position`. Larger step sizes lead to faster progress, but + too-large step sizes lead to larger discretization error and + worse energy conservation. + position: Tensor containing the value(s) of the position variable(s) + to update. + momentum: Tensor containing the value(s) of the momentum variable(s) + to update. + potential_and_grad: Python callable that takes a position tensor like + `position` and returns the potential energy and its gradient at that + position. + grad: Tensor with the value of the gradient of the potential energy + at `position`. name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_kernel"). Returns: - accepted_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - acceptance_probs: Tensor with the acceptance probabilities for each - iteration. Has shape matching `target_log_prob_fn(current_state)`. - accepted_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `accepted_state`. - accepted_grads_target_log_prob: Python `list` of `Tensor`s representing the - gradient of `accepted_target_log_prob` wrt each `accepted_state`. - - Raises: - ValueError: if there isn't one `step_size` or a list with same length as - `current_state`. - """ - with ops.name_scope( - name, "hmc_kernel", - [current_state, step_size, num_leapfrog_steps, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [current_state_parts, step_sizes, current_target_log_prob, - current_grads_target_log_prob] = _prepare_args( - target_log_prob_fn, current_state, step_size, - current_target_log_prob, current_grads_target_log_prob, - maybe_expand=True) - independent_chain_ndims = distributions_util.prefer_static_rank( - current_target_log_prob) - def init_momentum(s): - return random_ops.random_normal( - shape=array_ops.shape(s), - dtype=s.dtype.base_dtype, - seed=distributions_util.gen_new_seed( - seed, salt="hmc_kernel_momentums")) - current_momentums = [init_momentum(s) for s in current_state_parts] - - [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] = _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob, - current_grads_target_log_prob) - - energy_change = _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims) - - # u < exp(min(-energy, 0)), where u~Uniform[0,1) - # ==> -log(u) >= max(e, 0) - # ==> -log(u) >= e - # (Perhaps surprisingly, we don't have a better way to obtain a random - # uniform from positive reals, i.e., `tf.random_uniform(minval=0, - # maxval=np.inf)` won't work.) - random_uniform = random_ops.random_uniform( - shape=array_ops.shape(energy_change), - dtype=energy_change.dtype, - seed=seed) - random_positive = -math_ops.log(random_uniform) - is_accepted = random_positive >= energy_change - - accepted_target_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - - accepted_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] - - accepted_grads_target_log_prob = [ - _choose(is_accepted, - proposed_grad, - grad, - independent_chain_ndims) - for proposed_grad, grad - in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] - - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(accepted_state_parts), - KernelResults( - acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), - current_grads_target_log_prob=accepted_grads_target_log_prob, - current_target_log_prob=accepted_target_log_prob, - energy_change=energy_change, - is_accepted=is_accepted, - proposed_grads_target_log_prob=proposed_grads_target_log_prob, - proposed_state=maybe_flatten(proposed_state_parts), - proposed_target_log_prob=proposed_target_log_prob, - random_positive=random_positive, - ), - ] - - -def _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Applies `num_leapfrog_steps` of the leapfrog integrator. - - Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. - - #### Examples: + updated_position: Updated value of the position. + updated_momentum: Updated value of the momentum. + new_potential: Potential energy of the new position. Has shape matching + `potential_and_grad(position)`. + new_grad: Gradient from potential_and_grad() evaluated at the new position. + Has shape matching `position`. - ##### Simple quadratic potential. + Example: Simple quadratic potential. ```python - tfd = tf.contrib.distributions - - dims = 10 - num_iter = int(1e3) - dtype = np.float32 - + def potential_and_grad(position): + # Simple quadratic potential + return tf.reduce_sum(0.5 * tf.square(position)), position position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) - - [ - new_momentums, - new_positions, - ] = hmc._leapfrog_integrator( - current_momentums=[momentum], - target_log_prob_fn=tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)).log_prob, - current_state_parts=[position], - step_sizes=0.1, - num_leapfrog_steps=3)[:2] - - sess.graph.finalize() # No more graph building. - - momentum_ = np.random.randn(dims).astype(dtype) - position_ = np.random.randn(dims).astype(dtype) - - positions = np.zeros([num_iter, dims], dtype) - for i in xrange(num_iter): - position_, momentum_ = sess.run( - [new_momentums[0], new_position[0]], - feed_dict={position: position_, momentum: momentum_}) - positions[i] = position_ - - plt.plot(positions[:, 0]); # Sinusoidal. + potential, grad = potential_and_grad(position) + new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step( + 0.1, position, momentum, potential_and_grad, grad) + + sess = tf.Session() + position_val = np.random.randn(10) + momentum_val = np.random.randn(10) + potential_val, grad_val = sess.run([potential, grad], + {position: position_val}) + positions = np.zeros([100, 10]) + for i in xrange(100): + position_val, momentum_val, potential_val, grad_val = sess.run( + [new_position, new_momentum, new_potential, new_grad], + {position: position_val, momentum: momentum_val}) + positions[i] = position_val + # Should trace out sinusoidal dynamics. + plt.plot(positions[:, 0]) ``` - - Args: - current_momentums: Tensor containing the value(s) of the momentum - variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like - `*current_state_parts` and returns its (possibly unnormalized) log-density - under the target distribution. - current_state_parts: Python `list` of `Tensor`s representing the current - state(s) of the Markov chain(s). The first `independent_chain_ndims` of - the `Tensor`(s) index different chains. - step_sizes: Python `list` of `Tensor`s representing the step size for the - leapfrog integrator. Must broadcast with the shape of - `current_state_parts`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. When - possible, it's often helpful to match per-variable step sizes to the - standard deviations of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn(*current_state_parts)`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob_fn(*current_state_parts`) wrt - `current_state_parts`. Must have same shape as `current_state_parts`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_leapfrog_integrator"). - - Returns: - proposed_momentums: Updated value of the momentum. - proposed_state_parts: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state_parts`. - proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `accepted_state`. - proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `accepted_state`. - - Raises: - ValueError: if `len(momentums) != len(state_parts)`. - ValueError: if `len(state_parts) != len(step_sizes)`. - ValueError: if `len(state_parts) != len(grads_target_log_prob)`. - TypeError: if `not target_log_prob.dtype.is_floating`. """ - def _loop_body(step, - current_momentums, - current_state_parts, - ignore_current_target_log_prob, # pylint: disable=unused-argument - current_grads_target_log_prob): - return [step + 1] + list(_leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob)) - - with ops.name_scope( - name, "hmc_leapfrog_integrator", - [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, - current_target_log_prob, current_grads_target_log_prob]): - if len(current_momentums) != len(current_state_parts): - raise ValueError("`momentums` must be in one-to-one correspondence " - "with `state_parts`") - num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, - name="num_leapfrog_steps") - current_target_log_prob, current_grads_target_log_prob = ( - _maybe_call_fn_and_grads( - target_log_prob_fn, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob)) - return control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_leapfrog_steps, - body=_loop_body, - loop_vars=[ - 0, # iter_ - current_momentums, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob, - ], - back_prop=False)[1:] # Lop-off "iter_". - - -def _leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob, - name=None): - """Applies one step of the leapfrog integrator.""" - with ops.name_scope( - name, "_leapfrog_step", - [current_momentums, current_state_parts, step_sizes, - current_grads_target_log_prob]): - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(current_momentums, - step_sizes, - current_grads_target_log_prob)] - proposed_state_parts = [x + ss * m for x, ss, m - in zip(current_state_parts, - step_sizes, - proposed_momentums)] - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - if not proposed_target_log_prob.dtype.is_floating: - raise TypeError("`target_log_prob_fn` must produce a `Tensor` " - "with `float` `dtype`.") - proposed_grads_target_log_prob = gradients_ops.gradients( - proposed_target_log_prob, proposed_state_parts) - if any(g is None for g in proposed_grads_target_log_prob): - raise ValueError( - "Encountered `None` gradient. Does your target `target_log_prob_fn` " - "access all `tf.Variable`s via `tf.get_variable`?\n" - " current_state_parts: {}\n" - " proposed_state_parts: {}\n" - " proposed_grads_target_log_prob: {}".format( - current_state_parts, - proposed_state_parts, - proposed_grads_target_log_prob)) - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(proposed_momentums, - step_sizes, - proposed_grads_target_log_prob)] - return [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] - - -def _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims, - name=None): - """Helper to `kernel` which computes the energy change.""" - with ops.name_scope( - name, "compute_energy_change", - ([current_target_log_prob, proposed_target_log_prob, - independent_chain_ndims] + - current_momentums + proposed_momentums)): - # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy - # since they're a mouthful and lets us inline more. - lk0, lk1 = [], [] - for current_momentum, proposed_momentum in zip(current_momentums, - proposed_momentums): - axis = math_ops.range(independent_chain_ndims, - array_ops.rank(current_momentum)) - lk0.append(_log_sum_sq(current_momentum, axis)) - lk1.append(_log_sum_sq(proposed_momentum, axis)) - - lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), - axis=-1) - lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), - axis=-1) - lp0 = -current_target_log_prob # log_potential - lp1 = -proposed_target_log_prob # proposed_log_potential - x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], - axis=-1) - - # The sum is NaN if any element is NaN or we see both +Inf and -Inf. - # Thus we will replace such rows with infinite energy change which implies - # rejection. Recall that float-comparisons with NaN are always False. - is_sum_determinate = ( - math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & - math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) - is_sum_determinate = array_ops.tile( - is_sum_determinate[..., array_ops.newaxis], - multiples=array_ops.concat([ - array_ops.ones(array_ops.rank(is_sum_determinate), - dtype=dtypes.int32), - [4], - ], axis=0)) - x = array_ops.where(is_sum_determinate, - x, - array_ops.fill(array_ops.shape(x), - value=x.dtype.as_numpy_dtype(np.inf))) - - return math_ops.reduce_sum(x, axis=-1) - - -def _choose(is_accepted, - accepted, - rejected, - independent_chain_ndims, - name=None): - """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" - def _expand_is_accepted_like(x): - with ops.name_scope("_choose"): - expand_shape = array_ops.concat([ - array_ops.shape(is_accepted), - array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], - dtype=dtypes.int32), - ], axis=0) - multiples = array_ops.concat([ - array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), - array_ops.shape(x)[independent_chain_ndims:], - ], axis=0) - m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), - multiples) - m.set_shape(x.shape) - return m - with ops.name_scope(name, "_choose", values=[ - is_accepted, accepted, rejected, independent_chain_ndims]): - return array_ops.where(_expand_is_accepted_like(accepted), - accepted, - rejected) - - -def _maybe_call_fn_and_grads(fn, - fn_arg_list, - fn_result=None, - grads_fn_result=None, - description="target_log_prob"): - """Helper which computes `fn_result` and `grads` if needed.""" - fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) - else [fn_arg_list]) - if fn_result is None: - fn_result = fn(*fn_arg_list) - if not fn_result.dtype.is_floating: - raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( - description)) - if grads_fn_result is None: - grads_fn_result = gradients_ops.gradients( - fn_result, fn_arg_list) - if len(fn_arg_list) != len(grads_fn_result): - raise ValueError("`{}` must be in one-to-one correspondence with " - "`grads_{}`".format(*[description]*2)) - if any(g is None for g in grads_fn_result): - raise ValueError("Encountered `None` gradient.") - return fn_result, grads_fn_result - - -def _prepare_args(target_log_prob_fn, state, step_size, - target_log_prob=None, grads_target_log_prob=None, - maybe_expand=False, description="target_log_prob"): - """Helper which processes input args to meet list-like assumptions.""" - state_parts = list(state) if _is_list_like(state) else [state] - state_parts = [ops.convert_to_tensor(s, name="state") - for s in state_parts] - target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( - target_log_prob_fn, - state_parts, - target_log_prob, - grads_target_log_prob, - description) - step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] - step_sizes = [ - ops.convert_to_tensor( - s, name="step_size", dtype=target_log_prob.dtype) - for s in step_sizes] - if len(step_sizes) == 1: - step_sizes *= len(state_parts) - if len(state_parts) != len(step_sizes): - raise ValueError("There should be exactly one `step_size` or it should " - "have same length as `current_state`.") - if maybe_expand: - maybe_flatten = lambda x: x - else: - maybe_flatten = lambda x: x if _is_list_like(state) else x[0] - return [ - maybe_flatten(state_parts), - maybe_flatten(step_sizes), - target_log_prob, - grads_target_log_prob, - ] - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _log_sum_sq(x, axis=None): - """Computes log(sum(x**2)).""" - return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) + with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum, + grad]): + momentum -= 0.5 * step_size * grad + position += step_size * momentum + potential, grad = potential_and_grad(position) + momentum -= 0.5 * step_size * grad + + return position, momentum, potential, grad