Breaking change: Revise HMC interface to accept a list of Tensors representing a...
authorJoshua V. Dillon <jvdillon@google.com>
Fri, 2 Feb 2018 21:07:37 +0000 (13:07 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 21:11:19 +0000 (13:11 -0800)
PiperOrigin-RevId: 184323369

tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
tensorflow/contrib/bayesflow/python/ops/hmc.py
tensorflow/contrib/bayesflow/python/ops/hmc_impl.py

index cbc66b6..d9d0dfc 100644 (file)
@@ -19,29 +19,36 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
-from scipy import special
 from scipy import stats
 
 from tensorflow.contrib.bayesflow.python.ops import hmc
+from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change
+from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator
 
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_impl as gradients_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
+from tensorflow.python.ops.distributions import gamma as gamma_lib
+from tensorflow.python.ops.distributions import normal as normal_lib
 from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.platform import tf_logging as logging_ops
+
+
+def _reduce_variance(x, axis=None, keepdims=False):
+  sample_mean = math_ops.reduce_mean(x, axis, keepdims=True)
+  return math_ops.reduce_mean(
+      math_ops.squared_difference(x, sample_mean), axis, keepdims)
 
 
-# TODO(b/66964210): Test float16.
 class HMCTest(test.TestCase):
 
   def setUp(self):
     self._shape_param = 5.
     self._rate_param = 10.
-    self._expected_x = (special.digamma(self._shape_param)
-                        - np.log(self._rate_param))
-    self._expected_exp_x = self._shape_param / self._rate_param
 
     random_seed.set_random_seed(10003)
     np.random.seed(10003)
@@ -63,63 +70,46 @@ class HMCTest(test.TestCase):
                                self._rate_param * math_ops.exp(x),
                                event_dims)
 
-  def _log_gamma_log_prob_grad(self, x, event_dims=()):
-    """Computes log-pdf and gradient of a log-gamma random variable.
-
-    Args:
-      x: Value of the random variable.
-      event_dims: Dimensions not to treat as independent. Default is (),
-        i.e., all dimensions are independent.
-
-    Returns:
-      log_prob: The log-pdf up to a normalizing constant.
-      grad: The gradient of the log-pdf with respect to x.
-    """
-    return (math_ops.reduce_sum(self._shape_param * x -
-                                self._rate_param * math_ops.exp(x),
-                                event_dims),
-            self._shape_param - self._rate_param * math_ops.exp(x))
-
-  def _n_event_dims(self, x_shape, event_dims):
-    return np.prod([int(x_shape[i]) for i in event_dims])
-
-  def _integrator_conserves_energy(self, x, event_dims, sess,
+  def _integrator_conserves_energy(self, x, independent_chain_ndims, sess,
                                    feed_dict=None):
-    def potential_and_grad(x):
-      log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims)
-      return -log_prob, -grad
-
-    step_size = array_ops.placeholder(np.float32, [], name='step_size')
-    hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps')
+    step_size = array_ops.placeholder(np.float32, [], name="step_size")
+    hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps")
 
     if feed_dict is None:
       feed_dict = {}
     feed_dict[hmc_lf_steps] = 1000
 
-    m = random_ops.random_normal(array_ops.shape(x))
-    potential_0, grad_0 = potential_and_grad(x)
-    old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m,
-                                                         event_dims)
-
-    _, new_m, potential_1, _ = (
-        hmc.leapfrog_integrator(step_size, hmc_lf_steps, x,
-                                m, potential_and_grad, grad_0))
+    event_dims = math_ops.range(independent_chain_ndims,
+                                array_ops.rank(x))
 
-    new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
+    m = random_ops.random_normal(array_ops.shape(x))
+    log_prob_0 = self._log_gamma_log_prob(x, event_dims)
+    grad_0 = gradients_ops.gradients(log_prob_0, x)
+    old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims)
+
+    new_m, _, log_prob_1, _ = _leapfrog_integrator(
+        current_momentums=[m],
+        target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims),
+        current_state_parts=[x],
+        step_sizes=[step_size],
+        num_leapfrog_steps=hmc_lf_steps,
+        current_target_log_prob=log_prob_0,
+        current_grads_target_log_prob=grad_0)
+    new_m = new_m[0]
+
+    new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
                                                          event_dims)
 
     x_shape = sess.run(x, feed_dict).shape
-    n_event_dims = self._n_event_dims(x_shape, event_dims)
-    feed_dict[step_size] = 0.1 / n_event_dims
-    old_energy_val, new_energy_val = sess.run([old_energy, new_energy],
-                                              feed_dict)
-    logging.vlog(1, 'average energy change: {}'.format(
-        abs(old_energy_val - new_energy_val).mean()))
-
-    self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool),
-                        abs(old_energy_val - new_energy_val) < 1.)
-
-  def _integrator_conserves_energy_wrapper(self, event_dims):
+    event_size = np.prod(x_shape[independent_chain_ndims:])
+    feed_dict[step_size] = 0.1 / event_size
+    old_energy_, new_energy_ = sess.run([old_energy, new_energy],
+                                        feed_dict)
+    logging_ops.vlog(1, "average energy relative change: {}".format(
+        (1. - new_energy_ / old_energy_).mean()))
+    self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
+
+  def _integrator_conserves_energy_wrapper(self, independent_chain_ndims):
     """Tests the long-term energy conservation of the leapfrog integrator.
 
     The leapfrog integrator is symplectic, so for sufficiently small step
@@ -127,135 +117,167 @@ class HMCTest(test.TestCase):
     the energy of the system blowing up or collapsing.
 
     Args:
-      event_dims: A tuple of dimensions that should not be treated as
-        independent. This allows for multiple chains to be run independently
-        in parallel. Default is (), i.e., all dimensions are independent.
+      independent_chain_ndims: Python `int` scalar representing the number of
+        dims associated with independent chains.
     """
     with self.test_session() as sess:
-      x_ph = array_ops.placeholder(np.float32, name='x_ph')
-
-      feed_dict = {x_ph: np.zeros([50, 10, 2])}
-      self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict)
+      x_ph = array_ops.placeholder(np.float32, name="x_ph")
+      feed_dict = {x_ph: np.random.rand(50, 10, 2)}
+      self._integrator_conserves_energy(x_ph, independent_chain_ndims,
+                                        sess, feed_dict)
 
   def testIntegratorEnergyConservationNullShape(self):
-    self._integrator_conserves_energy_wrapper([])
+    self._integrator_conserves_energy_wrapper(0)
 
   def testIntegratorEnergyConservation1(self):
-    self._integrator_conserves_energy_wrapper([1])
+    self._integrator_conserves_energy_wrapper(1)
 
   def testIntegratorEnergyConservation2(self):
-    self._integrator_conserves_energy_wrapper([2])
+    self._integrator_conserves_energy_wrapper(2)
 
-  def testIntegratorEnergyConservation12(self):
-    self._integrator_conserves_energy_wrapper([1, 2])
+  def testIntegratorEnergyConservation3(self):
+    self._integrator_conserves_energy_wrapper(3)
 
-  def testIntegratorEnergyConservation012(self):
-    self._integrator_conserves_energy_wrapper([0, 1, 2])
-
-  def _chain_gets_correct_expectations(self, x, event_dims, sess,
-                                       feed_dict=None):
+  def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
+                                       sess, feed_dict=None):
     def log_gamma_log_prob(x):
+      event_dims = math_ops.range(independent_chain_ndims,
+                                  array_ops.rank(x))
       return self._log_gamma_log_prob(x, event_dims)
 
-    step_size = array_ops.placeholder(np.float32, [], name='step_size')
-    hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps')
-    hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps')
+    num_results = array_ops.placeholder(
+        np.int32, [], name="num_results")
+    step_size = array_ops.placeholder(
+        np.float32, [], name="step_size")
+    num_leapfrog_steps = array_ops.placeholder(
+        np.int32, [], name="num_leapfrog_steps")
 
     if feed_dict is None:
       feed_dict = {}
-    feed_dict.update({step_size: 0.1,
-                      hmc_lf_steps: 2,
-                      hmc_n_steps: 300})
-
-    sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps],
-                                                    step_size,
-                                                    hmc_lf_steps,
-                                                    x, log_gamma_log_prob,
-                                                    event_dims)
-
-    acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain],
-                                         feed_dict)
-    samples = samples[feed_dict[hmc_n_steps] // 2:]
-    expected_x_est = samples.mean()
-    expected_exp_x_est = np.exp(samples).mean()
-
-    logging.vlog(1, 'True      E[x, exp(x)]: {}\t{}'.format(
-        self._expected_x, self._expected_exp_x))
-    logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format(
-        expected_x_est, expected_exp_x_est))
-    self.assertNear(expected_x_est, self._expected_x, 2e-2)
-    self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2)
-    self.assertTrue((acceptance_probs > 0.5).all())
-    self.assertTrue((acceptance_probs <= 1.0).all())
-
-  def _chain_gets_correct_expectations_wrapper(self, event_dims):
+    feed_dict.update({num_results: 150,
+                      step_size: 0.1,
+                      num_leapfrog_steps: 2})
+
+    samples, kernel_results = hmc.sample_chain(
+        num_results=num_results,
+        target_log_prob_fn=log_gamma_log_prob,
+        current_state=x,
+        step_size=step_size,
+        num_leapfrog_steps=num_leapfrog_steps,
+        num_burnin_steps=150,
+        seed=42)
+
+    expected_x = (math_ops.digamma(self._shape_param)
+                  - np.log(self._rate_param))
+
+    expected_exp_x = self._shape_param / self._rate_param
+
+    acceptance_probs_, samples_, expected_x_ = sess.run(
+        [kernel_results.acceptance_probs, samples, expected_x],
+        feed_dict)
+
+    actual_x = samples_.mean()
+    actual_exp_x = np.exp(samples_).mean()
+
+    logging_ops.vlog(1, "True      E[x, exp(x)]: {}\t{}".format(
+        expected_x_, expected_exp_x))
+    logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
+        actual_x, actual_exp_x))
+    self.assertNear(actual_x, expected_x_, 2e-2)
+    self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
+    self.assertTrue((acceptance_probs_ > 0.5).all())
+    self.assertTrue((acceptance_probs_ <= 1.0).all())
+
+  def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims):
     with self.test_session() as sess:
-      x_ph = array_ops.placeholder(np.float32, name='x_ph')
-
-      feed_dict = {x_ph: np.zeros([50, 10, 2])}
-      self._chain_gets_correct_expectations(x_ph, event_dims, sess,
-                                            feed_dict)
+      x_ph = array_ops.placeholder(np.float32, name="x_ph")
+      feed_dict = {x_ph: np.random.rand(50, 10, 2)}
+      self._chain_gets_correct_expectations(x_ph, independent_chain_ndims,
+                                            sess, feed_dict)
 
   def testHMCChainExpectationsNullShape(self):
-    self._chain_gets_correct_expectations_wrapper([])
+    self._chain_gets_correct_expectations_wrapper(0)
 
   def testHMCChainExpectations1(self):
-    self._chain_gets_correct_expectations_wrapper([1])
+    self._chain_gets_correct_expectations_wrapper(1)
 
   def testHMCChainExpectations2(self):
-    self._chain_gets_correct_expectations_wrapper([2])
-
-  def testHMCChainExpectations12(self):
-    self._chain_gets_correct_expectations_wrapper([1, 2])
+    self._chain_gets_correct_expectations_wrapper(2)
 
-  def _kernel_leaves_target_invariant(self, initial_draws, event_dims,
+  def _kernel_leaves_target_invariant(self, initial_draws,
+                                      independent_chain_ndims,
                                       sess, feed_dict=None):
     def log_gamma_log_prob(x):
+      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
       return self._log_gamma_log_prob(x, event_dims)
 
     def fake_log_prob(x):
       """Cooled version of the target distribution."""
       return 1.1 * log_gamma_log_prob(x)
 
-    step_size = array_ops.placeholder(np.float32, [], name='step_size')
+    step_size = array_ops.placeholder(np.float32, [], name="step_size")
 
     if feed_dict is None:
       feed_dict = {}
 
     feed_dict[step_size] = 0.4
 
-    sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws,
-                                                log_gamma_log_prob, event_dims)
-    bad_sample, bad_acceptance_probs, _, _ = hmc.kernel(
-        step_size, 5, initial_draws, fake_log_prob, event_dims)
-    (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val,
-     updated_draws_val, fake_draws_val) = sess.run([acceptance_probs,
-                                                    bad_acceptance_probs,
-                                                    initial_draws, sample,
-                                                    bad_sample], feed_dict)
+    sample, kernel_results = hmc.kernel(
+        target_log_prob_fn=log_gamma_log_prob,
+        current_state=initial_draws,
+        step_size=step_size,
+        num_leapfrog_steps=5,
+        seed=43)
+
+    bad_sample, bad_kernel_results = hmc.kernel(
+        target_log_prob_fn=fake_log_prob,
+        current_state=initial_draws,
+        step_size=step_size,
+        num_leapfrog_steps=5,
+        seed=44)
+
+    [
+        acceptance_probs_,
+        bad_acceptance_probs_,
+        initial_draws_,
+        updated_draws_,
+        fake_draws_,
+    ] = sess.run([
+        kernel_results.acceptance_probs,
+        bad_kernel_results.acceptance_probs,
+        initial_draws,
+        sample,
+        bad_sample,
+    ], feed_dict)
+
     # Confirm step size is small enough that we usually accept.
-    self.assertGreater(acceptance_probs_val.mean(), 0.5)
-    self.assertGreater(bad_acceptance_probs_val.mean(), 0.5)
+    self.assertGreater(acceptance_probs_.mean(), 0.5)
+    self.assertGreater(bad_acceptance_probs_.mean(), 0.5)
+
     # Confirm step size is large enough that we sometimes reject.
-    self.assertLess(acceptance_probs_val.mean(), 0.99)
-    self.assertLess(bad_acceptance_probs_val.mean(), 0.99)
-    _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(),
-                                        updated_draws_val.flatten())
-    _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(),
-                                        fake_draws_val.flatten())
-    logging.vlog(1, 'acceptance rate for true target: {}'.format(
-        acceptance_probs_val.mean()))
-    logging.vlog(1, 'acceptance rate for fake target: {}'.format(
-        bad_acceptance_probs_val.mean()))
-    logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true))
-    logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake))
+    self.assertLess(acceptance_probs_.mean(), 0.99)
+    self.assertLess(bad_acceptance_probs_.mean(), 0.99)
+
+    _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(),
+                                        updated_draws_.flatten())
+    _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(),
+                                        fake_draws_.flatten())
+
+    logging_ops.vlog(1, "acceptance rate for true target: {}".format(
+        acceptance_probs_.mean()))
+    logging_ops.vlog(1, "acceptance rate for fake target: {}".format(
+        bad_acceptance_probs_.mean()))
+    logging_ops.vlog(1, "K-S p-value for true target: {}".format(
+        ks_p_value_true))
+    logging_ops.vlog(1, "K-S p-value for fake target: {}".format(
+        ks_p_value_fake))
     # Make sure that the MCMC update hasn't changed the empirical CDF much.
     self.assertGreater(ks_p_value_true, 1e-3)
     # Confirm that targeting the wrong distribution does
     # significantly change the empirical CDF.
     self.assertLess(ks_p_value_fake, 1e-6)
 
-  def _kernel_leaves_target_invariant_wrapper(self, event_dims):
+  def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims):
     """Tests that the kernel leaves the target distribution invariant.
 
     Draws some independent samples from the target distribution,
@@ -267,86 +289,116 @@ class HMCTest(test.TestCase):
     does change the target distribution. (And that we can detect that.)
 
     Args:
-      event_dims: A tuple of dimensions that should not be treated as
-        independent. This allows for multiple chains to be run independently
-        in parallel. Default is (), i.e., all dimensions are independent.
+      independent_chain_ndims: Python `int` scalar representing the number of
+        dims associated with independent chains.
     """
     with self.test_session() as sess:
       initial_draws = np.log(np.random.gamma(self._shape_param,
                                              size=[50000, 2, 2]))
       initial_draws -= np.log(self._rate_param)
-      x_ph = array_ops.placeholder(np.float32, name='x_ph')
+      x_ph = array_ops.placeholder(np.float32, name="x_ph")
 
       feed_dict = {x_ph: initial_draws}
 
-      self._kernel_leaves_target_invariant(x_ph, event_dims, sess,
-                                           feed_dict)
-
-  def testKernelLeavesTargetInvariantNullShape(self):
-    self._kernel_leaves_target_invariant_wrapper([])
+      self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims,
+                                           sess, feed_dict)
 
   def testKernelLeavesTargetInvariant1(self):
-    self._kernel_leaves_target_invariant_wrapper([1])
+    self._kernel_leaves_target_invariant_wrapper(1)
 
   def testKernelLeavesTargetInvariant2(self):
-    self._kernel_leaves_target_invariant_wrapper([2])
+    self._kernel_leaves_target_invariant_wrapper(2)
 
-  def testKernelLeavesTargetInvariant12(self):
-    self._kernel_leaves_target_invariant_wrapper([1, 2])
+  def testKernelLeavesTargetInvariant3(self):
+    self._kernel_leaves_target_invariant_wrapper(3)
 
-  def _ais_gets_correct_log_normalizer(self, init, event_dims, sess,
-                                       feed_dict=None):
+  def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
+                                       sess, feed_dict=None):
     def proposal_log_prob(x):
-      return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi),
-                                 event_dims)
+      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
+      return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
+                                        axis=event_dims)
 
     def target_log_prob(x):
+      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
       return self._log_gamma_log_prob(x, event_dims)
 
     if feed_dict is None:
       feed_dict = {}
 
-    w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob,
-                            proposal_log_prob, event_dims)
-
-    w_val = sess.run(w, feed_dict)
-    init_shape = sess.run(init, feed_dict).shape
-    normalizer_multiplier = np.prod([init_shape[i] for i in event_dims])
-
-    true_normalizer = -self._shape_param * np.log(self._rate_param)
-    true_normalizer += special.gammaln(self._shape_param)
-    true_normalizer *= normalizer_multiplier
-
-    n_weights = np.prod(w_val.shape)
-    normalized_w = np.exp(w_val - true_normalizer)
-    standard_error = np.std(normalized_w) / np.sqrt(n_weights)
-    logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format(
-        true_normalizer, np.log(normalized_w.mean()) + true_normalizer,
-        n_weights))
-    self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error)
-
-  def _ais_gets_correct_log_normalizer_wrapper(self, event_dims):
+    num_steps = 200
+
+    _, ais_weights, _ = hmc.sample_annealed_importance_chain(
+        proposal_log_prob_fn=proposal_log_prob,
+        num_steps=num_steps,
+        target_log_prob_fn=target_log_prob,
+        step_size=0.5,
+        current_state=init,
+        num_leapfrog_steps=2,
+        seed=45)
+
+    event_shape = array_ops.shape(init)[independent_chain_ndims:]
+    event_size = math_ops.reduce_prod(event_shape)
+
+    log_true_normalizer = (
+        -self._shape_param * math_ops.log(self._rate_param)
+        + math_ops.lgamma(self._shape_param))
+    log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
+
+    log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
+                                - np.log(num_steps))
+
+    ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
+    ais_weights_size = array_ops.size(ais_weights)
+    standard_error = math_ops.sqrt(
+        _reduce_variance(ratio_estimate_true)
+        / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
+
+    [
+        ratio_estimate_true_,
+        log_true_normalizer_,
+        log_estimated_normalizer_,
+        standard_error_,
+        ais_weights_size_,
+        event_size_,
+    ] = sess.run([
+        ratio_estimate_true,
+        log_true_normalizer,
+        log_estimated_normalizer,
+        standard_error,
+        ais_weights_size,
+        event_size,
+    ], feed_dict)
+
+    logging_ops.vlog(1, "        log_true_normalizer: {}\n"
+                        "   log_estimated_normalizer: {}\n"
+                        "           ais_weights_size: {}\n"
+                        "                 event_size: {}\n".format(
+                            log_true_normalizer_,
+                            log_estimated_normalizer_,
+                            ais_weights_size_,
+                            event_size_))
+    self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
+
+  def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
     """Tests that AIS yields reasonable estimates of normalizers."""
     with self.test_session() as sess:
-      x_ph = array_ops.placeholder(np.float32, name='x_ph')
-
+      x_ph = array_ops.placeholder(np.float32, name="x_ph")
       initial_draws = np.random.normal(size=[30, 2, 1])
-      feed_dict = {x_ph: initial_draws}
-
-      self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess,
-                                            feed_dict)
-
-  def testAISNullShape(self):
-    self._ais_gets_correct_log_normalizer_wrapper([])
+      self._ais_gets_correct_log_normalizer(
+          x_ph,
+          independent_chain_ndims,
+          sess,
+          feed_dict={x_ph: initial_draws})
 
   def testAIS1(self):
-    self._ais_gets_correct_log_normalizer_wrapper([1])
+    self._ais_gets_correct_log_normalizer_wrapper(1)
 
   def testAIS2(self):
-    self._ais_gets_correct_log_normalizer_wrapper([2])
+    self._ais_gets_correct_log_normalizer_wrapper(2)
 
-  def testAIS12(self):
-    self._ais_gets_correct_log_normalizer_wrapper([1, 2])
+  def testAIS3(self):
+    self._ais_gets_correct_log_normalizer_wrapper(3)
 
   def testNanRejection(self):
     """Tests that an update that yields NaN potentials gets rejected.
@@ -359,24 +411,29 @@ class HMCTest(test.TestCase):
     """
     def _unbounded_exponential_log_prob(x):
       """An exponential distribution with log-likelihood NaN for x < 0."""
-      per_element_potentials = array_ops.where(x < 0,
-                                               np.nan * array_ops.ones_like(x),
-                                               -x)
+      per_element_potentials = array_ops.where(
+          x < 0.,
+          array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)),
+          -x)
       return math_ops.reduce_sum(per_element_potentials)
 
     with self.test_session() as sess:
       initial_x = math_ops.linspace(0.01, 5, 10)
-      updated_x, acceptance_probs, _, _ = hmc.kernel(
-          2., 5, initial_x, _unbounded_exponential_log_prob, [0])
-      initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
-          [initial_x, updated_x, acceptance_probs])
-
-      logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
-      logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
-      logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
-
-      self.assertAllEqual(initial_x_val, updated_x_val)
-      self.assertEqual(acceptance_probs_val, 0.)
+      updated_x, kernel_results = hmc.kernel(
+          target_log_prob_fn=_unbounded_exponential_log_prob,
+          current_state=initial_x,
+          step_size=2.,
+          num_leapfrog_steps=5,
+          seed=46)
+      initial_x_, updated_x_, acceptance_probs_ = sess.run(
+          [initial_x, updated_x, kernel_results.acceptance_probs])
+
+      logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
+      logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
+      logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
+
+      self.assertAllEqual(initial_x_, updated_x_)
+      self.assertEqual(acceptance_probs_, 0.)
 
   def testNanFromGradsDontPropagate(self):
     """Test that update with NaN gradients does not cause NaN in results."""
@@ -385,60 +442,195 @@ class HMCTest(test.TestCase):
 
     with self.test_session() as sess:
       initial_x = math_ops.linspace(0.01, 5, 10)
-      updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel(
-          2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0])
-      initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
-          [initial_x, updated_x, acceptance_probs])
-
-      logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
-      logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
-      logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
-
-      self.assertAllEqual(initial_x_val, updated_x_val)
-      self.assertEqual(acceptance_probs_val, 0.)
+      updated_x, kernel_results = hmc.kernel(
+          target_log_prob_fn=_nan_log_prob_with_nan_gradient,
+          current_state=initial_x,
+          step_size=2.,
+          num_leapfrog_steps=5,
+          seed=47)
+      initial_x_, updated_x_, acceptance_probs_ = sess.run(
+          [initial_x, updated_x, kernel_results.acceptance_probs])
+
+      logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
+      logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
+      logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
+
+      self.assertAllEqual(initial_x_, updated_x_)
+      self.assertEqual(acceptance_probs_, 0.)
 
       self.assertAllFinite(
-          gradients_impl.gradients(updated_x, initial_x)[0].eval())
-      self.assertTrue(
-          gradients_impl.gradients(new_grad, initial_x)[0] is None)
+          gradients_ops.gradients(updated_x, initial_x)[0].eval())
+      self.assertAllEqual([True], [g is None for g in gradients_ops.gradients(
+          kernel_results.proposed_grads_target_log_prob, initial_x)])
+      self.assertAllEqual([False], [g is None for g in gradients_ops.gradients(
+          kernel_results.proposed_grads_target_log_prob,
+          kernel_results.proposed_state)])
 
       # Gradients of the acceptance probs and new log prob are not finite.
-      _ = new_log_prob  # Prevent unused arg error.
       # self.assertAllFinite(
-      #     gradients_impl.gradients(acceptance_probs, initial_x)[0].eval())
+      #     gradients_ops.gradients(acceptance_probs, initial_x)[0].eval())
       # self.assertAllFinite(
-      #     gradients_impl.gradients(new_log_prob, initial_x)[0].eval())
+      #     gradients_ops.gradients(new_log_prob, initial_x)[0].eval())
+
+  def _testChainWorksDtype(self, dtype):
+    states, kernel_results = hmc.sample_chain(
+        num_results=10,
+        target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
+        current_state=np.zeros(5).astype(dtype),
+        step_size=0.01,
+        num_leapfrog_steps=10,
+        seed=48)
+    with self.test_session() as sess:
+      states_, acceptance_probs_ = sess.run(
+          [states, kernel_results.acceptance_probs])
+    self.assertEqual(dtype, states_.dtype)
+    self.assertEqual(dtype, acceptance_probs_.dtype)
 
   def testChainWorksIn64Bit(self):
-    def log_prob(x):
-      return - math_ops.reduce_sum(x * x, axis=-1)
-    states, acceptance_probs = hmc.chain(
-        n_iterations=10,
-        step_size=np.float64(0.01),
-        n_leapfrog_steps=10,
-        initial_x=np.zeros(5).astype(np.float64),
-        target_log_prob_fn=log_prob,
-        event_dims=[-1])
-    with self.test_session() as sess:
-      states_, acceptance_probs_ = sess.run([states, acceptance_probs])
-    self.assertEqual(np.float64, states_.dtype)
-    self.assertEqual(np.float64, acceptance_probs_.dtype)
+    self._testChainWorksDtype(np.float64)
 
   def testChainWorksIn16Bit(self):
-    def log_prob(x):
-      return - math_ops.reduce_sum(x * x, axis=-1)
-    states, acceptance_probs = hmc.chain(
-        n_iterations=10,
-        step_size=np.float16(0.01),
-        n_leapfrog_steps=10,
-        initial_x=np.zeros(5).astype(np.float16),
-        target_log_prob_fn=log_prob,
-        event_dims=[-1])
+    self._testChainWorksDtype(np.float16)
+
+
+class _EnergyComputationTest(object):
+
+  def testHandlesNanFromPotential(self):
+    with self.test_session() as sess:
+      x = [1, np.inf, -np.inf, np.nan]
+      target_log_prob, proposed_target_log_prob = [
+          self.dtype(x.flatten()) for x in np.meshgrid(x, x)]
+      num_chains = len(target_log_prob)
+      dummy_momentums = [-1, 1]
+      momentums = [self.dtype([dummy_momentums] * num_chains)]
+      proposed_momentums = [self.dtype([dummy_momentums] * num_chains)]
+
+      target_log_prob = ops.convert_to_tensor(target_log_prob)
+      momentums = [ops.convert_to_tensor(momentums[0])]
+      proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
+      proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
+
+      energy = _compute_energy_change(
+          target_log_prob,
+          momentums,
+          proposed_target_log_prob,
+          proposed_momentums,
+          independent_chain_ndims=1)
+      grads = gradients_ops.gradients(energy, momentums)
+
+      [actual_energy, grads_] = sess.run([energy, grads])
+
+      # Ensure energy is `inf` (note: that's positive inf) in weird cases and
+      # finite otherwise.
+      expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
+      self.assertAllEqual(expected_energy, actual_energy)
+
+      # Ensure gradient is finite.
+      self.assertAllEqual(np.ones_like(grads_).astype(np.bool),
+                          np.isfinite(grads_))
+
+  def testHandlesNanFromKinetic(self):
     with self.test_session() as sess:
-      states_, acceptance_probs_ = sess.run([states, acceptance_probs])
-    self.assertEqual(np.float16, states_.dtype)
-    self.assertEqual(np.float16, acceptance_probs_.dtype)
+      x = [1, np.inf, -np.inf, np.nan]
+      momentums, proposed_momentums = [
+          [np.reshape(self.dtype(x), [-1, 1])]
+          for x in np.meshgrid(x, x)]
+      num_chains = len(momentums[0])
+      target_log_prob = np.ones(num_chains, self.dtype)
+      proposed_target_log_prob = np.ones(num_chains, self.dtype)
 
+      target_log_prob = ops.convert_to_tensor(target_log_prob)
+      momentums = [ops.convert_to_tensor(momentums[0])]
+      proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
+      proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
 
-if __name__ == '__main__':
+      energy = _compute_energy_change(
+          target_log_prob,
+          momentums,
+          proposed_target_log_prob,
+          proposed_momentums,
+          independent_chain_ndims=1)
+      grads = gradients_ops.gradients(energy, momentums)
+
+      [actual_energy, grads_] = sess.run([energy, grads])
+
+      # Ensure energy is `inf` (note: that's positive inf) in weird cases and
+      # finite otherwise.
+      expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
+      self.assertAllEqual(expected_energy, actual_energy)
+
+      # Ensure gradient is finite.
+      g = grads_[0].reshape([len(x), len(x)])[:, 0]
+      self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))
+
+      # The remaining gradients are nan because the momentum was itself nan or
+      # inf.
+      g = grads_[0].reshape([len(x), len(x)])[:, 1:]
+      self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
+
+
+class EnergyComputationTest16(test.TestCase, _EnergyComputationTest):
+  dtype = np.float16
+
+
+class EnergyComputationTest32(test.TestCase, _EnergyComputationTest):
+  dtype = np.float32
+
+
+class EnergyComputationTest64(test.TestCase, _EnergyComputationTest):
+  dtype = np.float64
+
+
+class _HMCHandlesLists(object):
+
+  def testStateParts(self):
+    with self.test_session() as sess:
+      dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
+      dist_y = independent_lib.Independent(
+          gamma_lib.Gamma(concentration=self.dtype([1, 2]),
+                          rate=self.dtype([0.5, 0.75])),
+          reinterpreted_batch_ndims=1)
+      def target_log_prob(x, y):
+        return dist_x.log_prob(x) + dist_y.log_prob(y)
+      x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
+      samples, _ = hmc.sample_chain(
+          num_results=int(2e3),
+          target_log_prob_fn=target_log_prob,
+          current_state=x0,
+          step_size=0.85,
+          num_leapfrog_steps=3,
+          num_burnin_steps=int(250),
+          seed=49)
+      actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
+      actual_vars = [_reduce_variance(s, axis=0) for s in samples]
+      expected_means = [dist_x.mean(), dist_y.mean()]
+      expected_vars = [dist_x.variance(), dist_y.variance()]
+      [
+          actual_means_,
+          actual_vars_,
+          expected_means_,
+          expected_vars_,
+      ] = sess.run([
+          actual_means,
+          actual_vars,
+          expected_means,
+          expected_vars,
+      ])
+      self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
+      self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.30)
+
+
+class HMCHandlesLists16(_HMCHandlesLists, test.TestCase):
+  dtype = np.float16
+
+
+class HMCHandlesLists32(_HMCHandlesLists, test.TestCase):
+  dtype = np.float32
+
+
+class HMCHandlesLists64(_HMCHandlesLists, test.TestCase):
+  dtype = np.float64
+
+
+if __name__ == "__main__":
   test.main()
index 977d42f..7fd5652 100644 (file)
@@ -12,8 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
-"""
+"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -24,11 +23,9 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import *  # pylint: disabl
 from tensorflow.python.util import all_util
 
 _allowed_symbols = [
-    'chain',
-    'kernel',
-    'leapfrog_integrator',
-    'leapfrog_step',
-    'ais_chain'
+    "sample_chain",
+    "sample_annealed_importance_chain",
+    "kernel",
 ]
 
 all_util.remove_undocumented(__name__, _allowed_symbols)
index 5685a94..f7a11c2 100644 (file)
 # ==============================================================================
 """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
 
-@@chain
-@@update
-@@leapfrog_integrator
-@@leapfrog_step
-@@ais_chain
+@@sample_chain
+@@sample_annealed_importance_chain
+@@kernel
 """
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import numpy as np
 
 from tensorflow.python.framework import dtypes
@@ -32,168 +31,292 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_impl as gradients_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.ops.distributions import util as distributions_util
 
 __all__ = [
-    'chain',
-    'kernel',
-    'leapfrog_integrator',
-    'leapfrog_step',
-    'ais_chain'
+    "sample_chain",
+    "sample_annealed_importance_chain",
+    "kernel",
 ]
 
 
-def _make_potential_and_grad(target_log_prob_fn):
-  def potential_and_grad(x):
-    log_prob_result = -target_log_prob_fn(x)
-    grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result),
-                                           x)[0]
-    return log_prob_result, grad_result
-  return potential_and_grad
-
-
-def chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
-          target_log_prob_fn, event_dims=(), name=None):
+KernelResults = collections.namedtuple(
+    "KernelResults",
+    [
+        "acceptance_probs",
+        "current_grads_target_log_prob",  # "Current result" means "accepted".
+        "current_target_log_prob",  # "Current result" means "accepted".
+        "energy_change",
+        "is_accepted",
+        "proposed_grads_target_log_prob",
+        "proposed_state",
+        "proposed_target_log_prob",
+        "random_positive",
+    ])
+
+
+def _make_dummy_kernel_results(
+    dummy_state,
+    dummy_target_log_prob,
+    dummy_grads_target_log_prob):
+  return KernelResults(
+      acceptance_probs=dummy_target_log_prob,
+      current_grads_target_log_prob=dummy_grads_target_log_prob,
+      current_target_log_prob=dummy_target_log_prob,
+      energy_change=dummy_target_log_prob,
+      is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool),
+      proposed_grads_target_log_prob=dummy_grads_target_log_prob,
+      proposed_state=dummy_state,
+      proposed_target_log_prob=dummy_target_log_prob,
+      random_positive=dummy_target_log_prob,
+  )
+
+
+def sample_chain(
+    num_results,
+    target_log_prob_fn,
+    current_state,
+    step_size,
+    num_leapfrog_steps,
+    num_burnin_steps=0,
+    num_steps_between_results=0,
+    seed=None,
+    current_target_log_prob=None,
+    current_grads_target_log_prob=None,
+    name=None):
   """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains.
 
-  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
-  algorithm that takes a series of gradient-informed steps to produce
-  a Metropolis proposal. This function samples from an HMC Markov
-  chain whose initial state is `initial_x` and whose stationary
-  distribution has log-density `target_log_prob_fn()`.
-
-  This function can update multiple chains in parallel. It assumes
-  that all dimensions of `initial_x` not specified in `event_dims` are
-  independent, and should therefore be updated independently. The
-  output of `target_log_prob_fn()` should sum log-probabilities across
-  all event dimensions. Slices along dimensions not in `event_dims`
-  may have different target distributions; this is up to
+  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm
+  that takes a series of gradient-informed steps to produce a Metropolis
+  proposal. This function samples from an HMC Markov chain at `current_state`
+  and whose stationary distribution has log-unnormalized-density
   `target_log_prob_fn()`.
 
-  This function basically just wraps `hmc.kernel()` in a tf.scan() loop.
+  This function samples from multiple chains in parallel. It assumes that the
+  the leftmost dimensions of (each) `current_state` (part) index an independent
+  chain.  The function `target_log_prob_fn()` sums log-probabilities across
+  event dimensions (i.e., current state (part) rightmost dimensions). Each
+  element of the output of `target_log_prob_fn()` represents the (possibly
+  unnormalized) log-probability of the joint distribution over (all) the current
+  state (parts).
 
-  Args:
-    n_iterations: Integer number of Markov chain updates to run.
-    step_size: Scalar step size or array of step sizes for the
-      leapfrog integrator. Broadcasts to the shape of
-      `initial_x`. Larger step sizes lead to faster progress, but
-      too-large step sizes make rejection exponentially more likely.
-      When possible, it's often helpful to match per-variable step
-      sizes to the standard deviations of the target distribution in
-      each variable.
-    n_leapfrog_steps: Integer number of steps to run the leapfrog
-      integrator for. Total progress per HMC step is roughly
-      proportional to step_size * n_leapfrog_steps.
-    initial_x: Tensor of initial state(s) of the Markov chain(s).
-    target_log_prob_fn: Python callable which takes an argument like `initial_x`
-      and returns its (possibly unnormalized) log-density under the target
-      distribution.
-    event_dims: List of dimensions that should not be treated as
-      independent. This allows for multiple chains to be run independently
-      in parallel. Default is (), i.e., all dimensions are independent.
-    name: Python `str` name prefixed to Ops created by this function.
+  The `current_state` can be represented as a single `Tensor` or a `list` of
+  `Tensors` which collectively represent the current state. When specifying a
+  `list`, one must also specify a list of `step_size`s.
 
-  Returns:
-    acceptance_probs: Tensor with the acceptance probabilities for each
-      iteration. Has shape matching `target_log_prob_fn(initial_x)`.
-    chain_states: Tensor with the state of the Markov chain at each iteration.
-      Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`.
+  Only one out of every `num_steps_between_samples + 1` steps is included in the
+  returned results. This "thinning" comes at a cost of reduced statistical
+  power, while reducing memory requirements and autocorrelation. For more
+  discussion see [1].
+
+  [1]: "Statistically efficient thinning of a Markov chain sampler."
+       Art B. Owen. April 2017.
+       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
 
   #### Examples:
 
-  ```python
-  # Sampling from a standard normal (note `log_joint()` is unnormalized):
-  def log_joint(x):
-    return tf.reduce_sum(-0.5 * tf.square(x))
-  chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint,
-                                      event_dims=[0])
-  # Discard first half of chain as warmup/burn-in
-  warmed_up = chain[500:]
-  mean_est = tf.reduce_mean(warmed_up, 0)
-  var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
-  ```
+  ##### Sample from a diagonal-variance Gaussian.
 
   ```python
-  # Sampling from a diagonal-variance Gaussian:
-  variances = tf.linspace(1., 3., 10)
-  def log_joint(x):
-    return tf.reduce_sum(-0.5 / variances * tf.square(x))
-  chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint,
-                                      event_dims=[0])
-  # Discard first half of chain as warmup/burn-in
-  warmed_up = chain[500:]
-  mean_est = tf.reduce_mean(warmed_up, 0)
-  var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
+  tfd = tf.contrib.distributions
+
+  def make_likelihood(true_variances):
+    return tfd.MultivariateNormalDiag(
+        scale_diag=tf.sqrt(true_variances))
+
+  dims = 10
+  dtype = np.float32
+  true_variances = tf.linspace(dtype(1), dtype(3), dims)
+  likelihood = make_likelihood(true_variances)
+
+  states, kernel_results = hmc.sample_chain(
+      num_results=1000,
+      target_log_prob_fn=likelihood.log_prob,
+      current_state=tf.zeros(dims),
+      step_size=0.5,
+      num_leapfrog_steps=2,
+      num_burnin_steps=500)
+
+  # Compute sample stats.
+  sample_mean = tf.reduce_mean(states, axis=0)
+  sample_var = tf.reduce_mean(
+      tf.squared_difference(states, sample_mean),
+      axis=0)
   ```
 
-  ```python
-  # Sampling from factor-analysis posteriors with known factors W:
-  # mu[i, j] ~ Normal(0, 1)
-  # x[i] ~ Normal(matmul(mu[i], W), I)
-  def log_joint(mu, x, W):
-    prior = -0.5 * tf.reduce_sum(tf.square(mu), 1)
-    x_mean = tf.matmul(mu, W)
-    likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1)
-    return prior + likelihood
-  chain, acceptance_probs = hmc.chain(1000, 0.1, 2,
-                                      tf.zeros([x.shape[0], W.shape[0]]),
-                                      lambda mu: log_joint(mu, x, W),
-                                      event_dims=[1])
-  # Discard first half of chain as warmup/burn-in
-  warmed_up = chain[500:]
-  mean_est = tf.reduce_mean(warmed_up, 0)
-  var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
+  ##### Sampling from factor-analysis posteriors with known factors.
+
+  I.e.,
+
+  ```none
+  for i=1..n:
+    w[i] ~ Normal(0, eye(d))            # prior
+    x[i] ~ Normal(loc=matmul(w[i], F))  # likelihood
   ```
 
+  where `F` denotes factors.
+
   ```python
-  # Sampling from the posterior of a Bayesian regression model.:
-
-  # Run 100 chains in parallel, each with a different initialization.
-  initial_beta = tf.random_normal([100, x.shape[1]])
-  chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta,
-                                      log_joint_partial, event_dims=[1])
-  # Discard first halves of chains as warmup/burn-in
-  warmed_up = chain[500:]
-  # Averaging across samples within a chain and across chains
-  mean_est = tf.reduce_mean(warmed_up, [0, 1])
-  var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est)
+  tfd = tf.contrib.distributions
+
+  def make_prior(dims, dtype):
+    return tfd.MultivariateNormalDiag(
+        loc=tf.zeros(dims, dtype))
+
+  def make_likelihood(weights, factors):
+    return tfd.MultivariateNormalDiag(
+        loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))
+
+  # Setup data.
+  num_weights = 10
+  num_factors = 4
+  num_chains = 100
+  dtype = np.float32
+
+  prior = make_prior(num_weights, dtype)
+  weights = prior.sample(num_chains)
+  factors = np.random.randn(num_factors, num_weights).astype(dtype)
+  x = make_likelihood(weights, factors).sample(num_chains)
+
+  def target_log_prob(w):
+    # Target joint is: `f(w) = p(w, x | factors)`.
+    return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)
+
+  # Get `num_results` samples from `num_chains` independent chains.
+  chains_states, kernels_results = hmc.sample_chain(
+      num_results=1000,
+      target_log_prob_fn=target_log_prob,
+      current_state=tf.zeros([num_chains, dims], dtype),
+      step_size=0.1,
+      num_leapfrog_steps=2,
+      num_burnin_steps=500)
+
+  # Compute sample stats.
+  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
+  sample_var = tf.reduce_mean(
+      tf.squared_difference(chains_states, sample_mean),
+      axis=[0, 1])
   ```
-  """
-  with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size,
-                                          n_leapfrog_steps, initial_x]):
-    initial_x = ops.convert_to_tensor(initial_x, name='initial_x')
-    non_event_shape = array_ops.shape(target_log_prob_fn(initial_x))
-
-    def body(a, _):
-      updated_x, acceptance_probs, log_prob, grad = kernel(
-          step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims,
-          a[2], a[3])
-      return updated_x, acceptance_probs, log_prob, grad
-
-    potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
-    potential, grad = potential_and_grad(initial_x)
-    return functional_ops.scan(
-        body, array_ops.zeros(n_iterations, dtype=initial_x.dtype),
-        (initial_x,
-         array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
-         -potential, -grad))[:2]
 
+  Args:
+    num_results: Integer number of Markov chain draws.
+    target_log_prob_fn: Python callable which takes an argument like
+      `current_state` (or `*current_state` if it's a list) and returns its
+      (possibly unnormalized) log-density under the target distribution.
+    current_state: `Tensor` or Python `list` of `Tensor`s representing the
+      current state(s) of the Markov chain(s). The first `r` dimensions index
+      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
+    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
+      for the leapfrog integrator. Must broadcast with the shape of
+      `current_state`. Larger step sizes lead to faster progress, but too-large
+      step sizes make rejection exponentially more likely. When possible, it's
+      often helpful to match per-variable step sizes to the standard deviations
+      of the target distribution in each variable.
+    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+      for. Total progress per HMC step is roughly proportional to `step_size *
+      num_leapfrog_steps`.
+    num_burnin_steps: Integer number of chain steps to take before starting to
+      collect results.
+      Default value: 0 (i.e., no burn-in).
+    num_steps_between_results: Integer number of chain steps between collecting
+      a result. Only one out of every `num_steps_between_samples + 1` steps is
+      included in the returned results. This "thinning" comes at a cost of
+      reduced statistical power, while reducing memory requirements and
+      autocorrelation. For more discussion see [1].
+      Default value: 0 (i.e., no subsampling).
+    seed: Python integer to seed the random number generator.
+    current_target_log_prob: (Optional) `Tensor` representing the value of
+      `target_log_prob_fn` at the `current_state`. The only reason to specify
+      this argument is to reduce TF graph size.
+      Default value: `None` (i.e., compute as needed).
+    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
+      representing gradient of `target_log_prob` at the `current_state` and wrt
+      the `current_state`. Must have same shape as `current_state`. The only
+      reason to specify this argument is to reduce TF graph size.
+      Default value: `None` (i.e., compute as needed).
+    name: Python `str` name prefixed to Ops created by this function.
+      Default value: `None` (i.e., "hmc_sample_chain").
+
+  Returns:
+    accepted_states: Tensor or Python list of `Tensor`s representing the
+      state(s) of the Markov chain(s) at each result step. Has same shape as
+      input `current_state` but with a prepended `num_results`-size dimension.
+    kernel_results: `collections.namedtuple` of internal calculations used to
+      advance the chain.
+  """
+  with ops.name_scope(
+      name, "hmc_sample_chain",
+      [num_results, current_state, step_size, num_leapfrog_steps,
+       num_burnin_steps, num_steps_between_results, seed,
+       current_target_log_prob, current_grads_target_log_prob]):
+    with ops.name_scope("initialize"):
+      [
+          current_state,
+          step_size,
+          current_target_log_prob,
+          current_grads_target_log_prob,
+      ] = _prepare_args(
+          target_log_prob_fn, current_state, step_size,
+          current_target_log_prob, current_grads_target_log_prob)
+    def _run_chain(num_steps, current_state, seed, kernel_results):
+      """Runs the chain(s) for `num_steps`."""
+      def _loop_body(iter_, current_state, kernel_results):
+        return [iter_ + 1] + list(kernel(
+            target_log_prob_fn,
+            current_state,
+            step_size,
+            num_leapfrog_steps,
+            seed,
+            kernel_results.current_target_log_prob,
+            kernel_results.current_grads_target_log_prob))
+      return control_flow_ops.while_loop(
+          cond=lambda iter_, *args: iter_ < num_steps,
+          body=_loop_body,
+          loop_vars=[0, current_state, kernel_results])[1:]  # Lop-off "iter_".
+
+    def _scan_body(args_list, _):
+      """Closure which implements `tf.scan` body."""
+      current_state, kernel_results = args_list
+      return _run_chain(num_steps_between_results + 1, current_state, seed,
+                        kernel_results)
+
+    current_state, kernel_results = _run_chain(
+        num_burnin_steps,
+        current_state,
+        distributions_util.gen_new_seed(
+            seed, salt="hmc_sample_chain_burnin"),
+        _make_dummy_kernel_results(
+            current_state,
+            current_target_log_prob,
+            current_grads_target_log_prob))
 
-def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
-              target_log_prob_fn, proposal_log_prob_fn, event_dims=(),
-              name=None):
+    return functional_ops.scan(
+        fn=_scan_body,
+        elems=array_ops.zeros(num_results, dtype=dtypes.bool),  # Dummy arg.
+        initializer=[current_state, kernel_results])
+
+
+def sample_annealed_importance_chain(
+    proposal_log_prob_fn,
+    num_steps,
+    target_log_prob_fn,
+    current_state,
+    step_size,
+    num_leapfrog_steps,
+    seed=None,
+    name=None):
   """Runs annealed importance sampling (AIS) to estimate normalizing constants.
 
-  This routine uses Hamiltonian Monte Carlo to sample from a series of
+  This function uses Hamiltonian Monte Carlo to sample from a series of
   distributions that slowly interpolates between an initial "proposal"
-  distribution
+  distribution:
 
   `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`
 
-  and the target distribution
+  and the target distribution:
 
   `exp(target_log_prob_fn(x) - target_log_normalizer)`,
 
@@ -202,113 +325,183 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
   normalizing constants of the initial distribution and the target
   distribution:
 
-  E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer).
-
-  Args:
-    n_iterations: Integer number of Markov chain updates to run. More
-      iterations means more expense, but smoother annealing between q
-      and p, which in turn means exponentially lower variance for the
-      normalizing constant estimator.
-    step_size: Scalar step size or array of step sizes for the
-      leapfrog integrator. Broadcasts to the shape of
-      `initial_x`. Larger step sizes lead to faster progress, but
-      too-large step sizes make rejection exponentially more likely.
-      When possible, it's often helpful to match per-variable step
-      sizes to the standard deviations of the target distribution in
-      each variable.
-    n_leapfrog_steps: Integer number of steps to run the leapfrog
-      integrator for. Total progress per HMC step is roughly
-      proportional to step_size * n_leapfrog_steps.
-    initial_x: Tensor of initial state(s) of the Markov chain(s). Must
-      be a sample from q, or results will be incorrect.
-    target_log_prob_fn: Python callable which takes an argument like `initial_x`
-      and returns its (possibly unnormalized) log-density under the target
-      distribution.
-    proposal_log_prob_fn: Python callable that returns the log density of the
-      initial distribution.
-    event_dims: List of dimensions that should not be treated as
-      independent. This allows for multiple chains to be run independently
-      in parallel. Default is (), i.e., all dimensions are independent.
-    name: Python `str` name prefixed to Ops created by this function.
-
-  Returns:
-    ais_weights: Tensor with the estimated weight(s). Has shape matching
-      `target_log_prob_fn(initial_x)`.
-    chain_states: Tensor with the state(s) of the Markov chain(s) the final
-      iteration. Has shape matching `initial_x`.
-    acceptance_probs: Tensor with the acceptance probabilities for the final
-      iteration. Has shape matching `target_log_prob_fn(initial_x)`.
+  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.
 
   #### Examples:
 
+  ##### Estimate the normalizing constant of a log-gamma distribution.
+
   ```python
-  # Estimating the normalizing constant of a log-gamma distribution:
-  def proposal_log_prob(x):
-    # Standard normal log-probability. This is properly normalized.
-    return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1)
-  def target_log_prob(x):
-    # Unnormalized log-gamma(2, 3) distribution.
-    # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1]
-    return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1)
+  tfd = tf.contrib.distributions
+
   # Run 100 AIS chains in parallel
-  initial_x = tf.random_normal([100, 20])
-  w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob,
-                          proposal_log_prob, event_dims=[1])
-  log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100)
+  num_chains = 100
+  dims = 20
+  dtype = np.float32
+
+  proposal = tfd.MultivatiateNormalDiag(
+     loc=tf.zeros([dims], dtype=dtype))
+
+  target = tfd.TransformedDistribution(
+    distribution=tfd.Gamma(concentration=dtype(2),
+                           rate=dtype(3)),
+    bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()),
+    event_shape=[dims])
+
+  chains_state, ais_weights, kernels_results = (
+      hmc.sample_annealed_importance_chain(
+          proposal_log_prob_fn=proposal.log_prob,
+          num_steps=1000,
+          target_log_prob_fn=target.log_prob,
+          step_size=0.2,
+          current_state=proposal.sample(num_chains),
+          num_leapfrog_steps=2))
+
+  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
+                              - np.log(num_chains))
+  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
   ```
 
+  ##### Estimate marginal likelihood of a Bayesian regression model.
+
   ```python
-  # Estimating the marginal likelihood of a Bayesian regression model:
-  base_measure = -0.5 * np.log(2 * np.pi)
-  def proposal_log_prob(x):
-    # Standard normal log-probability. This is properly normalized.
-    return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1)
-  def regression_log_joint(beta, x, y):
-    # This function returns a vector whose ith element is log p(beta[i], y | x).
-    # Each row of beta corresponds to the state of an independent Markov chain.
-    log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1)
-    means = tf.matmul(beta, x, transpose_b=True)
-    log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) +
-                                   base_measure, 1)
-    return log_prior + log_likelihood
-  def log_joint_partial(beta):
-    return regression_log_joint(beta, x, y)
+  tfd = tf.contrib.distributions
+
+  def make_prior(dims, dtype):
+    return tfd.MultivariateNormalDiag(
+        loc=tf.zeros(dims, dtype))
+
+  def make_likelihood(weights, x):
+    return tfd.MultivariateNormalDiag(
+        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))
+
   # Run 100 AIS chains in parallel
-  initial_beta = tf.random_normal([100, x.shape[1]])
-  w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta,
-                                     log_joint_partial, proposal_log_prob,
-                                     event_dims=[1])
-  log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100)
+  num_chains = 100
+  dims = 10
+  dtype = np.float32
+
+  # Make training data.
+  x = np.random.randn(num_chains, dims).astype(dtype)
+  true_weights = np.random.randn(dims).astype(dtype)
+  y = np.dot(x, true_weights) + np.random.randn(num_chains)
+
+  # Setup model.
+  prior = make_prior(dims, dtype)
+  def target_log_prob_fn(weights):
+    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)
+
+  proposal = tfd.MultivariateNormalDiag(
+      loc=tf.zeros(dims, dtype))
+
+  weight_samples, ais_weights, kernel_results = (
+      hmc.sample_annealed_importance_chain(
+        num_steps=1000,
+        proposal_log_prob_fn=proposal.log_prob,
+        target_log_prob_fn=target_log_prob_fn
+        current_state=tf.zeros([num_chains, dims], dtype),
+        step_size=0.1,
+        num_leapfrog_steps=2))
+  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
+                             - np.log(num_chains))
   ```
+
+  Args:
+    proposal_log_prob_fn: Python callable that returns the log density of the
+      initial distribution.
+    num_steps: Integer number of Markov chain updates to run. More
+      iterations means more expense, but smoother annealing between q
+      and p, which in turn means exponentially lower variance for the
+      normalizing constant estimator.
+    target_log_prob_fn: Python callable which takes an argument like
+      `current_state` (or `*current_state` if it's a list) and returns its
+      (possibly unnormalized) log-density under the target distribution.
+    current_state: `Tensor` or Python `list` of `Tensor`s representing the
+      current state(s) of the Markov chain(s). The first `r` dimensions index
+      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
+    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
+      for the leapfrog integrator. Must broadcast with the shape of
+      `current_state`. Larger step sizes lead to faster progress, but too-large
+      step sizes make rejection exponentially more likely. When possible, it's
+      often helpful to match per-variable step sizes to the standard deviations
+      of the target distribution in each variable.
+    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+      for. Total progress per HMC step is roughly proportional to `step_size *
+      num_leapfrog_steps`.
+    seed: Python integer to seed the random number generator.
+    name: Python `str` name prefixed to Ops created by this function.
+      Default value: `None` (i.e., "hmc_sample_annealed_importance_chain").
+
+  Returns:
+    accepted_state: `Tensor` or Python list of `Tensor`s representing the
+      state(s) of the Markov chain(s) at the final iteration. Has same shape as
+      input `current_state`.
+    ais_weights: Tensor with the estimated weight(s). Has shape matching
+      `target_log_prob_fn(current_state)`.
   """
-  with ops.name_scope(name, 'hmc_ais_chain',
-                      [n_iterations, step_size, n_leapfrog_steps, initial_x]):
-    non_event_shape = array_ops.shape(target_log_prob_fn(initial_x))
-
-    beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:]
-    def _body(a, beta):  # pylint: disable=missing-docstring
-      def log_prob_beta(x):
-        return ((1 - beta) * proposal_log_prob_fn(x) +
-                beta * target_log_prob_fn(x))
-      last_x = a[0]
-      w = a[2]
-      w += (1. / n_iterations) * (target_log_prob_fn(last_x) -
-                                  proposal_log_prob_fn(last_x))
-      # TODO(b/66917083): There's an opportunity for gradient reuse here.
-      updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps,
-                                                 last_x, log_prob_beta,
-                                                 event_dims)
-      return updated_x, acceptance_probs, w
-
-    x, acceptance_probs, w = functional_ops.scan(
-        _body, beta_series,
-        (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
-         array_ops.zeros(non_event_shape, dtype=initial_x.dtype)))
-  return w[-1], x[-1], acceptance_probs[-1]
-
-
-def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
-           x_log_prob=None, x_grad=None, name=None):
+  def make_convex_combined_log_prob_fn(iter_):
+    def _fn(*args):
+      p = proposal_log_prob_fn(*args)
+      t = target_log_prob_fn(*args)
+      dtype = p.dtype.base_dtype
+      beta = (math_ops.cast(iter_ + 1, dtype)
+              / math_ops.cast(num_steps, dtype))
+      return (1. - beta) * p + beta * t
+    return _fn
+
+  with ops.name_scope(
+      name, "hmc_sample_annealed_importance_chain",
+      [num_steps, current_state, step_size, num_leapfrog_steps, seed]):
+    with ops.name_scope("initialize"):
+      [
+          current_state,
+          step_size,
+          current_log_prob,
+          current_grads_log_prob,
+      ] = _prepare_args(
+          make_convex_combined_log_prob_fn(iter_=0),
+          current_state,
+          step_size,
+          description="convex_combined_log_prob")
+    def _loop_body(iter_, ais_weights, current_state, kernel_results):
+      """Closure which implements `tf.while_loop` body."""
+      current_state_parts = (list(current_state)
+                             if _is_list_like(current_state)
+                             else [current_state])
+      ais_weights += ((target_log_prob_fn(*current_state_parts)
+                       - proposal_log_prob_fn(*current_state_parts))
+                      / math_ops.cast(num_steps, ais_weights.dtype))
+      return [iter_ + 1, ais_weights] + list(kernel(
+          make_convex_combined_log_prob_fn(iter_),
+          current_state,
+          step_size,
+          num_leapfrog_steps,
+          seed,
+          kernel_results.current_target_log_prob,
+          kernel_results.current_grads_target_log_prob))
+
+    [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop(
+        cond=lambda iter_, *args: iter_ < num_steps,
+        body=_loop_body,
+        loop_vars=[
+            0,  # iter_
+            array_ops.zeros_like(current_log_prob),  # ais_weights
+            current_state,
+            _make_dummy_kernel_results(current_state,
+                                       current_log_prob,
+                                       current_grads_log_prob),
+        ])[1:]  # Lop-off "iter_".
+
+    return [current_state, ais_weights, kernel_results]
+
+
+def kernel(target_log_prob_fn,
+           current_state,
+           step_size,
+           num_leapfrog_steps,
+           seed=None,
+           current_target_log_prob=None,
+           current_grads_target_log_prob=None,
+           name=None):
   """Runs one iteration of Hamiltonian Monte Carlo.
 
   Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
@@ -316,334 +509,625 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
   a Metropolis proposal. This function applies one step of HMC to
   randomly update the variable `x`.
 
-  This function can update multiple chains in parallel. It assumes
-  that all dimensions of `x` not specified in `event_dims` are
-  independent, and should therefore be updated independently. The
-  output of `target_log_prob_fn()` should sum log-probabilities across
-  all event dimensions. Slices along dimensions not in `event_dims`
-  may have different target distributions; for example, if
-  `event_dims == (1,)`, then `x[0, :]` could have a different target
-  distribution from x[1, :]. This is up to `target_log_prob_fn()`.
-
-  Args:
-    step_size: Scalar step size or array of step sizes for the
-      leapfrog integrator. Broadcasts to the shape of
-      `x`. Larger step sizes lead to faster progress, but
-      too-large step sizes make rejection exponentially more likely.
-      When possible, it's often helpful to match per-variable step
-      sizes to the standard deviations of the target distribution in
-      each variable.
-    n_leapfrog_steps: Integer number of steps to run the leapfrog
-      integrator for. Total progress per HMC step is roughly
-      proportional to step_size * n_leapfrog_steps.
-    x: Tensor containing the value(s) of the random variable(s) to update.
-    target_log_prob_fn: Python callable which takes an argument like `initial_x`
-      and returns its (possibly unnormalized) log-density under the target
-      distribution.
-    event_dims: List of dimensions that should not be treated as
-      independent. This allows for multiple chains to be run independently
-      in parallel. Default is (), i.e., all dimensions are independent.
-    x_log_prob (optional): Tensor containing the cached output of a previous
-      call to `target_log_prob_fn()` evaluated at `x` (such as that provided by
-      a previous call to `kernel()`). Providing `x_log_prob` and
-      `x_grad` saves one gradient computation per call to `kernel()`.
-    x_grad (optional): Tensor containing the cached gradient of
-      `target_log_prob_fn()` evaluated at `x` (such as that provided by
-      a previous call to `kernel()`). Providing `x_log_prob` and
-      `x_grad` saves one gradient computation per call to `kernel()`.
-    name: Python `str` name prefixed to Ops created by this function.
-
-  Returns:
-    updated_x: The updated variable(s) x. Has shape matching `initial_x`.
-    acceptance_probs: Tensor with the acceptance probabilities for the final
-      iteration. This is useful for diagnosing step size problems etc. Has
-      shape matching `target_log_prob_fn(initial_x)`.
-    new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`.
-    new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at
-      `updated_x`.
+  This function can update multiple chains in parallel. It assumes that all
+  leftmost dimensions of `current_state` index independent chain states (and are
+  therefore updated independently). The output of `target_log_prob_fn()` should
+  sum log-probabilities across all event dimensions. Slices along the rightmost
+  dimensions may have different target distributions; for example,
+  `current_state[0, :]` could have a different target distribution from
+  `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
+  independent chains is `tf.size(target_log_prob_fn(*current_state))`.)
 
   #### Examples:
 
+  ##### Simple chain with warm-up.
+
   ```python
+  tfd = tf.contrib.distributions
+
   # Tuning acceptance rates:
+  dtype = np.float32
   target_accept_rate = 0.631
-  def target_log_prob(x):
-    # Standard normal
-    return tf.reduce_sum(-0.5 * tf.square(x))
-  initial_x = tf.zeros([10])
-  initial_log_prob = target_log_prob(initial_x)
-  initial_grad = tf.gradients(initial_log_prob, initial_x)[0]
-  # Algorithm state
-  x = tf.Variable(initial_x, name='x')
-  step_size = tf.Variable(1., name='step_size')
-  last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob')
-  last_grad = tf.Variable(initial_grad, name='last_grad')
-  # Compute updates
-  new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x,
-                                                      target_log_prob,
-                                                      event_dims=[0],
-                                                      x_log_prob=last_log_prob)
-  x_update = tf.assign(x, new_x)
-  log_prob_update = tf.assign(last_log_prob, log_prob)
-  grad_update = tf.assign(last_grad, grad)
-  step_size_update = tf.assign(step_size,
-                               tf.where(acceptance_prob > target_accept_rate,
-                                        step_size * 1.01, step_size / 1.01))
-  adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update]
-  sampling_updates = [x_update, log_prob_update, grad_update]
-
-  sess = tf.Session()
-  sess.run(tf.global_variables_initializer())
+  num_warmup_iter = 500
+  num_chain_iter = 500
+
+  x = tf.get_variable(name="x", initializer=dtype(1))
+  step_size = tf.get_variable(name="step_size", initializer=dtype(1))
+
+  target = tfd.Normal(loc=dtype(0), scale=dtype(1))
+
+  new_x, other_results = hmc.kernel(
+      target_log_prob_fn=target.log_prob,
+      current_state=x,
+      step_size=step_size,
+      num_leapfrog_steps=3)[:4]
+
+  x_update = x.assign(new_x)
+
+  step_size_update = step_size.assign_add(
+      step_size * tf.where(
+        other_results.acceptance_probs > target_accept_rate,
+        0.01, -0.01))
+
+  warmup = tf.group([x_update, step_size_update])
+
+  tf.global_variables_initializer().run()
+
+  sess.graph.finalize()  # No more graph building.
+
   # Warm up the sampler and adapt the step size
-  for i in xrange(500):
-    sess.run(adaptive_updates)
+  for _ in xrange(num_warmup_iter):
+    sess.run(warmup)
+
   # Collect samples without adapting step size
-  samples = np.zeros([500, 10])
-  for i in xrange(500):
-    x_val, _ = sess.run([new_x, sampling_updates])
-    samples[i] = x_val
+  samples = np.zeros([num_chain_iter])
+  for i in xrange(num_chain_iter):
+    _, x_, target_log_prob_, grad_ = sess.run([
+        x_update,
+        x,
+        other_results.target_log_prob,
+        other_results.grads_target_log_prob])
+    samples[i] = x_
+
+  print(samples.mean(), samples.std())
   ```
 
-  ```python
-  # Empirical-Bayes estimation of a hyperparameter by MCMC-EM:
-
-  # Problem setup
-  N = 150
-  D = 10
-  x = np.random.randn(N, D).astype(np.float32)
-  true_sigma = 0.5
-  true_beta = true_sigma * np.random.randn(D).astype(np.float32)
-  y = x.dot(true_beta) + np.random.randn(N).astype(np.float32)
-
-  def log_prior(beta, log_sigma):
-    return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) -
-                         log_sigma)
-  def regression_log_joint(beta, log_sigma, x, y):
-    # This function returns log p(beta | log_sigma) + log p(y | x, beta).
-    means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True)
-    means = tf.squeeze(means)
-    log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means))
-    return log_prior(beta, log_sigma) + log_likelihood
-  def log_joint_partial(beta):
-    return regression_log_joint(beta, log_sigma, x, y)
-  # Our estimate of log(sigma)
-  log_sigma = tf.Variable(0., name='log_sigma')
-  # The state of the Markov chain
-  beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta')
-  new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial,
-                                 event_dims=[0])
-  beta_update = tf.assign(beta, new_beta)
-  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
-  with tf.control_dependencies([beta_update]):
-    log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma),
-                                          var_list=[log_sigma])
-
-  sess = tf.Session()
-  sess.run(tf.global_variables_initializer())
-  log_sigma_history = np.zeros(1000)
-  for i in xrange(1000):
-    log_sigma_val, _ = sess.run([log_sigma, log_sigma_update])
-    log_sigma_history[i] = log_sigma_val
-  # Should converge to something close to true_sigma
-  plt.plot(np.exp(log_sigma_history))
-  ```
-  """
-  with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]):
-    potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
-    x = ops.convert_to_tensor(x, name='x')
-
-    x_shape = array_ops.shape(x)
-    m = random_ops.random_normal(x_shape, dtype=x.dtype)
-
-    kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims)
-
-    if (x_log_prob is not None) and (x_grad is not None):
-      log_potential_0, grad_0 = -x_log_prob, -x_grad  # pylint: disable=invalid-unary-operand-type
-    else:
-      if x_log_prob is not None:
-        logging.warn('x_log_prob was provided, but x_grad was not,'
-                     ' so x_log_prob was not used.')
-      if x_grad is not None:
-        logging.warn('x_grad was provided, but x_log_prob was not,'
-                     ' so x_grad was not used.')
-      log_potential_0, grad_0 = potential_and_grad(x)
-
-    new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator(
-        step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0)
-
-    kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims)
-
-    energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0
-    # Treat NaN as infinite energy (and therefore guaranteed rejection).
-    energy_change = array_ops.where(
-        math_ops.is_nan(energy_change),
-        array_ops.fill(array_ops.shape(energy_change),
-                       energy_change.dtype.as_numpy_dtype(np.inf)),
-        energy_change)
-    acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.))
-    accepted = (
-        random_ops.random_uniform(
-            array_ops.shape(acceptance_probs), dtype=x.dtype)
-        < acceptance_probs)
-    new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0)
-
-    # TODO(b/65738010): This should work, but it doesn't for now.
-    # reduced_shape = math_ops.reduced_shape(x_shape, event_dims)
-    reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims,
-                                                        keep_dims=True))
-    accepted = array_ops.reshape(accepted, reduced_shape)
-    accepted = math_ops.logical_or(
-        accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool))
-    new_x = array_ops.where(accepted, new_x, x)
-    new_grad = -array_ops.where(accepted, grad_1, grad_0)
-
-  # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect
-  # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate).  This
-  # should be fixed.
-  return new_x, acceptance_probs, new_log_prob, new_grad
-
-
-def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum,
-                        potential_and_grad, initial_grad, name=None):
-  """Applies `n_steps` steps of the leapfrog integrator.
-
-  This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing
-  gradient computations where possible.
+  ##### Sample from more complicated posterior.
 
-  Args:
-    step_size: Scalar step size or array of step sizes for the
-      leapfrog integrator. Broadcasts to the shape of
-      `initial_position`. Larger step sizes lead to faster progress, but
-      too-large step sizes lead to larger discretization error and
-      worse energy conservation.
-    n_steps: Number of steps to run the leapfrog integrator.
-    initial_position: Tensor containing the value(s) of the position variable(s)
-      to update.
-    initial_momentum: Tensor containing the value(s) of the momentum variable(s)
-      to update.
-    potential_and_grad: Python callable that takes a position tensor like
-      `initial_position` and returns the potential energy and its gradient at
-      that position.
-    initial_grad: Tensor with the value of the gradient of the potential energy
-      at `initial_position`.
-    name: Python `str` name prefixed to Ops created by this function.
-
-  Returns:
-    updated_position: Updated value of the position.
-    updated_momentum: Updated value of the momentum.
-    new_potential: Potential energy of the new position. Has shape matching
-      `potential_and_grad(initial_position)`.
-    new_grad: Gradient from potential_and_grad() evaluated at the new position.
-      Has shape matching `initial_position`.
+  I.e.,
 
-  Example: Simple quadratic potential.
+  ```none
+    W ~ MVN(loc=0, scale=sigma * eye(dims))
+    for i=1...num_samples:
+        X[i] ~ MVN(loc=0, scale=eye(dims))
+      eps[i] ~ Normal(loc=0, scale=1)
+        Y[i] = X[i].T * W + eps[i]
+  ```
 
   ```python
-  def potential_and_grad(position):
-    return tf.reduce_sum(0.5 * tf.square(position)), position
-  position = tf.placeholder(np.float32)
-  momentum = tf.placeholder(np.float32)
-  potential, grad = potential_and_grad(position)
-  new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator(
-    0.1, 3, position, momentum, potential_and_grad, grad)
-
-  sess = tf.Session()
-  position_val = np.random.randn(10)
-  momentum_val = np.random.randn(10)
-  potential_val, grad_val = sess.run([potential, grad],
-                                     {position: position_val})
-  positions = np.zeros([100, 10])
-  for i in xrange(100):
-    position_val, momentum_val, potential_val, grad_val = sess.run(
-      [new_position, new_momentum, new_potential, new_grad],
-      {position: position_val, momentum: momentum_val})
-    positions[i] = position_val
-  # Should trace out sinusoidal dynamics.
-  plt.plot(positions[:, 0])
-  ```
-  """
-  def leapfrog_wrapper(step_size, x, m, grad, l):
-    x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad)
-    return step_size, x, m, grad, l + 1
+  tfd = tf.contrib.distributions
+
+  def make_training_data(num_samples, dims, sigma):
+    dt = np.asarray(sigma).dtype
+    zeros = tf.zeros(dims, dtype=dt)
+    x = tfd.MultivariateNormalDiag(
+        loc=zeros).sample(num_samples, seed=1)
+    w = tfd.MultivariateNormalDiag(
+        loc=zeros,
+        scale_identity_multiplier=sigma).sample(seed=2)
+    noise = tfd.Normal(
+        loc=dt(0),
+        scale=dt(1)).sample(num_samples, seed=3)
+    y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
+    return y, x, w
+
+  def make_prior(sigma, dims):
+    # p(w | sigma)
+    return tfd.MultivariateNormalDiag(
+        loc=tf.zeros([dims], dtype=sigma.dtype),
+        scale_identity_multiplier=sigma)
+
+  def make_likelihood(x, w):
+    # p(y | x, w)
+    return tfd.MultivariateNormalDiag(
+        loc=tf.tensordot(x, w, axes=[[1], [0]]))
+
+  # Setup assumptions.
+  dtype = np.float32
+  num_samples = 150
+  dims = 10
+  num_iters = int(5e3)
+
+  true_sigma = dtype(0.5)
+  y, x, true_weights = make_training_data(num_samples, dims, true_sigma)
+
+  # Estimate of `log(true_sigma)`.
+  log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
+  sigma = tf.exp(log_sigma)
+
+  # State of the Markov chain.
+  weights = tf.get_variable(
+      name="weights",
+      initializer=np.random.randn(dims).astype(dtype))
+
+  prior = make_prior(sigma, dims)
+
+  def joint_log_prob_fn(w):
+    # f(w) = log p(w, y | x)
+    return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)
+
+  weights_update = weights.assign(
+      hmc.kernel(target_log_prob_fn=joint_log_prob,
+                 current_state=weights,
+                 step_size=0.1,
+                 num_leapfrog_steps=5)[0])
+
+  with tf.control_dependencies([weights_update]):
+    loss = -prior.log_prob(weights)
+
+  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+  log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])
+
+  sess.graph.finalize()  # No more graph building.
 
-  def counter_fn(a, b, c, d, counter):  # pylint: disable=unused-argument
-    return counter < n_steps
+  tf.global_variables_initializer().run()
 
-  with ops.name_scope(name, 'leapfrog_integrator',
-                      [step_size, n_steps, initial_position, initial_momentum,
-                       initial_grad]):
-    _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop(
-        counter_fn, leapfrog_wrapper, [step_size, initial_position,
-                                       initial_momentum, initial_grad,
-                                       array_ops.constant(0)], back_prop=False)
-    # We're counting on the runtime to eliminate this redundant computation.
-    new_potential, new_grad = potential_and_grad(new_x)
-  return new_x, new_m, new_potential, new_grad
+  sigma_history = np.zeros(num_iters, dtype)
+  weights_history = np.zeros([num_iters, dims], dtype)
 
+  for i in xrange(num_iters):
+    _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
+    weights_history[i, :] = weights_
+    sigma_history[i] = sigma_
 
-def leapfrog_step(step_size, position, momentum, potential_and_grad, grad,
-                  name=None):
-  """Applies one step of the leapfrog integrator.
+  true_weights_ = sess.run(true_weights)
 
-  Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2.
+  # Should converge to something close to true_sigma.
+  plt.plot(sigma_history);
+  plt.ylabel("sigma");
+  plt.xlabel("iteration");
+  ```
 
   Args:
-    step_size: Scalar step size or array of step sizes for the
-      leapfrog integrator. Broadcasts to the shape of
-      `position`. Larger step sizes lead to faster progress, but
-      too-large step sizes lead to larger discretization error and
-      worse energy conservation.
-    position: Tensor containing the value(s) of the position variable(s)
-      to update.
-    momentum: Tensor containing the value(s) of the momentum variable(s)
-      to update.
-    potential_and_grad: Python callable that takes a position tensor like
-      `position` and returns the potential energy and its gradient at that
-      position.
-    grad: Tensor with the value of the gradient of the potential energy
-      at `position`.
+    target_log_prob_fn: Python callable which takes an argument like
+      `current_state` (or `*current_state` if it's a list) and returns its
+      (possibly unnormalized) log-density under the target distribution.
+    current_state: `Tensor` or Python `list` of `Tensor`s representing the
+      current state(s) of the Markov chain(s). The first `r` dimensions index
+      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
+    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
+      for the leapfrog integrator. Must broadcast with the shape of
+      `current_state`. Larger step sizes lead to faster progress, but too-large
+      step sizes make rejection exponentially more likely. When possible, it's
+      often helpful to match per-variable step sizes to the standard deviations
+      of the target distribution in each variable.
+    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+      for. Total progress per HMC step is roughly proportional to `step_size *
+      num_leapfrog_steps`.
+    seed: Python integer to seed the random number generator.
+    current_target_log_prob: (Optional) `Tensor` representing the value of
+      `target_log_prob_fn` at the `current_state`. The only reason to
+      specify this argument is to reduce TF graph size.
+      Default value: `None` (i.e., compute as needed).
+    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
+      representing gradient of `current_target_log_prob` at the `current_state`
+      and wrt the `current_state`. Must have same shape as `current_state`. The
+      only reason to specify this argument is to reduce TF graph size.
+      Default value: `None` (i.e., compute as needed).
     name: Python `str` name prefixed to Ops created by this function.
+      Default value: `None` (i.e., "hmc_kernel").
 
   Returns:
-    updated_position: Updated value of the position.
-    updated_momentum: Updated value of the momentum.
-    new_potential: Potential energy of the new position. Has shape matching
-      `potential_and_grad(position)`.
-    new_grad: Gradient from potential_and_grad() evaluated at the new position.
-      Has shape matching `position`.
+    accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
+      of the Markov chain(s) at each result step. Has same shape as
+      `current_state`.
+    acceptance_probs: Tensor with the acceptance probabilities for each
+      iteration. Has shape matching `target_log_prob_fn(current_state)`.
+    accepted_target_log_prob: `Tensor` representing the value of
+      `target_log_prob_fn` at `accepted_state`.
+    accepted_grads_target_log_prob: Python `list` of `Tensor`s representing the
+      gradient of `accepted_target_log_prob` wrt each `accepted_state`.
+
+  Raises:
+    ValueError: if there isn't one `step_size` or a list with same length as
+      `current_state`.
+  """
+  with ops.name_scope(
+      name, "hmc_kernel",
+      [current_state, step_size, num_leapfrog_steps, seed,
+       current_target_log_prob, current_grads_target_log_prob]):
+    with ops.name_scope("initialize"):
+      [current_state_parts, step_sizes, current_target_log_prob,
+       current_grads_target_log_prob] = _prepare_args(
+           target_log_prob_fn, current_state, step_size,
+           current_target_log_prob, current_grads_target_log_prob,
+           maybe_expand=True)
+      independent_chain_ndims = distributions_util.prefer_static_rank(
+          current_target_log_prob)
+      def init_momentum(s):
+        return random_ops.random_normal(
+            shape=array_ops.shape(s),
+            dtype=s.dtype.base_dtype,
+            seed=distributions_util.gen_new_seed(
+                seed, salt="hmc_kernel_momentums"))
+      current_momentums = [init_momentum(s) for s in current_state_parts]
+
+    [
+        proposed_momentums,
+        proposed_state_parts,
+        proposed_target_log_prob,
+        proposed_grads_target_log_prob,
+    ] = _leapfrog_integrator(current_momentums,
+                             target_log_prob_fn,
+                             current_state_parts,
+                             step_sizes,
+                             num_leapfrog_steps,
+                             current_target_log_prob,
+                             current_grads_target_log_prob)
+
+    energy_change = _compute_energy_change(current_target_log_prob,
+                                           current_momentums,
+                                           proposed_target_log_prob,
+                                           proposed_momentums,
+                                           independent_chain_ndims)
+
+    # u < exp(min(-energy, 0)),  where u~Uniform[0,1)
+    # ==> -log(u) >= max(e, 0)
+    # ==> -log(u) >= e
+    # (Perhaps surprisingly, we don't have a better way to obtain a random
+    # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
+    # maxval=np.inf)` won't work.)
+    random_uniform = random_ops.random_uniform(
+        shape=array_ops.shape(energy_change),
+        dtype=energy_change.dtype,
+        seed=seed)
+    random_positive = -math_ops.log(random_uniform)
+    is_accepted = random_positive >= energy_change
+
+    accepted_target_log_prob = array_ops.where(is_accepted,
+                                               proposed_target_log_prob,
+                                               current_target_log_prob)
+
+    accepted_state_parts = [_choose(is_accepted,
+                                    proposed_state_part,
+                                    current_state_part,
+                                    independent_chain_ndims)
+                            for current_state_part, proposed_state_part
+                            in zip(current_state_parts, proposed_state_parts)]
+
+    accepted_grads_target_log_prob = [
+        _choose(is_accepted,
+                proposed_grad,
+                grad,
+                independent_chain_ndims)
+        for proposed_grad, grad
+        in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)]
+
+    maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
+    return [
+        maybe_flatten(accepted_state_parts),
+        KernelResults(
+            acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)),
+            current_grads_target_log_prob=accepted_grads_target_log_prob,
+            current_target_log_prob=accepted_target_log_prob,
+            energy_change=energy_change,
+            is_accepted=is_accepted,
+            proposed_grads_target_log_prob=proposed_grads_target_log_prob,
+            proposed_state=maybe_flatten(proposed_state_parts),
+            proposed_target_log_prob=proposed_target_log_prob,
+            random_positive=random_positive,
+        ),
+    ]
+
+
+def _leapfrog_integrator(current_momentums,
+                         target_log_prob_fn,
+                         current_state_parts,
+                         step_sizes,
+                         num_leapfrog_steps,
+                         current_target_log_prob=None,
+                         current_grads_target_log_prob=None,
+                         name=None):
+  """Applies `num_leapfrog_steps` of the leapfrog integrator.
+
+  Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.
+
+  #### Examples:
 
-  Example: Simple quadratic potential.
+  ##### Simple quadratic potential.
 
   ```python
-  def potential_and_grad(position):
-    # Simple quadratic potential
-    return tf.reduce_sum(0.5 * tf.square(position)), position
+  tfd = tf.contrib.distributions
+
+  dims = 10
+  num_iter = int(1e3)
+  dtype = np.float32
+
   position = tf.placeholder(np.float32)
   momentum = tf.placeholder(np.float32)
-  potential, grad = potential_and_grad(position)
-  new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step(
-    0.1, position, momentum, potential_and_grad, grad)
-
-  sess = tf.Session()
-  position_val = np.random.randn(10)
-  momentum_val = np.random.randn(10)
-  potential_val, grad_val = sess.run([potential, grad],
-                                     {position: position_val})
-  positions = np.zeros([100, 10])
-  for i in xrange(100):
-    position_val, momentum_val, potential_val, grad_val = sess.run(
-      [new_position, new_momentum, new_potential, new_grad],
-      {position: position_val, momentum: momentum_val})
-    positions[i] = position_val
-  # Should trace out sinusoidal dynamics.
-  plt.plot(positions[:, 0])
+
+  [
+      new_momentums,
+      new_positions,
+  ] = hmc._leapfrog_integrator(
+      current_momentums=[momentum],
+      target_log_prob_fn=tfd.MultivariateNormalDiag(
+          loc=tf.zeros(dims, dtype)).log_prob,
+      current_state_parts=[position],
+      step_sizes=0.1,
+      num_leapfrog_steps=3)[:2]
+
+  sess.graph.finalize()  # No more graph building.
+
+  momentum_ = np.random.randn(dims).astype(dtype)
+  position_ = np.random.randn(dims).astype(dtype)
+
+  positions = np.zeros([num_iter, dims], dtype)
+  for i in xrange(num_iter):
+    position_, momentum_ = sess.run(
+        [new_momentums[0], new_position[0]],
+        feed_dict={position: position_, momentum: momentum_})
+    positions[i] = position_
+
+  plt.plot(positions[:, 0]);  # Sinusoidal.
   ```
+
+  Args:
+    current_momentums: Tensor containing the value(s) of the momentum
+      variable(s) to update.
+    target_log_prob_fn: Python callable which takes an argument like
+      `*current_state_parts` and returns its (possibly unnormalized) log-density
+      under the target distribution.
+    current_state_parts: Python `list` of `Tensor`s representing the current
+      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
+      the `Tensor`(s) index different chains.
+    step_sizes: Python `list` of `Tensor`s representing the step size for the
+      leapfrog integrator. Must broadcast with the shape of
+      `current_state_parts`.  Larger step sizes lead to faster progress, but
+      too-large step sizes make rejection exponentially more likely. When
+      possible, it's often helpful to match per-variable step sizes to the
+      standard deviations of the target distribution in each variable.
+    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+      for. Total progress per HMC step is roughly proportional to `step_size *
+      num_leapfrog_steps`.
+    current_target_log_prob: (Optional) `Tensor` representing the value of
+      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
+      this argument is to reduce TF graph size.
+      Default value: `None` (i.e., compute as needed).
+    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
+      representing gradient of `target_log_prob_fn(*current_state_parts`) wrt
+      `current_state_parts`. Must have same shape as `current_state_parts`. The
+      only reason to specify this argument is to reduce TF graph size.
+      Default value: `None` (i.e., compute as needed).
+    name: Python `str` name prefixed to Ops created by this function.
+      Default value: `None` (i.e., "hmc_leapfrog_integrator").
+
+  Returns:
+    proposed_momentums: Updated value of the momentum.
+    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
+      state(s) of the Markov chain(s) at each result step. Has same shape as
+      input `current_state_parts`.
+    proposed_target_log_prob: `Tensor` representing the value of
+      `target_log_prob_fn` at `accepted_state`.
+    proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt
+      `accepted_state`.
+
+  Raises:
+    ValueError: if `len(momentums) != len(state_parts)`.
+    ValueError: if `len(state_parts) != len(step_sizes)`.
+    ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
+    TypeError: if `not target_log_prob.dtype.is_floating`.
   """
-  with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum,
-                                              grad]):
-    momentum -= 0.5 * step_size * grad
-    position += step_size * momentum
-    potential, grad = potential_and_grad(position)
-    momentum -= 0.5 * step_size * grad
-
-  return position, momentum, potential, grad
+  def _loop_body(step,
+                 current_momentums,
+                 current_state_parts,
+                 ignore_current_target_log_prob,  # pylint: disable=unused-argument
+                 current_grads_target_log_prob):
+    return [step + 1] + list(_leapfrog_step(current_momentums,
+                                            target_log_prob_fn,
+                                            current_state_parts,
+                                            step_sizes,
+                                            current_grads_target_log_prob))
+
+  with ops.name_scope(
+      name, "hmc_leapfrog_integrator",
+      [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps,
+       current_target_log_prob, current_grads_target_log_prob]):
+    if len(current_momentums) != len(current_state_parts):
+      raise ValueError("`momentums` must be in one-to-one correspondence "
+                       "with `state_parts`")
+    num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps,
+                                               name="num_leapfrog_steps")
+    current_target_log_prob, current_grads_target_log_prob = (
+        _maybe_call_fn_and_grads(
+            target_log_prob_fn,
+            current_state_parts,
+            current_target_log_prob,
+            current_grads_target_log_prob))
+    return control_flow_ops.while_loop(
+        cond=lambda iter_, *args: iter_ < num_leapfrog_steps,
+        body=_loop_body,
+        loop_vars=[
+            0,  # iter_
+            current_momentums,
+            current_state_parts,
+            current_target_log_prob,
+            current_grads_target_log_prob,
+        ],
+        back_prop=False)[1:]  # Lop-off "iter_".
+
+
+def _leapfrog_step(current_momentums,
+                   target_log_prob_fn,
+                   current_state_parts,
+                   step_sizes,
+                   current_grads_target_log_prob,
+                   name=None):
+  """Applies one step of the leapfrog integrator."""
+  with ops.name_scope(
+      name, "_leapfrog_step",
+      [current_momentums, current_state_parts, step_sizes,
+       current_grads_target_log_prob]):
+    proposed_momentums = [m + 0.5 * ss * g for m, ss, g
+                          in zip(current_momentums,
+                                 step_sizes,
+                                 current_grads_target_log_prob)]
+    proposed_state_parts = [x + ss * m for x, ss, m
+                            in zip(current_state_parts,
+                                   step_sizes,
+                                   proposed_momentums)]
+    proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts)
+    if not proposed_target_log_prob.dtype.is_floating:
+      raise TypeError("`target_log_prob_fn` must produce a `Tensor` "
+                      "with `float` `dtype`.")
+    proposed_grads_target_log_prob = gradients_ops.gradients(
+        proposed_target_log_prob, proposed_state_parts)
+    if any(g is None for g in proposed_grads_target_log_prob):
+      raise ValueError(
+          "Encountered `None` gradient. Does your target `target_log_prob_fn` "
+          "access all `tf.Variable`s via `tf.get_variable`?\n"
+          "  current_state_parts: {}\n"
+          "  proposed_state_parts: {}\n"
+          "  proposed_grads_target_log_prob: {}".format(
+              current_state_parts,
+              proposed_state_parts,
+              proposed_grads_target_log_prob))
+    proposed_momentums = [m + 0.5 * ss * g for m, ss, g
+                          in zip(proposed_momentums,
+                                 step_sizes,
+                                 proposed_grads_target_log_prob)]
+    return [
+        proposed_momentums,
+        proposed_state_parts,
+        proposed_target_log_prob,
+        proposed_grads_target_log_prob,
+    ]
+
+
+def _compute_energy_change(current_target_log_prob,
+                           current_momentums,
+                           proposed_target_log_prob,
+                           proposed_momentums,
+                           independent_chain_ndims,
+                           name=None):
+  """Helper to `kernel` which computes the energy change."""
+  with ops.name_scope(
+      name, "compute_energy_change",
+      ([current_target_log_prob, proposed_target_log_prob,
+        independent_chain_ndims] +
+       current_momentums + proposed_momentums)):
+    # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy
+    # since they're a mouthful and lets us inline more.
+    lk0, lk1 = [], []
+    for current_momentum, proposed_momentum in zip(current_momentums,
+                                                   proposed_momentums):
+      axis = math_ops.range(independent_chain_ndims,
+                            array_ops.rank(current_momentum))
+      lk0.append(_log_sum_sq(current_momentum, axis))
+      lk1.append(_log_sum_sq(proposed_momentum, axis))
+
+    lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1),
+                                                  axis=-1)
+    lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1),
+                                                  axis=-1)
+    lp0 = -current_target_log_prob   # log_potential
+    lp1 = -proposed_target_log_prob  # proposed_log_potential
+    x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)],
+                        axis=-1)
+
+    # The sum is NaN if any element is NaN or we see both +Inf and -Inf.
+    # Thus we will replace such rows with infinite energy change which implies
+    # rejection. Recall that float-comparisons with NaN are always False.
+    is_sum_determinate = (
+        math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) &
+        math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1))
+    is_sum_determinate = array_ops.tile(
+        is_sum_determinate[..., array_ops.newaxis],
+        multiples=array_ops.concat([
+            array_ops.ones(array_ops.rank(is_sum_determinate),
+                           dtype=dtypes.int32),
+            [4],
+        ], axis=0))
+    x = array_ops.where(is_sum_determinate,
+                        x,
+                        array_ops.fill(array_ops.shape(x),
+                                       value=x.dtype.as_numpy_dtype(np.inf)))
+
+    return math_ops.reduce_sum(x, axis=-1)
+
+
+def _choose(is_accepted,
+            accepted,
+            rejected,
+            independent_chain_ndims,
+            name=None):
+  """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where."""
+  def _expand_is_accepted_like(x):
+    with ops.name_scope("_choose"):
+      expand_shape = array_ops.concat([
+          array_ops.shape(is_accepted),
+          array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)],
+                         dtype=dtypes.int32),
+      ], axis=0)
+      multiples = array_ops.concat([
+          array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32),
+          array_ops.shape(x)[independent_chain_ndims:],
+      ], axis=0)
+      m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape),
+                         multiples)
+      m.set_shape(x.shape)
+      return m
+  with ops.name_scope(name, "_choose", values=[
+      is_accepted, accepted, rejected, independent_chain_ndims]):
+    return array_ops.where(_expand_is_accepted_like(accepted),
+                           accepted,
+                           rejected)
+
+
+def _maybe_call_fn_and_grads(fn,
+                             fn_arg_list,
+                             fn_result=None,
+                             grads_fn_result=None,
+                             description="target_log_prob"):
+  """Helper which computes `fn_result` and `grads` if needed."""
+  fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list)
+                 else [fn_arg_list])
+  if fn_result is None:
+    fn_result = fn(*fn_arg_list)
+  if not fn_result.dtype.is_floating:
+    raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format(
+        description))
+  if grads_fn_result is None:
+    grads_fn_result = gradients_ops.gradients(
+        fn_result, fn_arg_list)
+  if len(fn_arg_list) != len(grads_fn_result):
+    raise ValueError("`{}` must be in one-to-one correspondence with "
+                     "`grads_{}`".format(*[description]*2))
+  if any(g is None for g in grads_fn_result):
+    raise ValueError("Encountered `None` gradient.")
+  return fn_result, grads_fn_result
+
+
+def _prepare_args(target_log_prob_fn, state, step_size,
+                  target_log_prob=None, grads_target_log_prob=None,
+                  maybe_expand=False, description="target_log_prob"):
+  """Helper which processes input args to meet list-like assumptions."""
+  state_parts = list(state) if _is_list_like(state) else [state]
+  state_parts = [ops.convert_to_tensor(s, name="state")
+                 for s in state_parts]
+  target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads(
+      target_log_prob_fn,
+      state_parts,
+      target_log_prob,
+      grads_target_log_prob,
+      description)
+  step_sizes = list(step_size) if _is_list_like(step_size) else [step_size]
+  step_sizes = [
+      ops.convert_to_tensor(
+          s, name="step_size", dtype=target_log_prob.dtype)
+      for s in step_sizes]
+  if len(step_sizes) == 1:
+    step_sizes *= len(state_parts)
+  if len(state_parts) != len(step_sizes):
+    raise ValueError("There should be exactly one `step_size` or it should "
+                     "have same length as `current_state`.")
+  if maybe_expand:
+    maybe_flatten = lambda x: x
+  else:
+    maybe_flatten = lambda x: x if _is_list_like(state) else x[0]
+  return [
+      maybe_flatten(state_parts),
+      maybe_flatten(step_sizes),
+      target_log_prob,
+      grads_target_log_prob,
+  ]
+
+
+def _is_list_like(x):
+  """Helper which returns `True` if input is `list`-like."""
+  return isinstance(x, (tuple, list))
+
+
+def _log_sum_sq(x, axis=None):
+  """Computes log(sum(x**2))."""
+  return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis)