Automated g4 rollback of changelist 184323369
authorJoshua V. Dillon <jvdillon@google.com>
Mon, 5 Feb 2018 18:45:33 +0000 (10:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 18:49:44 +0000 (10:49 -0800)
PiperOrigin-RevId: 184551259

tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
tensorflow/contrib/bayesflow/python/ops/hmc.py
tensorflow/contrib/bayesflow/python/ops/hmc_impl.py

index d9d0dfc..cbc66b6 100644 (file)
@@ -19,36 +19,29 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
+from scipy import special
 from scipy import stats
 
 from tensorflow.contrib.bayesflow.python.ops import hmc
-from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change
-from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator
 
-from tensorflow.contrib.distributions.python.ops import independent as independent_lib
-from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl as gradients_ops
+from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
-from tensorflow.python.ops.distributions import gamma as gamma_lib
-from tensorflow.python.ops.distributions import normal as normal_lib
 from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging as logging_ops
-
-
-def _reduce_variance(x, axis=None, keepdims=False):
-  sample_mean = math_ops.reduce_mean(x, axis, keepdims=True)
-  return math_ops.reduce_mean(
-      math_ops.squared_difference(x, sample_mean), axis, keepdims)
+from tensorflow.python.platform import tf_logging as logging
 
 
+# TODO(b/66964210): Test float16.
 class HMCTest(test.TestCase):
 
   def setUp(self):
     self._shape_param = 5.
     self._rate_param = 10.
+    self._expected_x = (special.digamma(self._shape_param)
+                        - np.log(self._rate_param))
+    self._expected_exp_x = self._shape_param / self._rate_param
 
     random_seed.set_random_seed(10003)
     np.random.seed(10003)
@@ -70,46 +63,63 @@ class HMCTest(test.TestCase):
                                self._rate_param * math_ops.exp(x),
                                event_dims)
 
-  def _integrator_conserves_energy(self, x, independent_chain_ndims, sess,
+  def _log_gamma_log_prob_grad(self, x, event_dims=()):
+    """Computes log-pdf and gradient of a log-gamma random variable.
+
+    Args:
+      x: Value of the random variable.
+      event_dims: Dimensions not to treat as independent. Default is (),
+        i.e., all dimensions are independent.
+
+    Returns:
+      log_prob: The log-pdf up to a normalizing constant.
+      grad: The gradient of the log-pdf with respect to x.
+    """
+    return (math_ops.reduce_sum(self._shape_param * x -
+                                self._rate_param * math_ops.exp(x),
+                                event_dims),
+            self._shape_param - self._rate_param * math_ops.exp(x))
+
+  def _n_event_dims(self, x_shape, event_dims):
+    return np.prod([int(x_shape[i]) for i in event_dims])
+
+  def _integrator_conserves_energy(self, x, event_dims, sess,
                                    feed_dict=None):
-    step_size = array_ops.placeholder(np.float32, [], name="step_size")
-    hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps")
+    def potential_and_grad(x):
+      log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims)
+      return -log_prob, -grad
+
+    step_size = array_ops.placeholder(np.float32, [], name='step_size')
+    hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps')
 
     if feed_dict is None:
       feed_dict = {}
     feed_dict[hmc_lf_steps] = 1000
 
-    event_dims = math_ops.range(independent_chain_ndims,
-                                array_ops.rank(x))
-
     m = random_ops.random_normal(array_ops.shape(x))
-    log_prob_0 = self._log_gamma_log_prob(x, event_dims)
-    grad_0 = gradients_ops.gradients(log_prob_0, x)
-    old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims)
-
-    new_m, _, log_prob_1, _ = _leapfrog_integrator(
-        current_momentums=[m],
-        target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims),
-        current_state_parts=[x],
-        step_sizes=[step_size],
-        num_leapfrog_steps=hmc_lf_steps,
-        current_target_log_prob=log_prob_0,
-        current_grads_target_log_prob=grad_0)
-    new_m = new_m[0]
-
-    new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
+    potential_0, grad_0 = potential_and_grad(x)
+    old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m,
+                                                         event_dims)
+
+    _, new_m, potential_1, _ = (
+        hmc.leapfrog_integrator(step_size, hmc_lf_steps, x,
+                                m, potential_and_grad, grad_0))
+
+    new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
                                                          event_dims)
 
     x_shape = sess.run(x, feed_dict).shape
-    event_size = np.prod(x_shape[independent_chain_ndims:])
-    feed_dict[step_size] = 0.1 / event_size
-    old_energy_, new_energy_ = sess.run([old_energy, new_energy],
-                                        feed_dict)
-    logging_ops.vlog(1, "average energy relative change: {}".format(
-        (1. - new_energy_ / old_energy_).mean()))
-    self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
-
-  def _integrator_conserves_energy_wrapper(self, independent_chain_ndims):
+    n_event_dims = self._n_event_dims(x_shape, event_dims)
+    feed_dict[step_size] = 0.1 / n_event_dims
+    old_energy_val, new_energy_val = sess.run([old_energy, new_energy],
+                                              feed_dict)
+    logging.vlog(1, 'average energy change: {}'.format(
+        abs(old_energy_val - new_energy_val).mean()))
+
+    self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool),
+                        abs(old_energy_val - new_energy_val) < 1.)
+
+  def _integrator_conserves_energy_wrapper(self, event_dims):
     """Tests the long-term energy conservation of the leapfrog integrator.
 
     The leapfrog integrator is symplectic, so for sufficiently small step
@@ -117,167 +127,135 @@ class HMCTest(test.TestCase):
     the energy of the system blowing up or collapsing.
 
     Args:
-      independent_chain_ndims: Python `int` scalar representing the number of
-        dims associated with independent chains.
+      event_dims: A tuple of dimensions that should not be treated as
+        independent. This allows for multiple chains to be run independently
+        in parallel. Default is (), i.e., all dimensions are independent.
     """
     with self.test_session() as sess:
-      x_ph = array_ops.placeholder(np.float32, name="x_ph")
-      feed_dict = {x_ph: np.random.rand(50, 10, 2)}
-      self._integrator_conserves_energy(x_ph, independent_chain_ndims,
-                                        sess, feed_dict)
+      x_ph = array_ops.placeholder(np.float32, name='x_ph')
+
+      feed_dict = {x_ph: np.zeros([50, 10, 2])}
+      self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict)
 
   def testIntegratorEnergyConservationNullShape(self):
-    self._integrator_conserves_energy_wrapper(0)
+    self._integrator_conserves_energy_wrapper([])
 
   def testIntegratorEnergyConservation1(self):
-    self._integrator_conserves_energy_wrapper(1)
+    self._integrator_conserves_energy_wrapper([1])
 
   def testIntegratorEnergyConservation2(self):
-    self._integrator_conserves_energy_wrapper(2)
+    self._integrator_conserves_energy_wrapper([2])
 
-  def testIntegratorEnergyConservation3(self):
-    self._integrator_conserves_energy_wrapper(3)
+  def testIntegratorEnergyConservation12(self):
+    self._integrator_conserves_energy_wrapper([1, 2])
 
-  def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
-                                       sess, feed_dict=None):
+  def testIntegratorEnergyConservation012(self):
+    self._integrator_conserves_energy_wrapper([0, 1, 2])
+
+  def _chain_gets_correct_expectations(self, x, event_dims, sess,
+                                       feed_dict=None):
     def log_gamma_log_prob(x):
-      event_dims = math_ops.range(independent_chain_ndims,
-                                  array_ops.rank(x))
       return self._log_gamma_log_prob(x, event_dims)
 
-    num_results = array_ops.placeholder(
-        np.int32, [], name="num_results")
-    step_size = array_ops.placeholder(
-        np.float32, [], name="step_size")
-    num_leapfrog_steps = array_ops.placeholder(
-        np.int32, [], name="num_leapfrog_steps")
+    step_size = array_ops.placeholder(np.float32, [], name='step_size')
+    hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps')
+    hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps')
 
     if feed_dict is None:
       feed_dict = {}
-    feed_dict.update({num_results: 150,
-                      step_size: 0.1,
-                      num_leapfrog_steps: 2})
-
-    samples, kernel_results = hmc.sample_chain(
-        num_results=num_results,
-        target_log_prob_fn=log_gamma_log_prob,
-        current_state=x,
-        step_size=step_size,
-        num_leapfrog_steps=num_leapfrog_steps,
-        num_burnin_steps=150,
-        seed=42)
-
-    expected_x = (math_ops.digamma(self._shape_param)
-                  - np.log(self._rate_param))
-
-    expected_exp_x = self._shape_param / self._rate_param
-
-    acceptance_probs_, samples_, expected_x_ = sess.run(
-        [kernel_results.acceptance_probs, samples, expected_x],
-        feed_dict)
-
-    actual_x = samples_.mean()
-    actual_exp_x = np.exp(samples_).mean()
-
-    logging_ops.vlog(1, "True      E[x, exp(x)]: {}\t{}".format(
-        expected_x_, expected_exp_x))
-    logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
-        actual_x, actual_exp_x))
-    self.assertNear(actual_x, expected_x_, 2e-2)
-    self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
-    self.assertTrue((acceptance_probs_ > 0.5).all())
-    self.assertTrue((acceptance_probs_ <= 1.0).all())
-
-  def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims):
+    feed_dict.update({step_size: 0.1,
+                      hmc_lf_steps: 2,
+                      hmc_n_steps: 300})
+
+    sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps],
+                                                    step_size,
+                                                    hmc_lf_steps,
+                                                    x, log_gamma_log_prob,
+                                                    event_dims)
+
+    acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain],
+                                         feed_dict)
+    samples = samples[feed_dict[hmc_n_steps] // 2:]
+    expected_x_est = samples.mean()
+    expected_exp_x_est = np.exp(samples).mean()
+
+    logging.vlog(1, 'True      E[x, exp(x)]: {}\t{}'.format(
+        self._expected_x, self._expected_exp_x))
+    logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format(
+        expected_x_est, expected_exp_x_est))
+    self.assertNear(expected_x_est, self._expected_x, 2e-2)
+    self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2)
+    self.assertTrue((acceptance_probs > 0.5).all())
+    self.assertTrue((acceptance_probs <= 1.0).all())
+
+  def _chain_gets_correct_expectations_wrapper(self, event_dims):
     with self.test_session() as sess:
-      x_ph = array_ops.placeholder(np.float32, name="x_ph")
-      feed_dict = {x_ph: np.random.rand(50, 10, 2)}
-      self._chain_gets_correct_expectations(x_ph, independent_chain_ndims,
-                                            sess, feed_dict)
+      x_ph = array_ops.placeholder(np.float32, name='x_ph')
+
+      feed_dict = {x_ph: np.zeros([50, 10, 2])}
+      self._chain_gets_correct_expectations(x_ph, event_dims, sess,
+                                            feed_dict)
 
   def testHMCChainExpectationsNullShape(self):
-    self._chain_gets_correct_expectations_wrapper(0)
+    self._chain_gets_correct_expectations_wrapper([])
 
   def testHMCChainExpectations1(self):
-    self._chain_gets_correct_expectations_wrapper(1)
+    self._chain_gets_correct_expectations_wrapper([1])
 
   def testHMCChainExpectations2(self):
-    self._chain_gets_correct_expectations_wrapper(2)
+    self._chain_gets_correct_expectations_wrapper([2])
+
+  def testHMCChainExpectations12(self):
+    self._chain_gets_correct_expectations_wrapper([1, 2])
 
-  def _kernel_leaves_target_invariant(self, initial_draws,
-                                      independent_chain_ndims,
+  def _kernel_leaves_target_invariant(self, initial_draws, event_dims,
                                       sess, feed_dict=None):
     def log_gamma_log_prob(x):
-      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
       return self._log_gamma_log_prob(x, event_dims)
 
     def fake_log_prob(x):
       """Cooled version of the target distribution."""
       return 1.1 * log_gamma_log_prob(x)
 
-    step_size = array_ops.placeholder(np.float32, [], name="step_size")
+    step_size = array_ops.placeholder(np.float32, [], name='step_size')
 
     if feed_dict is None:
       feed_dict = {}
 
     feed_dict[step_size] = 0.4
 
-    sample, kernel_results = hmc.kernel(
-        target_log_prob_fn=log_gamma_log_prob,
-        current_state=initial_draws,
-        step_size=step_size,
-        num_leapfrog_steps=5,
-        seed=43)
-
-    bad_sample, bad_kernel_results = hmc.kernel(
-        target_log_prob_fn=fake_log_prob,
-        current_state=initial_draws,
-        step_size=step_size,
-        num_leapfrog_steps=5,
-        seed=44)
-
-    [
-        acceptance_probs_,
-        bad_acceptance_probs_,
-        initial_draws_,
-        updated_draws_,
-        fake_draws_,
-    ] = sess.run([
-        kernel_results.acceptance_probs,
-        bad_kernel_results.acceptance_probs,
-        initial_draws,
-        sample,
-        bad_sample,
-    ], feed_dict)
-
+    sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws,
+                                                log_gamma_log_prob, event_dims)
+    bad_sample, bad_acceptance_probs, _, _ = hmc.kernel(
+        step_size, 5, initial_draws, fake_log_prob, event_dims)
+    (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val,
+     updated_draws_val, fake_draws_val) = sess.run([acceptance_probs,
+                                                    bad_acceptance_probs,
+                                                    initial_draws, sample,
+                                                    bad_sample], feed_dict)
     # Confirm step size is small enough that we usually accept.
-    self.assertGreater(acceptance_probs_.mean(), 0.5)
-    self.assertGreater(bad_acceptance_probs_.mean(), 0.5)
-
+    self.assertGreater(acceptance_probs_val.mean(), 0.5)
+    self.assertGreater(bad_acceptance_probs_val.mean(), 0.5)
     # Confirm step size is large enough that we sometimes reject.
-    self.assertLess(acceptance_probs_.mean(), 0.99)
-    self.assertLess(bad_acceptance_probs_.mean(), 0.99)
-
-    _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(),
-                                        updated_draws_.flatten())
-    _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(),
-                                        fake_draws_.flatten())
-
-    logging_ops.vlog(1, "acceptance rate for true target: {}".format(
-        acceptance_probs_.mean()))
-    logging_ops.vlog(1, "acceptance rate for fake target: {}".format(
-        bad_acceptance_probs_.mean()))
-    logging_ops.vlog(1, "K-S p-value for true target: {}".format(
-        ks_p_value_true))
-    logging_ops.vlog(1, "K-S p-value for fake target: {}".format(
-        ks_p_value_fake))
+    self.assertLess(acceptance_probs_val.mean(), 0.99)
+    self.assertLess(bad_acceptance_probs_val.mean(), 0.99)
+    _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(),
+                                        updated_draws_val.flatten())
+    _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(),
+                                        fake_draws_val.flatten())
+    logging.vlog(1, 'acceptance rate for true target: {}'.format(
+        acceptance_probs_val.mean()))
+    logging.vlog(1, 'acceptance rate for fake target: {}'.format(
+        bad_acceptance_probs_val.mean()))
+    logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true))
+    logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake))
     # Make sure that the MCMC update hasn't changed the empirical CDF much.
     self.assertGreater(ks_p_value_true, 1e-3)
     # Confirm that targeting the wrong distribution does
     # significantly change the empirical CDF.
     self.assertLess(ks_p_value_fake, 1e-6)
 
-  def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims):
+  def _kernel_leaves_target_invariant_wrapper(self, event_dims):
     """Tests that the kernel leaves the target distribution invariant.
 
     Draws some independent samples from the target distribution,
@@ -289,116 +267,86 @@ class HMCTest(test.TestCase):
     does change the target distribution. (And that we can detect that.)
 
     Args:
-      independent_chain_ndims: Python `int` scalar representing the number of
-        dims associated with independent chains.
+      event_dims: A tuple of dimensions that should not be treated as
+        independent. This allows for multiple chains to be run independently
+        in parallel. Default is (), i.e., all dimensions are independent.
     """
     with self.test_session() as sess:
       initial_draws = np.log(np.random.gamma(self._shape_param,
                                              size=[50000, 2, 2]))
       initial_draws -= np.log(self._rate_param)
-      x_ph = array_ops.placeholder(np.float32, name="x_ph")
+      x_ph = array_ops.placeholder(np.float32, name='x_ph')
 
       feed_dict = {x_ph: initial_draws}
 
-      self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims,
-                                           sess, feed_dict)
+      self._kernel_leaves_target_invariant(x_ph, event_dims, sess,
+                                           feed_dict)
+
+  def testKernelLeavesTargetInvariantNullShape(self):
+    self._kernel_leaves_target_invariant_wrapper([])
 
   def testKernelLeavesTargetInvariant1(self):
-    self._kernel_leaves_target_invariant_wrapper(1)
+    self._kernel_leaves_target_invariant_wrapper([1])
 
   def testKernelLeavesTargetInvariant2(self):
-    self._kernel_leaves_target_invariant_wrapper(2)
+    self._kernel_leaves_target_invariant_wrapper([2])
 
-  def testKernelLeavesTargetInvariant3(self):
-    self._kernel_leaves_target_invariant_wrapper(3)
+  def testKernelLeavesTargetInvariant12(self):
+    self._kernel_leaves_target_invariant_wrapper([1, 2])
 
-  def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
-                                       sess, feed_dict=None):
+  def _ais_gets_correct_log_normalizer(self, init, event_dims, sess,
+                                       feed_dict=None):
     def proposal_log_prob(x):
-      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
-      return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
-                                        axis=event_dims)
+      return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi),
+                                 event_dims)
 
     def target_log_prob(x):
-      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
       return self._log_gamma_log_prob(x, event_dims)
 
     if feed_dict is None:
       feed_dict = {}
 
-    num_steps = 200
-
-    _, ais_weights, _ = hmc.sample_annealed_importance_chain(
-        proposal_log_prob_fn=proposal_log_prob,
-        num_steps=num_steps,
-        target_log_prob_fn=target_log_prob,
-        step_size=0.5,
-        current_state=init,
-        num_leapfrog_steps=2,
-        seed=45)
-
-    event_shape = array_ops.shape(init)[independent_chain_ndims:]
-    event_size = math_ops.reduce_prod(event_shape)
-
-    log_true_normalizer = (
-        -self._shape_param * math_ops.log(self._rate_param)
-        + math_ops.lgamma(self._shape_param))
-    log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
-
-    log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
-                                - np.log(num_steps))
-
-    ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
-    ais_weights_size = array_ops.size(ais_weights)
-    standard_error = math_ops.sqrt(
-        _reduce_variance(ratio_estimate_true)
-        / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
-
-    [
-        ratio_estimate_true_,
-        log_true_normalizer_,
-        log_estimated_normalizer_,
-        standard_error_,
-        ais_weights_size_,
-        event_size_,
-    ] = sess.run([
-        ratio_estimate_true,
-        log_true_normalizer,
-        log_estimated_normalizer,
-        standard_error,
-        ais_weights_size,
-        event_size,
-    ], feed_dict)
-
-    logging_ops.vlog(1, "        log_true_normalizer: {}\n"
-                        "   log_estimated_normalizer: {}\n"
-                        "           ais_weights_size: {}\n"
-                        "                 event_size: {}\n".format(
-                            log_true_normalizer_,
-                            log_estimated_normalizer_,
-                            ais_weights_size_,
-                            event_size_))
-    self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
-
-  def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
+    w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob,
+                            proposal_log_prob, event_dims)
+
+    w_val = sess.run(w, feed_dict)
+    init_shape = sess.run(init, feed_dict).shape
+    normalizer_multiplier = np.prod([init_shape[i] for i in event_dims])
+
+    true_normalizer = -self._shape_param * np.log(self._rate_param)
+    true_normalizer += special.gammaln(self._shape_param)
+    true_normalizer *= normalizer_multiplier
+
+    n_weights = np.prod(w_val.shape)
+    normalized_w = np.exp(w_val - true_normalizer)
+    standard_error = np.std(normalized_w) / np.sqrt(n_weights)
+    logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format(
+        true_normalizer, np.log(normalized_w.mean()) + true_normalizer,
+        n_weights))
+    self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error)
+
+  def _ais_gets_correct_log_normalizer_wrapper(self, event_dims):
     """Tests that AIS yields reasonable estimates of normalizers."""
     with self.test_session() as sess:
-      x_ph = array_ops.placeholder(np.float32, name="x_ph")
+      x_ph = array_ops.placeholder(np.float32, name='x_ph')
+
       initial_draws = np.random.normal(size=[30, 2, 1])
-      self._ais_gets_correct_log_normalizer(
-          x_ph,
-          independent_chain_ndims,
-          sess,
-          feed_dict={x_ph: initial_draws})
+      feed_dict = {x_ph: initial_draws}
+
+      self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess,
+                                            feed_dict)
+
+  def testAISNullShape(self):
+    self._ais_gets_correct_log_normalizer_wrapper([])
 
   def testAIS1(self):
-    self._ais_gets_correct_log_normalizer_wrapper(1)
+    self._ais_gets_correct_log_normalizer_wrapper([1])
 
   def testAIS2(self):
-    self._ais_gets_correct_log_normalizer_wrapper(2)
+    self._ais_gets_correct_log_normalizer_wrapper([2])
 
-  def testAIS3(self):
-    self._ais_gets_correct_log_normalizer_wrapper(3)
+  def testAIS12(self):
+    self._ais_gets_correct_log_normalizer_wrapper([1, 2])
 
   def testNanRejection(self):
     """Tests that an update that yields NaN potentials gets rejected.
@@ -411,29 +359,24 @@ class HMCTest(test.TestCase):
     """
     def _unbounded_exponential_log_prob(x):
       """An exponential distribution with log-likelihood NaN for x < 0."""
-      per_element_potentials = array_ops.where(
-          x < 0.,
-          array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)),
-          -x)
+      per_element_potentials = array_ops.where(x < 0,
+                                               np.nan * array_ops.ones_like(x),
+                                               -x)
       return math_ops.reduce_sum(per_element_potentials)
 
     with self.test_session() as sess:
       initial_x = math_ops.linspace(0.01, 5, 10)
-      updated_x, kernel_results = hmc.kernel(
-          target_log_prob_fn=_unbounded_exponential_log_prob,
-          current_state=initial_x,
-          step_size=2.,
-          num_leapfrog_steps=5,
-          seed=46)
-      initial_x_, updated_x_, acceptance_probs_ = sess.run(
-          [initial_x, updated_x, kernel_results.acceptance_probs])
-
-      logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
-      logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
-      logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
-
-      self.assertAllEqual(initial_x_, updated_x_)
-      self.assertEqual(acceptance_probs_, 0.)
+      updated_x, acceptance_probs, _, _ = hmc.kernel(
+          2., 5, initial_x, _unbounded_exponential_log_prob, [0])
+      initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
+          [initial_x, updated_x, acceptance_probs])
+
+      logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
+      logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
+      logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
+
+      self.assertAllEqual(initial_x_val, updated_x_val)
+      self.assertEqual(acceptance_probs_val, 0.)
 
   def testNanFromGradsDontPropagate(self):
     """Test that update with NaN gradients does not cause NaN in results."""
@@ -442,195 +385,60 @@ class HMCTest(test.TestCase):
 
     with self.test_session() as sess:
       initial_x = math_ops.linspace(0.01, 5, 10)
-      updated_x, kernel_results = hmc.kernel(
-          target_log_prob_fn=_nan_log_prob_with_nan_gradient,
-          current_state=initial_x,
-          step_size=2.,
-          num_leapfrog_steps=5,
-          seed=47)
-      initial_x_, updated_x_, acceptance_probs_ = sess.run(
-          [initial_x, updated_x, kernel_results.acceptance_probs])
-
-      logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
-      logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
-      logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
-
-      self.assertAllEqual(initial_x_, updated_x_)
-      self.assertEqual(acceptance_probs_, 0.)
+      updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel(
+          2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0])
+      initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
+          [initial_x, updated_x, acceptance_probs])
+
+      logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
+      logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
+      logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
+
+      self.assertAllEqual(initial_x_val, updated_x_val)
+      self.assertEqual(acceptance_probs_val, 0.)
 
       self.assertAllFinite(
-          gradients_ops.gradients(updated_x, initial_x)[0].eval())
-      self.assertAllEqual([True], [g is None for g in gradients_ops.gradients(
-          kernel_results.proposed_grads_target_log_prob, initial_x)])
-      self.assertAllEqual([False], [g is None for g in gradients_ops.gradients(
-          kernel_results.proposed_grads_target_log_prob,
-          kernel_results.proposed_state)])
+          gradients_impl.gradients(updated_x, initial_x)[0].eval())
+      self.assertTrue(
+          gradients_impl.gradients(new_grad, initial_x)[0] is None)
 
       # Gradients of the acceptance probs and new log prob are not finite.
+      _ = new_log_prob  # Prevent unused arg error.
       # self.assertAllFinite(
-      #     gradients_ops.gradients(acceptance_probs, initial_x)[0].eval())
+      #     gradients_impl.gradients(acceptance_probs, initial_x)[0].eval())
       # self.assertAllFinite(
-      #     gradients_ops.gradients(new_log_prob, initial_x)[0].eval())
-
-  def _testChainWorksDtype(self, dtype):
-    states, kernel_results = hmc.sample_chain(
-        num_results=10,
-        target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
-        current_state=np.zeros(5).astype(dtype),
-        step_size=0.01,
-        num_leapfrog_steps=10,
-        seed=48)
-    with self.test_session() as sess:
-      states_, acceptance_probs_ = sess.run(
-          [states, kernel_results.acceptance_probs])
-    self.assertEqual(dtype, states_.dtype)
-    self.assertEqual(dtype, acceptance_probs_.dtype)
+      #     gradients_impl.gradients(new_log_prob, initial_x)[0].eval())
 
   def testChainWorksIn64Bit(self):
-    self._testChainWorksDtype(np.float64)
+    def log_prob(x):
+      return - math_ops.reduce_sum(x * x, axis=-1)
+    states, acceptance_probs = hmc.chain(
+        n_iterations=10,
+        step_size=np.float64(0.01),
+        n_leapfrog_steps=10,
+        initial_x=np.zeros(5).astype(np.float64),
+        target_log_prob_fn=log_prob,
+        event_dims=[-1])
+    with self.test_session() as sess:
+      states_, acceptance_probs_ = sess.run([states, acceptance_probs])
+    self.assertEqual(np.float64, states_.dtype)
+    self.assertEqual(np.float64, acceptance_probs_.dtype)
 
   def testChainWorksIn16Bit(self):
-    self._testChainWorksDtype(np.float16)
-
-
-class _EnergyComputationTest(object):
-
-  def testHandlesNanFromPotential(self):
-    with self.test_session() as sess:
-      x = [1, np.inf, -np.inf, np.nan]
-      target_log_prob, proposed_target_log_prob = [
-          self.dtype(x.flatten()) for x in np.meshgrid(x, x)]
-      num_chains = len(target_log_prob)
-      dummy_momentums = [-1, 1]
-      momentums = [self.dtype([dummy_momentums] * num_chains)]
-      proposed_momentums = [self.dtype([dummy_momentums] * num_chains)]
-
-      target_log_prob = ops.convert_to_tensor(target_log_prob)
-      momentums = [ops.convert_to_tensor(momentums[0])]
-      proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
-      proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
-
-      energy = _compute_energy_change(
-          target_log_prob,
-          momentums,
-          proposed_target_log_prob,
-          proposed_momentums,
-          independent_chain_ndims=1)
-      grads = gradients_ops.gradients(energy, momentums)
-
-      [actual_energy, grads_] = sess.run([energy, grads])
-
-      # Ensure energy is `inf` (note: that's positive inf) in weird cases and
-      # finite otherwise.
-      expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
-      self.assertAllEqual(expected_energy, actual_energy)
-
-      # Ensure gradient is finite.
-      self.assertAllEqual(np.ones_like(grads_).astype(np.bool),
-                          np.isfinite(grads_))
-
-  def testHandlesNanFromKinetic(self):
+    def log_prob(x):
+      return - math_ops.reduce_sum(x * x, axis=-1)
+    states, acceptance_probs = hmc.chain(
+        n_iterations=10,
+        step_size=np.float16(0.01),
+        n_leapfrog_steps=10,
+        initial_x=np.zeros(5).astype(np.float16),
+        target_log_prob_fn=log_prob,
+        event_dims=[-1])
     with self.test_session() as sess:
-      x = [1, np.inf, -np.inf, np.nan]
-      momentums, proposed_momentums = [
-          [np.reshape(self.dtype(x), [-1, 1])]
-          for x in np.meshgrid(x, x)]
-      num_chains = len(momentums[0])
-      target_log_prob = np.ones(num_chains, self.dtype)
-      proposed_target_log_prob = np.ones(num_chains, self.dtype)
+      states_, acceptance_probs_ = sess.run([states, acceptance_probs])
+    self.assertEqual(np.float16, states_.dtype)
+    self.assertEqual(np.float16, acceptance_probs_.dtype)
 
-      target_log_prob = ops.convert_to_tensor(target_log_prob)
-      momentums = [ops.convert_to_tensor(momentums[0])]
-      proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
-      proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
 
-      energy = _compute_energy_change(
-          target_log_prob,
-          momentums,
-          proposed_target_log_prob,
-          proposed_momentums,
-          independent_chain_ndims=1)
-      grads = gradients_ops.gradients(energy, momentums)
-
-      [actual_energy, grads_] = sess.run([energy, grads])
-
-      # Ensure energy is `inf` (note: that's positive inf) in weird cases and
-      # finite otherwise.
-      expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
-      self.assertAllEqual(expected_energy, actual_energy)
-
-      # Ensure gradient is finite.
-      g = grads_[0].reshape([len(x), len(x)])[:, 0]
-      self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))
-
-      # The remaining gradients are nan because the momentum was itself nan or
-      # inf.
-      g = grads_[0].reshape([len(x), len(x)])[:, 1:]
-      self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
-
-
-class EnergyComputationTest16(test.TestCase, _EnergyComputationTest):
-  dtype = np.float16
-
-
-class EnergyComputationTest32(test.TestCase, _EnergyComputationTest):
-  dtype = np.float32
-
-
-class EnergyComputationTest64(test.TestCase, _EnergyComputationTest):
-  dtype = np.float64
-
-
-class _HMCHandlesLists(object):
-
-  def testStateParts(self):
-    with self.test_session() as sess:
-      dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
-      dist_y = independent_lib.Independent(
-          gamma_lib.Gamma(concentration=self.dtype([1, 2]),
-                          rate=self.dtype([0.5, 0.75])),
-          reinterpreted_batch_ndims=1)
-      def target_log_prob(x, y):
-        return dist_x.log_prob(x) + dist_y.log_prob(y)
-      x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
-      samples, _ = hmc.sample_chain(
-          num_results=int(2e3),
-          target_log_prob_fn=target_log_prob,
-          current_state=x0,
-          step_size=0.85,
-          num_leapfrog_steps=3,
-          num_burnin_steps=int(250),
-          seed=49)
-      actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
-      actual_vars = [_reduce_variance(s, axis=0) for s in samples]
-      expected_means = [dist_x.mean(), dist_y.mean()]
-      expected_vars = [dist_x.variance(), dist_y.variance()]
-      [
-          actual_means_,
-          actual_vars_,
-          expected_means_,
-          expected_vars_,
-      ] = sess.run([
-          actual_means,
-          actual_vars,
-          expected_means,
-          expected_vars,
-      ])
-      self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
-      self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.30)
-
-
-class HMCHandlesLists16(_HMCHandlesLists, test.TestCase):
-  dtype = np.float16
-
-
-class HMCHandlesLists32(_HMCHandlesLists, test.TestCase):
-  dtype = np.float32
-
-
-class HMCHandlesLists64(_HMCHandlesLists, test.TestCase):
-  dtype = np.float64
-
-
-if __name__ == "__main__":
+if __name__ == '__main__':
   test.main()
index 7fd5652..977d42f 100644 (file)
@@ -12,7 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm."""
+"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
+"""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -23,9 +24,11 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import *  # pylint: disabl
 from tensorflow.python.util import all_util
 
 _allowed_symbols = [
-    "sample_chain",
-    "sample_annealed_importance_chain",
-    "kernel",
+    'chain',
+    'kernel',
+    'leapfrog_integrator',
+    'leapfrog_step',
+    'ais_chain'
 ]
 
 all_util.remove_undocumented(__name__, _allowed_symbols)
index f7a11c2..5685a94 100644 (file)
 # ==============================================================================
 """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
 
-@@sample_chain
-@@sample_annealed_importance_chain
-@@kernel
+@@chain
+@@update
+@@leapfrog_integrator
+@@leapfrog_step
+@@ais_chain
 """
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import numpy as np
 
 from tensorflow.python.framework import dtypes
@@ -31,292 +32,168 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gradients_impl as gradients_ops
+from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
-from tensorflow.python.ops.distributions import util as distributions_util
+from tensorflow.python.platform import tf_logging as logging
 
 __all__ = [
-    "sample_chain",
-    "sample_annealed_importance_chain",
-    "kernel",
+    'chain',
+    'kernel',
+    'leapfrog_integrator',
+    'leapfrog_step',
+    'ais_chain'
 ]
 
 
-KernelResults = collections.namedtuple(
-    "KernelResults",
-    [
-        "acceptance_probs",
-        "current_grads_target_log_prob",  # "Current result" means "accepted".
-        "current_target_log_prob",  # "Current result" means "accepted".
-        "energy_change",
-        "is_accepted",
-        "proposed_grads_target_log_prob",
-        "proposed_state",
-        "proposed_target_log_prob",
-        "random_positive",
-    ])
-
-
-def _make_dummy_kernel_results(
-    dummy_state,
-    dummy_target_log_prob,
-    dummy_grads_target_log_prob):
-  return KernelResults(
-      acceptance_probs=dummy_target_log_prob,
-      current_grads_target_log_prob=dummy_grads_target_log_prob,
-      current_target_log_prob=dummy_target_log_prob,
-      energy_change=dummy_target_log_prob,
-      is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool),
-      proposed_grads_target_log_prob=dummy_grads_target_log_prob,
-      proposed_state=dummy_state,
-      proposed_target_log_prob=dummy_target_log_prob,
-      random_positive=dummy_target_log_prob,
-  )
-
-
-def sample_chain(
-    num_results,
-    target_log_prob_fn,
-    current_state,
-    step_size,
-    num_leapfrog_steps,
-    num_burnin_steps=0,
-    num_steps_between_results=0,
-    seed=None,
-    current_target_log_prob=None,
-    current_grads_target_log_prob=None,
-    name=None):
+def _make_potential_and_grad(target_log_prob_fn):
+  def potential_and_grad(x):
+    log_prob_result = -target_log_prob_fn(x)
+    grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result),
+                                           x)[0]
+    return log_prob_result, grad_result
+  return potential_and_grad
+
+
+def chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
+          target_log_prob_fn, event_dims=(), name=None):
   """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains.
 
-  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm
-  that takes a series of gradient-informed steps to produce a Metropolis
-  proposal. This function samples from an HMC Markov chain at `current_state`
-  and whose stationary distribution has log-unnormalized-density
+  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
+  algorithm that takes a series of gradient-informed steps to produce
+  a Metropolis proposal. This function samples from an HMC Markov
+  chain whose initial state is `initial_x` and whose stationary
+  distribution has log-density `target_log_prob_fn()`.
+
+  This function can update multiple chains in parallel. It assumes
+  that all dimensions of `initial_x` not specified in `event_dims` are
+  independent, and should therefore be updated independently. The
+  output of `target_log_prob_fn()` should sum log-probabilities across
+  all event dimensions. Slices along dimensions not in `event_dims`
+  may have different target distributions; this is up to
   `target_log_prob_fn()`.
 
-  This function samples from multiple chains in parallel. It assumes that the
-  the leftmost dimensions of (each) `current_state` (part) index an independent
-  chain.  The function `target_log_prob_fn()` sums log-probabilities across
-  event dimensions (i.e., current state (part) rightmost dimensions). Each
-  element of the output of `target_log_prob_fn()` represents the (possibly
-  unnormalized) log-probability of the joint distribution over (all) the current
-  state (parts).
+  This function basically just wraps `hmc.kernel()` in a tf.scan() loop.
 
-  The `current_state` can be represented as a single `Tensor` or a `list` of
-  `Tensors` which collectively represent the current state. When specifying a
-  `list`, one must also specify a list of `step_size`s.
-
-  Only one out of every `num_steps_between_samples + 1` steps is included in the
-  returned results. This "thinning" comes at a cost of reduced statistical
-  power, while reducing memory requirements and autocorrelation. For more
-  discussion see [1].
+  Args:
+    n_iterations: Integer number of Markov chain updates to run.
+    step_size: Scalar step size or array of step sizes for the
+      leapfrog integrator. Broadcasts to the shape of
+      `initial_x`. Larger step sizes lead to faster progress, but
+      too-large step sizes make rejection exponentially more likely.
+      When possible, it's often helpful to match per-variable step
+      sizes to the standard deviations of the target distribution in
+      each variable.
+    n_leapfrog_steps: Integer number of steps to run the leapfrog
+      integrator for. Total progress per HMC step is roughly
+      proportional to step_size * n_leapfrog_steps.
+    initial_x: Tensor of initial state(s) of the Markov chain(s).
+    target_log_prob_fn: Python callable which takes an argument like `initial_x`
+      and returns its (possibly unnormalized) log-density under the target
+      distribution.
+    event_dims: List of dimensions that should not be treated as
+      independent. This allows for multiple chains to be run independently
+      in parallel. Default is (), i.e., all dimensions are independent.
+    name: Python `str` name prefixed to Ops created by this function.
 
-  [1]: "Statistically efficient thinning of a Markov chain sampler."
-       Art B. Owen. April 2017.
-       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
+  Returns:
+    acceptance_probs: Tensor with the acceptance probabilities for each
+      iteration. Has shape matching `target_log_prob_fn(initial_x)`.
+    chain_states: Tensor with the state of the Markov chain at each iteration.
+      Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`.
 
   #### Examples:
 
-  ##### Sample from a diagonal-variance Gaussian.
-
   ```python
-  tfd = tf.contrib.distributions
-
-  def make_likelihood(true_variances):
-    return tfd.MultivariateNormalDiag(
-        scale_diag=tf.sqrt(true_variances))
-
-  dims = 10
-  dtype = np.float32
-  true_variances = tf.linspace(dtype(1), dtype(3), dims)
-  likelihood = make_likelihood(true_variances)
-
-  states, kernel_results = hmc.sample_chain(
-      num_results=1000,
-      target_log_prob_fn=likelihood.log_prob,
-      current_state=tf.zeros(dims),
-      step_size=0.5,
-      num_leapfrog_steps=2,
-      num_burnin_steps=500)
-
-  # Compute sample stats.
-  sample_mean = tf.reduce_mean(states, axis=0)
-  sample_var = tf.reduce_mean(
-      tf.squared_difference(states, sample_mean),
-      axis=0)
+  # Sampling from a standard normal (note `log_joint()` is unnormalized):
+  def log_joint(x):
+    return tf.reduce_sum(-0.5 * tf.square(x))
+  chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint,
+                                      event_dims=[0])
+  # Discard first half of chain as warmup/burn-in
+  warmed_up = chain[500:]
+  mean_est = tf.reduce_mean(warmed_up, 0)
+  var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
   ```
 
-  ##### Sampling from factor-analysis posteriors with known factors.
-
-  I.e.,
-
-  ```none
-  for i=1..n:
-    w[i] ~ Normal(0, eye(d))            # prior
-    x[i] ~ Normal(loc=matmul(w[i], F))  # likelihood
+  ```python
+  # Sampling from a diagonal-variance Gaussian:
+  variances = tf.linspace(1., 3., 10)
+  def log_joint(x):
+    return tf.reduce_sum(-0.5 / variances * tf.square(x))
+  chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint,
+                                      event_dims=[0])
+  # Discard first half of chain as warmup/burn-in
+  warmed_up = chain[500:]
+  mean_est = tf.reduce_mean(warmed_up, 0)
+  var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
   ```
 
-  where `F` denotes factors.
-
   ```python
-  tfd = tf.contrib.distributions
-
-  def make_prior(dims, dtype):
-    return tfd.MultivariateNormalDiag(
-        loc=tf.zeros(dims, dtype))
-
-  def make_likelihood(weights, factors):
-    return tfd.MultivariateNormalDiag(
-        loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))
-
-  # Setup data.
-  num_weights = 10
-  num_factors = 4
-  num_chains = 100
-  dtype = np.float32
-
-  prior = make_prior(num_weights, dtype)
-  weights = prior.sample(num_chains)
-  factors = np.random.randn(num_factors, num_weights).astype(dtype)
-  x = make_likelihood(weights, factors).sample(num_chains)
-
-  def target_log_prob(w):
-    # Target joint is: `f(w) = p(w, x | factors)`.
-    return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)
-
-  # Get `num_results` samples from `num_chains` independent chains.
-  chains_states, kernels_results = hmc.sample_chain(
-      num_results=1000,
-      target_log_prob_fn=target_log_prob,
-      current_state=tf.zeros([num_chains, dims], dtype),
-      step_size=0.1,
-      num_leapfrog_steps=2,
-      num_burnin_steps=500)
-
-  # Compute sample stats.
-  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
-  sample_var = tf.reduce_mean(
-      tf.squared_difference(chains_states, sample_mean),
-      axis=[0, 1])
+  # Sampling from factor-analysis posteriors with known factors W:
+  # mu[i, j] ~ Normal(0, 1)
+  # x[i] ~ Normal(matmul(mu[i], W), I)
+  def log_joint(mu, x, W):
+    prior = -0.5 * tf.reduce_sum(tf.square(mu), 1)
+    x_mean = tf.matmul(mu, W)
+    likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1)
+    return prior + likelihood
+  chain, acceptance_probs = hmc.chain(1000, 0.1, 2,
+                                      tf.zeros([x.shape[0], W.shape[0]]),
+                                      lambda mu: log_joint(mu, x, W),
+                                      event_dims=[1])
+  # Discard first half of chain as warmup/burn-in
+  warmed_up = chain[500:]
+  mean_est = tf.reduce_mean(warmed_up, 0)
+  var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
   ```
 
-  Args:
-    num_results: Integer number of Markov chain draws.
-    target_log_prob_fn: Python callable which takes an argument like
-      `current_state` (or `*current_state` if it's a list) and returns its
-      (possibly unnormalized) log-density under the target distribution.
-    current_state: `Tensor` or Python `list` of `Tensor`s representing the
-      current state(s) of the Markov chain(s). The first `r` dimensions index
-      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
-    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
-      for the leapfrog integrator. Must broadcast with the shape of
-      `current_state`. Larger step sizes lead to faster progress, but too-large
-      step sizes make rejection exponentially more likely. When possible, it's
-      often helpful to match per-variable step sizes to the standard deviations
-      of the target distribution in each variable.
-    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
-      for. Total progress per HMC step is roughly proportional to `step_size *
-      num_leapfrog_steps`.
-    num_burnin_steps: Integer number of chain steps to take before starting to
-      collect results.
-      Default value: 0 (i.e., no burn-in).
-    num_steps_between_results: Integer number of chain steps between collecting
-      a result. Only one out of every `num_steps_between_samples + 1` steps is
-      included in the returned results. This "thinning" comes at a cost of
-      reduced statistical power, while reducing memory requirements and
-      autocorrelation. For more discussion see [1].
-      Default value: 0 (i.e., no subsampling).
-    seed: Python integer to seed the random number generator.
-    current_target_log_prob: (Optional) `Tensor` representing the value of
-      `target_log_prob_fn` at the `current_state`. The only reason to specify
-      this argument is to reduce TF graph size.
-      Default value: `None` (i.e., compute as needed).
-    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
-      representing gradient of `target_log_prob` at the `current_state` and wrt
-      the `current_state`. Must have same shape as `current_state`. The only
-      reason to specify this argument is to reduce TF graph size.
-      Default value: `None` (i.e., compute as needed).
-    name: Python `str` name prefixed to Ops created by this function.
-      Default value: `None` (i.e., "hmc_sample_chain").
-
-  Returns:
-    accepted_states: Tensor or Python list of `Tensor`s representing the
-      state(s) of the Markov chain(s) at each result step. Has same shape as
-      input `current_state` but with a prepended `num_results`-size dimension.
-    kernel_results: `collections.namedtuple` of internal calculations used to
-      advance the chain.
+  ```python
+  # Sampling from the posterior of a Bayesian regression model.:
+
+  # Run 100 chains in parallel, each with a different initialization.
+  initial_beta = tf.random_normal([100, x.shape[1]])
+  chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta,
+                                      log_joint_partial, event_dims=[1])
+  # Discard first halves of chains as warmup/burn-in
+  warmed_up = chain[500:]
+  # Averaging across samples within a chain and across chains
+  mean_est = tf.reduce_mean(warmed_up, [0, 1])
+  var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est)
+  ```
   """
-  with ops.name_scope(
-      name, "hmc_sample_chain",
-      [num_results, current_state, step_size, num_leapfrog_steps,
-       num_burnin_steps, num_steps_between_results, seed,
-       current_target_log_prob, current_grads_target_log_prob]):
-    with ops.name_scope("initialize"):
-      [
-          current_state,
-          step_size,
-          current_target_log_prob,
-          current_grads_target_log_prob,
-      ] = _prepare_args(
-          target_log_prob_fn, current_state, step_size,
-          current_target_log_prob, current_grads_target_log_prob)
-    def _run_chain(num_steps, current_state, seed, kernel_results):
-      """Runs the chain(s) for `num_steps`."""
-      def _loop_body(iter_, current_state, kernel_results):
-        return [iter_ + 1] + list(kernel(
-            target_log_prob_fn,
-            current_state,
-            step_size,
-            num_leapfrog_steps,
-            seed,
-            kernel_results.current_target_log_prob,
-            kernel_results.current_grads_target_log_prob))
-      return control_flow_ops.while_loop(
-          cond=lambda iter_, *args: iter_ < num_steps,
-          body=_loop_body,
-          loop_vars=[0, current_state, kernel_results])[1:]  # Lop-off "iter_".
-
-    def _scan_body(args_list, _):
-      """Closure which implements `tf.scan` body."""
-      current_state, kernel_results = args_list
-      return _run_chain(num_steps_between_results + 1, current_state, seed,
-                        kernel_results)
-
-    current_state, kernel_results = _run_chain(
-        num_burnin_steps,
-        current_state,
-        distributions_util.gen_new_seed(
-            seed, salt="hmc_sample_chain_burnin"),
-        _make_dummy_kernel_results(
-            current_state,
-            current_target_log_prob,
-            current_grads_target_log_prob))
-
+  with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size,
+                                          n_leapfrog_steps, initial_x]):
+    initial_x = ops.convert_to_tensor(initial_x, name='initial_x')
+    non_event_shape = array_ops.shape(target_log_prob_fn(initial_x))
+
+    def body(a, _):
+      updated_x, acceptance_probs, log_prob, grad = kernel(
+          step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims,
+          a[2], a[3])
+      return updated_x, acceptance_probs, log_prob, grad
+
+    potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
+    potential, grad = potential_and_grad(initial_x)
     return functional_ops.scan(
-        fn=_scan_body,
-        elems=array_ops.zeros(num_results, dtype=dtypes.bool),  # Dummy arg.
-        initializer=[current_state, kernel_results])
-
-
-def sample_annealed_importance_chain(
-    proposal_log_prob_fn,
-    num_steps,
-    target_log_prob_fn,
-    current_state,
-    step_size,
-    num_leapfrog_steps,
-    seed=None,
-    name=None):
+        body, array_ops.zeros(n_iterations, dtype=initial_x.dtype),
+        (initial_x,
+         array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
+         -potential, -grad))[:2]
+
+
+def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
+              target_log_prob_fn, proposal_log_prob_fn, event_dims=(),
+              name=None):
   """Runs annealed importance sampling (AIS) to estimate normalizing constants.
 
-  This function uses Hamiltonian Monte Carlo to sample from a series of
+  This routine uses Hamiltonian Monte Carlo to sample from a series of
   distributions that slowly interpolates between an initial "proposal"
-  distribution:
+  distribution
 
   `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`
 
-  and the target distribution:
+  and the target distribution
 
   `exp(target_log_prob_fn(x) - target_log_normalizer)`,
 
@@ -325,183 +202,113 @@ def sample_annealed_importance_chain(
   normalizing constants of the initial distribution and the target
   distribution:
 
-  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.
+  E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer).
 
-  #### Examples:
+  Args:
+    n_iterations: Integer number of Markov chain updates to run. More
+      iterations means more expense, but smoother annealing between q
+      and p, which in turn means exponentially lower variance for the
+      normalizing constant estimator.
+    step_size: Scalar step size or array of step sizes for the
+      leapfrog integrator. Broadcasts to the shape of
+      `initial_x`. Larger step sizes lead to faster progress, but
+      too-large step sizes make rejection exponentially more likely.
+      When possible, it's often helpful to match per-variable step
+      sizes to the standard deviations of the target distribution in
+      each variable.
+    n_leapfrog_steps: Integer number of steps to run the leapfrog
+      integrator for. Total progress per HMC step is roughly
+      proportional to step_size * n_leapfrog_steps.
+    initial_x: Tensor of initial state(s) of the Markov chain(s). Must
+      be a sample from q, or results will be incorrect.
+    target_log_prob_fn: Python callable which takes an argument like `initial_x`
+      and returns its (possibly unnormalized) log-density under the target
+      distribution.
+    proposal_log_prob_fn: Python callable that returns the log density of the
+      initial distribution.
+    event_dims: List of dimensions that should not be treated as
+      independent. This allows for multiple chains to be run independently
+      in parallel. Default is (), i.e., all dimensions are independent.
+    name: Python `str` name prefixed to Ops created by this function.
+
+  Returns:
+    ais_weights: Tensor with the estimated weight(s). Has shape matching
+      `target_log_prob_fn(initial_x)`.
+    chain_states: Tensor with the state(s) of the Markov chain(s) the final
+      iteration. Has shape matching `initial_x`.
+    acceptance_probs: Tensor with the acceptance probabilities for the final
+      iteration. Has shape matching `target_log_prob_fn(initial_x)`.
 
-  ##### Estimate the normalizing constant of a log-gamma distribution.
+  #### Examples:
 
   ```python
-  tfd = tf.contrib.distributions
-
+  # Estimating the normalizing constant of a log-gamma distribution:
+  def proposal_log_prob(x):
+    # Standard normal log-probability. This is properly normalized.
+    return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1)
+  def target_log_prob(x):
+    # Unnormalized log-gamma(2, 3) distribution.
+    # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1]
+    return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1)
   # Run 100 AIS chains in parallel
-  num_chains = 100
-  dims = 20
-  dtype = np.float32
-
-  proposal = tfd.MultivatiateNormalDiag(
-     loc=tf.zeros([dims], dtype=dtype))
-
-  target = tfd.TransformedDistribution(
-    distribution=tfd.Gamma(concentration=dtype(2),
-                           rate=dtype(3)),
-    bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()),
-    event_shape=[dims])
-
-  chains_state, ais_weights, kernels_results = (
-      hmc.sample_annealed_importance_chain(
-          proposal_log_prob_fn=proposal.log_prob,
-          num_steps=1000,
-          target_log_prob_fn=target.log_prob,
-          step_size=0.2,
-          current_state=proposal.sample(num_chains),
-          num_leapfrog_steps=2))
-
-  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
-                              - np.log(num_chains))
-  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
+  initial_x = tf.random_normal([100, 20])
+  w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob,
+                          proposal_log_prob, event_dims=[1])
+  log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100)
   ```
 
-  ##### Estimate marginal likelihood of a Bayesian regression model.
-
   ```python
-  tfd = tf.contrib.distributions
-
-  def make_prior(dims, dtype):
-    return tfd.MultivariateNormalDiag(
-        loc=tf.zeros(dims, dtype))
-
-  def make_likelihood(weights, x):
-    return tfd.MultivariateNormalDiag(
-        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))
-
+  # Estimating the marginal likelihood of a Bayesian regression model:
+  base_measure = -0.5 * np.log(2 * np.pi)
+  def proposal_log_prob(x):
+    # Standard normal log-probability. This is properly normalized.
+    return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1)
+  def regression_log_joint(beta, x, y):
+    # This function returns a vector whose ith element is log p(beta[i], y | x).
+    # Each row of beta corresponds to the state of an independent Markov chain.
+    log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1)
+    means = tf.matmul(beta, x, transpose_b=True)
+    log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) +
+                                   base_measure, 1)
+    return log_prior + log_likelihood
+  def log_joint_partial(beta):
+    return regression_log_joint(beta, x, y)
   # Run 100 AIS chains in parallel
-  num_chains = 100
-  dims = 10
-  dtype = np.float32
-
-  # Make training data.
-  x = np.random.randn(num_chains, dims).astype(dtype)
-  true_weights = np.random.randn(dims).astype(dtype)
-  y = np.dot(x, true_weights) + np.random.randn(num_chains)
-
-  # Setup model.
-  prior = make_prior(dims, dtype)
-  def target_log_prob_fn(weights):
-    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)
-
-  proposal = tfd.MultivariateNormalDiag(
-      loc=tf.zeros(dims, dtype))
-
-  weight_samples, ais_weights, kernel_results = (
-      hmc.sample_annealed_importance_chain(
-        num_steps=1000,
-        proposal_log_prob_fn=proposal.log_prob,
-        target_log_prob_fn=target_log_prob_fn
-        current_state=tf.zeros([num_chains, dims], dtype),
-        step_size=0.1,
-        num_leapfrog_steps=2))
-  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
-                             - np.log(num_chains))
+  initial_beta = tf.random_normal([100, x.shape[1]])
+  w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta,
+                                     log_joint_partial, proposal_log_prob,
+                                     event_dims=[1])
+  log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100)
   ```
-
-  Args:
-    proposal_log_prob_fn: Python callable that returns the log density of the
-      initial distribution.
-    num_steps: Integer number of Markov chain updates to run. More
-      iterations means more expense, but smoother annealing between q
-      and p, which in turn means exponentially lower variance for the
-      normalizing constant estimator.
-    target_log_prob_fn: Python callable which takes an argument like
-      `current_state` (or `*current_state` if it's a list) and returns its
-      (possibly unnormalized) log-density under the target distribution.
-    current_state: `Tensor` or Python `list` of `Tensor`s representing the
-      current state(s) of the Markov chain(s). The first `r` dimensions index
-      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
-    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
-      for the leapfrog integrator. Must broadcast with the shape of
-      `current_state`. Larger step sizes lead to faster progress, but too-large
-      step sizes make rejection exponentially more likely. When possible, it's
-      often helpful to match per-variable step sizes to the standard deviations
-      of the target distribution in each variable.
-    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
-      for. Total progress per HMC step is roughly proportional to `step_size *
-      num_leapfrog_steps`.
-    seed: Python integer to seed the random number generator.
-    name: Python `str` name prefixed to Ops created by this function.
-      Default value: `None` (i.e., "hmc_sample_annealed_importance_chain").
-
-  Returns:
-    accepted_state: `Tensor` or Python list of `Tensor`s representing the
-      state(s) of the Markov chain(s) at the final iteration. Has same shape as
-      input `current_state`.
-    ais_weights: Tensor with the estimated weight(s). Has shape matching
-      `target_log_prob_fn(current_state)`.
   """
-  def make_convex_combined_log_prob_fn(iter_):
-    def _fn(*args):
-      p = proposal_log_prob_fn(*args)
-      t = target_log_prob_fn(*args)
-      dtype = p.dtype.base_dtype
-      beta = (math_ops.cast(iter_ + 1, dtype)
-              / math_ops.cast(num_steps, dtype))
-      return (1. - beta) * p + beta * t
-    return _fn
-
-  with ops.name_scope(
-      name, "hmc_sample_annealed_importance_chain",
-      [num_steps, current_state, step_size, num_leapfrog_steps, seed]):
-    with ops.name_scope("initialize"):
-      [
-          current_state,
-          step_size,
-          current_log_prob,
-          current_grads_log_prob,
-      ] = _prepare_args(
-          make_convex_combined_log_prob_fn(iter_=0),
-          current_state,
-          step_size,
-          description="convex_combined_log_prob")
-    def _loop_body(iter_, ais_weights, current_state, kernel_results):
-      """Closure which implements `tf.while_loop` body."""
-      current_state_parts = (list(current_state)
-                             if _is_list_like(current_state)
-                             else [current_state])
-      ais_weights += ((target_log_prob_fn(*current_state_parts)
-                       - proposal_log_prob_fn(*current_state_parts))
-                      / math_ops.cast(num_steps, ais_weights.dtype))
-      return [iter_ + 1, ais_weights] + list(kernel(
-          make_convex_combined_log_prob_fn(iter_),
-          current_state,
-          step_size,
-          num_leapfrog_steps,
-          seed,
-          kernel_results.current_target_log_prob,
-          kernel_results.current_grads_target_log_prob))
-
-    [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop(
-        cond=lambda iter_, *args: iter_ < num_steps,
-        body=_loop_body,
-        loop_vars=[
-            0,  # iter_
-            array_ops.zeros_like(current_log_prob),  # ais_weights
-            current_state,
-            _make_dummy_kernel_results(current_state,
-                                       current_log_prob,
-                                       current_grads_log_prob),
-        ])[1:]  # Lop-off "iter_".
-
-    return [current_state, ais_weights, kernel_results]
-
-
-def kernel(target_log_prob_fn,
-           current_state,
-           step_size,
-           num_leapfrog_steps,
-           seed=None,
-           current_target_log_prob=None,
-           current_grads_target_log_prob=None,
-           name=None):
+  with ops.name_scope(name, 'hmc_ais_chain',
+                      [n_iterations, step_size, n_leapfrog_steps, initial_x]):
+    non_event_shape = array_ops.shape(target_log_prob_fn(initial_x))
+
+    beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:]
+    def _body(a, beta):  # pylint: disable=missing-docstring
+      def log_prob_beta(x):
+        return ((1 - beta) * proposal_log_prob_fn(x) +
+                beta * target_log_prob_fn(x))
+      last_x = a[0]
+      w = a[2]
+      w += (1. / n_iterations) * (target_log_prob_fn(last_x) -
+                                  proposal_log_prob_fn(last_x))
+      # TODO(b/66917083): There's an opportunity for gradient reuse here.
+      updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps,
+                                                 last_x, log_prob_beta,
+                                                 event_dims)
+      return updated_x, acceptance_probs, w
+
+    x, acceptance_probs, w = functional_ops.scan(
+        _body, beta_series,
+        (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
+         array_ops.zeros(non_event_shape, dtype=initial_x.dtype)))
+  return w[-1], x[-1], acceptance_probs[-1]
+
+
+def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
+           x_log_prob=None, x_grad=None, name=None):
   """Runs one iteration of Hamiltonian Monte Carlo.
 
   Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
@@ -509,625 +316,334 @@ def kernel(target_log_prob_fn,
   a Metropolis proposal. This function applies one step of HMC to
   randomly update the variable `x`.
 
-  This function can update multiple chains in parallel. It assumes that all
-  leftmost dimensions of `current_state` index independent chain states (and are
-  therefore updated independently). The output of `target_log_prob_fn()` should
-  sum log-probabilities across all event dimensions. Slices along the rightmost
-  dimensions may have different target distributions; for example,
-  `current_state[0, :]` could have a different target distribution from
-  `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
-  independent chains is `tf.size(target_log_prob_fn(*current_state))`.)
+  This function can update multiple chains in parallel. It assumes
+  that all dimensions of `x` not specified in `event_dims` are
+  independent, and should therefore be updated independently. The
+  output of `target_log_prob_fn()` should sum log-probabilities across
+  all event dimensions. Slices along dimensions not in `event_dims`
+  may have different target distributions; for example, if
+  `event_dims == (1,)`, then `x[0, :]` could have a different target
+  distribution from x[1, :]. This is up to `target_log_prob_fn()`.
 
-  #### Examples:
+  Args:
+    step_size: Scalar step size or array of step sizes for the
+      leapfrog integrator. Broadcasts to the shape of
+      `x`. Larger step sizes lead to faster progress, but
+      too-large step sizes make rejection exponentially more likely.
+      When possible, it's often helpful to match per-variable step
+      sizes to the standard deviations of the target distribution in
+      each variable.
+    n_leapfrog_steps: Integer number of steps to run the leapfrog
+      integrator for. Total progress per HMC step is roughly
+      proportional to step_size * n_leapfrog_steps.
+    x: Tensor containing the value(s) of the random variable(s) to update.
+    target_log_prob_fn: Python callable which takes an argument like `initial_x`
+      and returns its (possibly unnormalized) log-density under the target
+      distribution.
+    event_dims: List of dimensions that should not be treated as
+      independent. This allows for multiple chains to be run independently
+      in parallel. Default is (), i.e., all dimensions are independent.
+    x_log_prob (optional): Tensor containing the cached output of a previous
+      call to `target_log_prob_fn()` evaluated at `x` (such as that provided by
+      a previous call to `kernel()`). Providing `x_log_prob` and
+      `x_grad` saves one gradient computation per call to `kernel()`.
+    x_grad (optional): Tensor containing the cached gradient of
+      `target_log_prob_fn()` evaluated at `x` (such as that provided by
+      a previous call to `kernel()`). Providing `x_log_prob` and
+      `x_grad` saves one gradient computation per call to `kernel()`.
+    name: Python `str` name prefixed to Ops created by this function.
 
-  ##### Simple chain with warm-up.
+  Returns:
+    updated_x: The updated variable(s) x. Has shape matching `initial_x`.
+    acceptance_probs: Tensor with the acceptance probabilities for the final
+      iteration. This is useful for diagnosing step size problems etc. Has
+      shape matching `target_log_prob_fn(initial_x)`.
+    new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`.
+    new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at
+      `updated_x`.
 
-  ```python
-  tfd = tf.contrib.distributions
+  #### Examples:
 
+  ```python
   # Tuning acceptance rates:
-  dtype = np.float32
   target_accept_rate = 0.631
-  num_warmup_iter = 500
-  num_chain_iter = 500
-
-  x = tf.get_variable(name="x", initializer=dtype(1))
-  step_size = tf.get_variable(name="step_size", initializer=dtype(1))
-
-  target = tfd.Normal(loc=dtype(0), scale=dtype(1))
-
-  new_x, other_results = hmc.kernel(
-      target_log_prob_fn=target.log_prob,
-      current_state=x,
-      step_size=step_size,
-      num_leapfrog_steps=3)[:4]
-
-  x_update = x.assign(new_x)
-
-  step_size_update = step_size.assign_add(
-      step_size * tf.where(
-        other_results.acceptance_probs > target_accept_rate,
-        0.01, -0.01))
-
-  warmup = tf.group([x_update, step_size_update])
-
-  tf.global_variables_initializer().run()
-
-  sess.graph.finalize()  # No more graph building.
-
+  def target_log_prob(x):
+    # Standard normal
+    return tf.reduce_sum(-0.5 * tf.square(x))
+  initial_x = tf.zeros([10])
+  initial_log_prob = target_log_prob(initial_x)
+  initial_grad = tf.gradients(initial_log_prob, initial_x)[0]
+  # Algorithm state
+  x = tf.Variable(initial_x, name='x')
+  step_size = tf.Variable(1., name='step_size')
+  last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob')
+  last_grad = tf.Variable(initial_grad, name='last_grad')
+  # Compute updates
+  new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x,
+                                                      target_log_prob,
+                                                      event_dims=[0],
+                                                      x_log_prob=last_log_prob)
+  x_update = tf.assign(x, new_x)
+  log_prob_update = tf.assign(last_log_prob, log_prob)
+  grad_update = tf.assign(last_grad, grad)
+  step_size_update = tf.assign(step_size,
+                               tf.where(acceptance_prob > target_accept_rate,
+                                        step_size * 1.01, step_size / 1.01))
+  adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update]
+  sampling_updates = [x_update, log_prob_update, grad_update]
+
+  sess = tf.Session()
+  sess.run(tf.global_variables_initializer())
   # Warm up the sampler and adapt the step size
-  for _ in xrange(num_warmup_iter):
-    sess.run(warmup)
-
+  for i in xrange(500):
+    sess.run(adaptive_updates)
   # Collect samples without adapting step size
-  samples = np.zeros([num_chain_iter])
-  for i in xrange(num_chain_iter):
-    _, x_, target_log_prob_, grad_ = sess.run([
-        x_update,
-        x,
-        other_results.target_log_prob,
-        other_results.grads_target_log_prob])
-    samples[i] = x_
-
-  print(samples.mean(), samples.std())
+  samples = np.zeros([500, 10])
+  for i in xrange(500):
+    x_val, _ = sess.run([new_x, sampling_updates])
+    samples[i] = x_val
   ```
 
-  ##### Sample from more complicated posterior.
-
-  I.e.,
-
-  ```none
-    W ~ MVN(loc=0, scale=sigma * eye(dims))
-    for i=1...num_samples:
-        X[i] ~ MVN(loc=0, scale=eye(dims))
-      eps[i] ~ Normal(loc=0, scale=1)
-        Y[i] = X[i].T * W + eps[i]
+  ```python
+  # Empirical-Bayes estimation of a hyperparameter by MCMC-EM:
+
+  # Problem setup
+  N = 150
+  D = 10
+  x = np.random.randn(N, D).astype(np.float32)
+  true_sigma = 0.5
+  true_beta = true_sigma * np.random.randn(D).astype(np.float32)
+  y = x.dot(true_beta) + np.random.randn(N).astype(np.float32)
+
+  def log_prior(beta, log_sigma):
+    return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) -
+                         log_sigma)
+  def regression_log_joint(beta, log_sigma, x, y):
+    # This function returns log p(beta | log_sigma) + log p(y | x, beta).
+    means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True)
+    means = tf.squeeze(means)
+    log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means))
+    return log_prior(beta, log_sigma) + log_likelihood
+  def log_joint_partial(beta):
+    return regression_log_joint(beta, log_sigma, x, y)
+  # Our estimate of log(sigma)
+  log_sigma = tf.Variable(0., name='log_sigma')
+  # The state of the Markov chain
+  beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta')
+  new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial,
+                                 event_dims=[0])
+  beta_update = tf.assign(beta, new_beta)
+  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+  with tf.control_dependencies([beta_update]):
+    log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma),
+                                          var_list=[log_sigma])
+
+  sess = tf.Session()
+  sess.run(tf.global_variables_initializer())
+  log_sigma_history = np.zeros(1000)
+  for i in xrange(1000):
+    log_sigma_val, _ = sess.run([log_sigma, log_sigma_update])
+    log_sigma_history[i] = log_sigma_val
+  # Should converge to something close to true_sigma
+  plt.plot(np.exp(log_sigma_history))
   ```
+  """
+  with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]):
+    potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
+    x = ops.convert_to_tensor(x, name='x')
+
+    x_shape = array_ops.shape(x)
+    m = random_ops.random_normal(x_shape, dtype=x.dtype)
+
+    kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims)
+
+    if (x_log_prob is not None) and (x_grad is not None):
+      log_potential_0, grad_0 = -x_log_prob, -x_grad  # pylint: disable=invalid-unary-operand-type
+    else:
+      if x_log_prob is not None:
+        logging.warn('x_log_prob was provided, but x_grad was not,'
+                     ' so x_log_prob was not used.')
+      if x_grad is not None:
+        logging.warn('x_grad was provided, but x_log_prob was not,'
+                     ' so x_grad was not used.')
+      log_potential_0, grad_0 = potential_and_grad(x)
+
+    new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator(
+        step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0)
+
+    kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims)
+
+    energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0
+    # Treat NaN as infinite energy (and therefore guaranteed rejection).
+    energy_change = array_ops.where(
+        math_ops.is_nan(energy_change),
+        array_ops.fill(array_ops.shape(energy_change),
+                       energy_change.dtype.as_numpy_dtype(np.inf)),
+        energy_change)
+    acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.))
+    accepted = (
+        random_ops.random_uniform(
+            array_ops.shape(acceptance_probs), dtype=x.dtype)
+        < acceptance_probs)
+    new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0)
+
+    # TODO(b/65738010): This should work, but it doesn't for now.
+    # reduced_shape = math_ops.reduced_shape(x_shape, event_dims)
+    reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims,
+                                                        keep_dims=True))
+    accepted = array_ops.reshape(accepted, reduced_shape)
+    accepted = math_ops.logical_or(
+        accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool))
+    new_x = array_ops.where(accepted, new_x, x)
+    new_grad = -array_ops.where(accepted, grad_1, grad_0)
+
+  # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect
+  # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate).  This
+  # should be fixed.
+  return new_x, acceptance_probs, new_log_prob, new_grad
+
+
+def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum,
+                        potential_and_grad, initial_grad, name=None):
+  """Applies `n_steps` steps of the leapfrog integrator.
+
+  This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing
+  gradient computations where possible.
 
-  ```python
-  tfd = tf.contrib.distributions
-
-  def make_training_data(num_samples, dims, sigma):
-    dt = np.asarray(sigma).dtype
-    zeros = tf.zeros(dims, dtype=dt)
-    x = tfd.MultivariateNormalDiag(
-        loc=zeros).sample(num_samples, seed=1)
-    w = tfd.MultivariateNormalDiag(
-        loc=zeros,
-        scale_identity_multiplier=sigma).sample(seed=2)
-    noise = tfd.Normal(
-        loc=dt(0),
-        scale=dt(1)).sample(num_samples, seed=3)
-    y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
-    return y, x, w
-
-  def make_prior(sigma, dims):
-    # p(w | sigma)
-    return tfd.MultivariateNormalDiag(
-        loc=tf.zeros([dims], dtype=sigma.dtype),
-        scale_identity_multiplier=sigma)
-
-  def make_likelihood(x, w):
-    # p(y | x, w)
-    return tfd.MultivariateNormalDiag(
-        loc=tf.tensordot(x, w, axes=[[1], [0]]))
-
-  # Setup assumptions.
-  dtype = np.float32
-  num_samples = 150
-  dims = 10
-  num_iters = int(5e3)
-
-  true_sigma = dtype(0.5)
-  y, x, true_weights = make_training_data(num_samples, dims, true_sigma)
-
-  # Estimate of `log(true_sigma)`.
-  log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
-  sigma = tf.exp(log_sigma)
-
-  # State of the Markov chain.
-  weights = tf.get_variable(
-      name="weights",
-      initializer=np.random.randn(dims).astype(dtype))
-
-  prior = make_prior(sigma, dims)
-
-  def joint_log_prob_fn(w):
-    # f(w) = log p(w, y | x)
-    return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)
-
-  weights_update = weights.assign(
-      hmc.kernel(target_log_prob_fn=joint_log_prob,
-                 current_state=weights,
-                 step_size=0.1,
-                 num_leapfrog_steps=5)[0])
-
-  with tf.control_dependencies([weights_update]):
-    loss = -prior.log_prob(weights)
+  Args:
+    step_size: Scalar step size or array of step sizes for the
+      leapfrog integrator. Broadcasts to the shape of
+      `initial_position`. Larger step sizes lead to faster progress, but
+      too-large step sizes lead to larger discretization error and
+      worse energy conservation.
+    n_steps: Number of steps to run the leapfrog integrator.
+    initial_position: Tensor containing the value(s) of the position variable(s)
+      to update.
+    initial_momentum: Tensor containing the value(s) of the momentum variable(s)
+      to update.
+    potential_and_grad: Python callable that takes a position tensor like
+      `initial_position` and returns the potential energy and its gradient at
+      that position.
+    initial_grad: Tensor with the value of the gradient of the potential energy
+      at `initial_position`.
+    name: Python `str` name prefixed to Ops created by this function.
 
-  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
-  log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])
+  Returns:
+    updated_position: Updated value of the position.
+    updated_momentum: Updated value of the momentum.
+    new_potential: Potential energy of the new position. Has shape matching
+      `potential_and_grad(initial_position)`.
+    new_grad: Gradient from potential_and_grad() evaluated at the new position.
+      Has shape matching `initial_position`.
+
+  Example: Simple quadratic potential.
 
-  sess.graph.finalize()  # No more graph building.
+  ```python
+  def potential_and_grad(position):
+    return tf.reduce_sum(0.5 * tf.square(position)), position
+  position = tf.placeholder(np.float32)
+  momentum = tf.placeholder(np.float32)
+  potential, grad = potential_and_grad(position)
+  new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator(
+    0.1, 3, position, momentum, potential_and_grad, grad)
+
+  sess = tf.Session()
+  position_val = np.random.randn(10)
+  momentum_val = np.random.randn(10)
+  potential_val, grad_val = sess.run([potential, grad],
+                                     {position: position_val})
+  positions = np.zeros([100, 10])
+  for i in xrange(100):
+    position_val, momentum_val, potential_val, grad_val = sess.run(
+      [new_position, new_momentum, new_potential, new_grad],
+      {position: position_val, momentum: momentum_val})
+    positions[i] = position_val
+  # Should trace out sinusoidal dynamics.
+  plt.plot(positions[:, 0])
+  ```
+  """
+  def leapfrog_wrapper(step_size, x, m, grad, l):
+    x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad)
+    return step_size, x, m, grad, l + 1
 
-  tf.global_variables_initializer().run()
+  def counter_fn(a, b, c, d, counter):  # pylint: disable=unused-argument
+    return counter < n_steps
 
-  sigma_history = np.zeros(num_iters, dtype)
-  weights_history = np.zeros([num_iters, dims], dtype)
+  with ops.name_scope(name, 'leapfrog_integrator',
+                      [step_size, n_steps, initial_position, initial_momentum,
+                       initial_grad]):
+    _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop(
+        counter_fn, leapfrog_wrapper, [step_size, initial_position,
+                                       initial_momentum, initial_grad,
+                                       array_ops.constant(0)], back_prop=False)
+    # We're counting on the runtime to eliminate this redundant computation.
+    new_potential, new_grad = potential_and_grad(new_x)
+  return new_x, new_m, new_potential, new_grad
 
-  for i in xrange(num_iters):
-    _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
-    weights_history[i, :] = weights_
-    sigma_history[i] = sigma_
 
-  true_weights_ = sess.run(true_weights)
+def leapfrog_step(step_size, position, momentum, potential_and_grad, grad,
+                  name=None):
+  """Applies one step of the leapfrog integrator.
 
-  # Should converge to something close to true_sigma.
-  plt.plot(sigma_history);
-  plt.ylabel("sigma");
-  plt.xlabel("iteration");
-  ```
+  Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2.
 
   Args:
-    target_log_prob_fn: Python callable which takes an argument like
-      `current_state` (or `*current_state` if it's a list) and returns its
-      (possibly unnormalized) log-density under the target distribution.
-    current_state: `Tensor` or Python `list` of `Tensor`s representing the
-      current state(s) of the Markov chain(s). The first `r` dimensions index
-      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
-    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
-      for the leapfrog integrator. Must broadcast with the shape of
-      `current_state`. Larger step sizes lead to faster progress, but too-large
-      step sizes make rejection exponentially more likely. When possible, it's
-      often helpful to match per-variable step sizes to the standard deviations
-      of the target distribution in each variable.
-    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
-      for. Total progress per HMC step is roughly proportional to `step_size *
-      num_leapfrog_steps`.
-    seed: Python integer to seed the random number generator.
-    current_target_log_prob: (Optional) `Tensor` representing the value of
-      `target_log_prob_fn` at the `current_state`. The only reason to
-      specify this argument is to reduce TF graph size.
-      Default value: `None` (i.e., compute as needed).
-    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
-      representing gradient of `current_target_log_prob` at the `current_state`
-      and wrt the `current_state`. Must have same shape as `current_state`. The
-      only reason to specify this argument is to reduce TF graph size.
-      Default value: `None` (i.e., compute as needed).
+    step_size: Scalar step size or array of step sizes for the
+      leapfrog integrator. Broadcasts to the shape of
+      `position`. Larger step sizes lead to faster progress, but
+      too-large step sizes lead to larger discretization error and
+      worse energy conservation.
+    position: Tensor containing the value(s) of the position variable(s)
+      to update.
+    momentum: Tensor containing the value(s) of the momentum variable(s)
+      to update.
+    potential_and_grad: Python callable that takes a position tensor like
+      `position` and returns the potential energy and its gradient at that
+      position.
+    grad: Tensor with the value of the gradient of the potential energy
+      at `position`.
     name: Python `str` name prefixed to Ops created by this function.
-      Default value: `None` (i.e., "hmc_kernel").
 
   Returns:
-    accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
-      of the Markov chain(s) at each result step. Has same shape as
-      `current_state`.
-    acceptance_probs: Tensor with the acceptance probabilities for each
-      iteration. Has shape matching `target_log_prob_fn(current_state)`.
-    accepted_target_log_prob: `Tensor` representing the value of
-      `target_log_prob_fn` at `accepted_state`.
-    accepted_grads_target_log_prob: Python `list` of `Tensor`s representing the
-      gradient of `accepted_target_log_prob` wrt each `accepted_state`.
-
-  Raises:
-    ValueError: if there isn't one `step_size` or a list with same length as
-      `current_state`.
-  """
-  with ops.name_scope(
-      name, "hmc_kernel",
-      [current_state, step_size, num_leapfrog_steps, seed,
-       current_target_log_prob, current_grads_target_log_prob]):
-    with ops.name_scope("initialize"):
-      [current_state_parts, step_sizes, current_target_log_prob,
-       current_grads_target_log_prob] = _prepare_args(
-           target_log_prob_fn, current_state, step_size,
-           current_target_log_prob, current_grads_target_log_prob,
-           maybe_expand=True)
-      independent_chain_ndims = distributions_util.prefer_static_rank(
-          current_target_log_prob)
-      def init_momentum(s):
-        return random_ops.random_normal(
-            shape=array_ops.shape(s),
-            dtype=s.dtype.base_dtype,
-            seed=distributions_util.gen_new_seed(
-                seed, salt="hmc_kernel_momentums"))
-      current_momentums = [init_momentum(s) for s in current_state_parts]
-
-    [
-        proposed_momentums,
-        proposed_state_parts,
-        proposed_target_log_prob,
-        proposed_grads_target_log_prob,
-    ] = _leapfrog_integrator(current_momentums,
-                             target_log_prob_fn,
-                             current_state_parts,
-                             step_sizes,
-                             num_leapfrog_steps,
-                             current_target_log_prob,
-                             current_grads_target_log_prob)
-
-    energy_change = _compute_energy_change(current_target_log_prob,
-                                           current_momentums,
-                                           proposed_target_log_prob,
-                                           proposed_momentums,
-                                           independent_chain_ndims)
-
-    # u < exp(min(-energy, 0)),  where u~Uniform[0,1)
-    # ==> -log(u) >= max(e, 0)
-    # ==> -log(u) >= e
-    # (Perhaps surprisingly, we don't have a better way to obtain a random
-    # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
-    # maxval=np.inf)` won't work.)
-    random_uniform = random_ops.random_uniform(
-        shape=array_ops.shape(energy_change),
-        dtype=energy_change.dtype,
-        seed=seed)
-    random_positive = -math_ops.log(random_uniform)
-    is_accepted = random_positive >= energy_change
-
-    accepted_target_log_prob = array_ops.where(is_accepted,
-                                               proposed_target_log_prob,
-                                               current_target_log_prob)
-
-    accepted_state_parts = [_choose(is_accepted,
-                                    proposed_state_part,
-                                    current_state_part,
-                                    independent_chain_ndims)
-                            for current_state_part, proposed_state_part
-                            in zip(current_state_parts, proposed_state_parts)]
-
-    accepted_grads_target_log_prob = [
-        _choose(is_accepted,
-                proposed_grad,
-                grad,
-                independent_chain_ndims)
-        for proposed_grad, grad
-        in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)]
-
-    maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
-    return [
-        maybe_flatten(accepted_state_parts),
-        KernelResults(
-            acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)),
-            current_grads_target_log_prob=accepted_grads_target_log_prob,
-            current_target_log_prob=accepted_target_log_prob,
-            energy_change=energy_change,
-            is_accepted=is_accepted,
-            proposed_grads_target_log_prob=proposed_grads_target_log_prob,
-            proposed_state=maybe_flatten(proposed_state_parts),
-            proposed_target_log_prob=proposed_target_log_prob,
-            random_positive=random_positive,
-        ),
-    ]
-
-
-def _leapfrog_integrator(current_momentums,
-                         target_log_prob_fn,
-                         current_state_parts,
-                         step_sizes,
-                         num_leapfrog_steps,
-                         current_target_log_prob=None,
-                         current_grads_target_log_prob=None,
-                         name=None):
-  """Applies `num_leapfrog_steps` of the leapfrog integrator.
-
-  Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.
-
-  #### Examples:
+    updated_position: Updated value of the position.
+    updated_momentum: Updated value of the momentum.
+    new_potential: Potential energy of the new position. Has shape matching
+      `potential_and_grad(position)`.
+    new_grad: Gradient from potential_and_grad() evaluated at the new position.
+      Has shape matching `position`.
 
-  ##### Simple quadratic potential.
+  Example: Simple quadratic potential.
 
   ```python
-  tfd = tf.contrib.distributions
-
-  dims = 10
-  num_iter = int(1e3)
-  dtype = np.float32
-
+  def potential_and_grad(position):
+    # Simple quadratic potential
+    return tf.reduce_sum(0.5 * tf.square(position)), position
   position = tf.placeholder(np.float32)
   momentum = tf.placeholder(np.float32)
-
-  [
-      new_momentums,
-      new_positions,
-  ] = hmc._leapfrog_integrator(
-      current_momentums=[momentum],
-      target_log_prob_fn=tfd.MultivariateNormalDiag(
-          loc=tf.zeros(dims, dtype)).log_prob,
-      current_state_parts=[position],
-      step_sizes=0.1,
-      num_leapfrog_steps=3)[:2]
-
-  sess.graph.finalize()  # No more graph building.
-
-  momentum_ = np.random.randn(dims).astype(dtype)
-  position_ = np.random.randn(dims).astype(dtype)
-
-  positions = np.zeros([num_iter, dims], dtype)
-  for i in xrange(num_iter):
-    position_, momentum_ = sess.run(
-        [new_momentums[0], new_position[0]],
-        feed_dict={position: position_, momentum: momentum_})
-    positions[i] = position_
-
-  plt.plot(positions[:, 0]);  # Sinusoidal.
+  potential, grad = potential_and_grad(position)
+  new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step(
+    0.1, position, momentum, potential_and_grad, grad)
+
+  sess = tf.Session()
+  position_val = np.random.randn(10)
+  momentum_val = np.random.randn(10)
+  potential_val, grad_val = sess.run([potential, grad],
+                                     {position: position_val})
+  positions = np.zeros([100, 10])
+  for i in xrange(100):
+    position_val, momentum_val, potential_val, grad_val = sess.run(
+      [new_position, new_momentum, new_potential, new_grad],
+      {position: position_val, momentum: momentum_val})
+    positions[i] = position_val
+  # Should trace out sinusoidal dynamics.
+  plt.plot(positions[:, 0])
   ```
-
-  Args:
-    current_momentums: Tensor containing the value(s) of the momentum
-      variable(s) to update.
-    target_log_prob_fn: Python callable which takes an argument like
-      `*current_state_parts` and returns its (possibly unnormalized) log-density
-      under the target distribution.
-    current_state_parts: Python `list` of `Tensor`s representing the current
-      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
-      the `Tensor`(s) index different chains.
-    step_sizes: Python `list` of `Tensor`s representing the step size for the
-      leapfrog integrator. Must broadcast with the shape of
-      `current_state_parts`.  Larger step sizes lead to faster progress, but
-      too-large step sizes make rejection exponentially more likely. When
-      possible, it's often helpful to match per-variable step sizes to the
-      standard deviations of the target distribution in each variable.
-    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
-      for. Total progress per HMC step is roughly proportional to `step_size *
-      num_leapfrog_steps`.
-    current_target_log_prob: (Optional) `Tensor` representing the value of
-      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
-      this argument is to reduce TF graph size.
-      Default value: `None` (i.e., compute as needed).
-    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
-      representing gradient of `target_log_prob_fn(*current_state_parts`) wrt
-      `current_state_parts`. Must have same shape as `current_state_parts`. The
-      only reason to specify this argument is to reduce TF graph size.
-      Default value: `None` (i.e., compute as needed).
-    name: Python `str` name prefixed to Ops created by this function.
-      Default value: `None` (i.e., "hmc_leapfrog_integrator").
-
-  Returns:
-    proposed_momentums: Updated value of the momentum.
-    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
-      state(s) of the Markov chain(s) at each result step. Has same shape as
-      input `current_state_parts`.
-    proposed_target_log_prob: `Tensor` representing the value of
-      `target_log_prob_fn` at `accepted_state`.
-    proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt
-      `accepted_state`.
-
-  Raises:
-    ValueError: if `len(momentums) != len(state_parts)`.
-    ValueError: if `len(state_parts) != len(step_sizes)`.
-    ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
-    TypeError: if `not target_log_prob.dtype.is_floating`.
   """
-  def _loop_body(step,
-                 current_momentums,
-                 current_state_parts,
-                 ignore_current_target_log_prob,  # pylint: disable=unused-argument
-                 current_grads_target_log_prob):
-    return [step + 1] + list(_leapfrog_step(current_momentums,
-                                            target_log_prob_fn,
-                                            current_state_parts,
-                                            step_sizes,
-                                            current_grads_target_log_prob))
-
-  with ops.name_scope(
-      name, "hmc_leapfrog_integrator",
-      [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps,
-       current_target_log_prob, current_grads_target_log_prob]):
-    if len(current_momentums) != len(current_state_parts):
-      raise ValueError("`momentums` must be in one-to-one correspondence "
-                       "with `state_parts`")
-    num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps,
-                                               name="num_leapfrog_steps")
-    current_target_log_prob, current_grads_target_log_prob = (
-        _maybe_call_fn_and_grads(
-            target_log_prob_fn,
-            current_state_parts,
-            current_target_log_prob,
-            current_grads_target_log_prob))
-    return control_flow_ops.while_loop(
-        cond=lambda iter_, *args: iter_ < num_leapfrog_steps,
-        body=_loop_body,
-        loop_vars=[
-            0,  # iter_
-            current_momentums,
-            current_state_parts,
-            current_target_log_prob,
-            current_grads_target_log_prob,
-        ],
-        back_prop=False)[1:]  # Lop-off "iter_".
-
-
-def _leapfrog_step(current_momentums,
-                   target_log_prob_fn,
-                   current_state_parts,
-                   step_sizes,
-                   current_grads_target_log_prob,
-                   name=None):
-  """Applies one step of the leapfrog integrator."""
-  with ops.name_scope(
-      name, "_leapfrog_step",
-      [current_momentums, current_state_parts, step_sizes,
-       current_grads_target_log_prob]):
-    proposed_momentums = [m + 0.5 * ss * g for m, ss, g
-                          in zip(current_momentums,
-                                 step_sizes,
-                                 current_grads_target_log_prob)]
-    proposed_state_parts = [x + ss * m for x, ss, m
-                            in zip(current_state_parts,
-                                   step_sizes,
-                                   proposed_momentums)]
-    proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts)
-    if not proposed_target_log_prob.dtype.is_floating:
-      raise TypeError("`target_log_prob_fn` must produce a `Tensor` "
-                      "with `float` `dtype`.")
-    proposed_grads_target_log_prob = gradients_ops.gradients(
-        proposed_target_log_prob, proposed_state_parts)
-    if any(g is None for g in proposed_grads_target_log_prob):
-      raise ValueError(
-          "Encountered `None` gradient. Does your target `target_log_prob_fn` "
-          "access all `tf.Variable`s via `tf.get_variable`?\n"
-          "  current_state_parts: {}\n"
-          "  proposed_state_parts: {}\n"
-          "  proposed_grads_target_log_prob: {}".format(
-              current_state_parts,
-              proposed_state_parts,
-              proposed_grads_target_log_prob))
-    proposed_momentums = [m + 0.5 * ss * g for m, ss, g
-                          in zip(proposed_momentums,
-                                 step_sizes,
-                                 proposed_grads_target_log_prob)]
-    return [
-        proposed_momentums,
-        proposed_state_parts,
-        proposed_target_log_prob,
-        proposed_grads_target_log_prob,
-    ]
-
-
-def _compute_energy_change(current_target_log_prob,
-                           current_momentums,
-                           proposed_target_log_prob,
-                           proposed_momentums,
-                           independent_chain_ndims,
-                           name=None):
-  """Helper to `kernel` which computes the energy change."""
-  with ops.name_scope(
-      name, "compute_energy_change",
-      ([current_target_log_prob, proposed_target_log_prob,
-        independent_chain_ndims] +
-       current_momentums + proposed_momentums)):
-    # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy
-    # since they're a mouthful and lets us inline more.
-    lk0, lk1 = [], []
-    for current_momentum, proposed_momentum in zip(current_momentums,
-                                                   proposed_momentums):
-      axis = math_ops.range(independent_chain_ndims,
-                            array_ops.rank(current_momentum))
-      lk0.append(_log_sum_sq(current_momentum, axis))
-      lk1.append(_log_sum_sq(proposed_momentum, axis))
-
-    lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1),
-                                                  axis=-1)
-    lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1),
-                                                  axis=-1)
-    lp0 = -current_target_log_prob   # log_potential
-    lp1 = -proposed_target_log_prob  # proposed_log_potential
-    x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)],
-                        axis=-1)
-
-    # The sum is NaN if any element is NaN or we see both +Inf and -Inf.
-    # Thus we will replace such rows with infinite energy change which implies
-    # rejection. Recall that float-comparisons with NaN are always False.
-    is_sum_determinate = (
-        math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) &
-        math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1))
-    is_sum_determinate = array_ops.tile(
-        is_sum_determinate[..., array_ops.newaxis],
-        multiples=array_ops.concat([
-            array_ops.ones(array_ops.rank(is_sum_determinate),
-                           dtype=dtypes.int32),
-            [4],
-        ], axis=0))
-    x = array_ops.where(is_sum_determinate,
-                        x,
-                        array_ops.fill(array_ops.shape(x),
-                                       value=x.dtype.as_numpy_dtype(np.inf)))
-
-    return math_ops.reduce_sum(x, axis=-1)
-
-
-def _choose(is_accepted,
-            accepted,
-            rejected,
-            independent_chain_ndims,
-            name=None):
-  """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where."""
-  def _expand_is_accepted_like(x):
-    with ops.name_scope("_choose"):
-      expand_shape = array_ops.concat([
-          array_ops.shape(is_accepted),
-          array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)],
-                         dtype=dtypes.int32),
-      ], axis=0)
-      multiples = array_ops.concat([
-          array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32),
-          array_ops.shape(x)[independent_chain_ndims:],
-      ], axis=0)
-      m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape),
-                         multiples)
-      m.set_shape(x.shape)
-      return m
-  with ops.name_scope(name, "_choose", values=[
-      is_accepted, accepted, rejected, independent_chain_ndims]):
-    return array_ops.where(_expand_is_accepted_like(accepted),
-                           accepted,
-                           rejected)
-
-
-def _maybe_call_fn_and_grads(fn,
-                             fn_arg_list,
-                             fn_result=None,
-                             grads_fn_result=None,
-                             description="target_log_prob"):
-  """Helper which computes `fn_result` and `grads` if needed."""
-  fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list)
-                 else [fn_arg_list])
-  if fn_result is None:
-    fn_result = fn(*fn_arg_list)
-  if not fn_result.dtype.is_floating:
-    raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format(
-        description))
-  if grads_fn_result is None:
-    grads_fn_result = gradients_ops.gradients(
-        fn_result, fn_arg_list)
-  if len(fn_arg_list) != len(grads_fn_result):
-    raise ValueError("`{}` must be in one-to-one correspondence with "
-                     "`grads_{}`".format(*[description]*2))
-  if any(g is None for g in grads_fn_result):
-    raise ValueError("Encountered `None` gradient.")
-  return fn_result, grads_fn_result
-
-
-def _prepare_args(target_log_prob_fn, state, step_size,
-                  target_log_prob=None, grads_target_log_prob=None,
-                  maybe_expand=False, description="target_log_prob"):
-  """Helper which processes input args to meet list-like assumptions."""
-  state_parts = list(state) if _is_list_like(state) else [state]
-  state_parts = [ops.convert_to_tensor(s, name="state")
-                 for s in state_parts]
-  target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads(
-      target_log_prob_fn,
-      state_parts,
-      target_log_prob,
-      grads_target_log_prob,
-      description)
-  step_sizes = list(step_size) if _is_list_like(step_size) else [step_size]
-  step_sizes = [
-      ops.convert_to_tensor(
-          s, name="step_size", dtype=target_log_prob.dtype)
-      for s in step_sizes]
-  if len(step_sizes) == 1:
-    step_sizes *= len(state_parts)
-  if len(state_parts) != len(step_sizes):
-    raise ValueError("There should be exactly one `step_size` or it should "
-                     "have same length as `current_state`.")
-  if maybe_expand:
-    maybe_flatten = lambda x: x
-  else:
-    maybe_flatten = lambda x: x if _is_list_like(state) else x[0]
-  return [
-      maybe_flatten(state_parts),
-      maybe_flatten(step_sizes),
-      target_log_prob,
-      grads_target_log_prob,
-  ]
-
-
-def _is_list_like(x):
-  """Helper which returns `True` if input is `list`-like."""
-  return isinstance(x, (tuple, list))
-
-
-def _log_sum_sq(x, axis=None):
-  """Computes log(sum(x**2))."""
-  return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis)
+  with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum,
+                                              grad]):
+    momentum -= 0.5 * step_size * grad
+    position += step_size * momentum
+    potential, grad = potential_and_grad(position)
+    momentum -= 0.5 * step_size * grad
+
+  return position, momentum, potential, grad