L2HMC trained with strongly correlated Gaussian. Simple testing
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 29 May 2018 19:01:11 +0000 (12:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 19:03:50 +0000 (12:03 -0700)
and benchmark with eager and graph mode execution.

PiperOrigin-RevId: 198433911

tensorflow/contrib/eager/python/examples/BUILD
tensorflow/contrib/eager/python/examples/l2hmc/BUILD [new file with mode: 0644]
tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py [new file with mode: 0644]
tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py [new file with mode: 0644]
tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py [new file with mode: 0644]

index c1fd9e0..1d9371c 100644 (file)
@@ -7,6 +7,8 @@ py_library(
     name = "examples_pip",
     deps = [
         "//tensorflow/contrib/eager/python/examples/gan:mnist",
+        "//tensorflow/contrib/eager/python/examples/l2hmc",
+        "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets",
         "//tensorflow/contrib/eager/python/examples/linear_regression",
         "//tensorflow/contrib/eager/python/examples/resnet50",
         "//tensorflow/contrib/eager/python/examples/rnn_colorbot",
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD
new file mode 100644 (file)
index 0000000..7bdf905
--- /dev/null
@@ -0,0 +1,39 @@
+licenses(["notice"])  # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_library(
+    name = "neural_nets",
+    srcs = ["neural_nets.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/contrib/eager/python:tfe",
+    ],
+)
+
+py_library(
+    name = "l2hmc",
+    srcs = ["l2hmc.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":neural_nets",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/contrib/eager/python:tfe",
+        "//third_party/py/numpy",
+    ],
+)
+
+cuda_py_test(
+    name = "l2hmc_test",
+    size = "large",
+    srcs = ["l2hmc_test.py"],
+    additional_deps = [
+        ":l2hmc",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/contrib/eager/python:tfe",
+        "//third_party/py/numpy",
+    ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
new file mode 100644 (file)
index 0000000..98b4ce1
--- /dev/null
@@ -0,0 +1,382 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""L2HMC compatible with TensorFlow's eager execution.
+
+Reference [Generalizing Hamiltonian Monte Carlo with Neural
+Networks](https://arxiv.org/pdf/1711.09268.pdf)
+
+Code adapted from the released TensorFlow graph implementation by original
+authors https://github.com/brain-research/l2hmc.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import numpy.random as npr
+import tensorflow as tf
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets
+
+
+class Dynamics(tf.keras.Model):
+  """Dynamics engine of naive L2HMC sampler.
+
+  Args:
+    x_dim: dimensionality of observed data
+    loglikelihood_fn: log-likelihood function of conditional probability
+    n_steps: number of leapfrog steps within each transition
+    eps: initial value learnable scale of step size
+  """
+
+  def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1):
+    super(Dynamics, self).__init__()
+
+    self.x_dim = x_dim
+    self.potential = loglikelihood_fn
+    self.n_steps = n_steps
+
+    self._construct_time()
+    self._construct_masks()
+
+    self.position_fn = neural_nets.GenericNet(x_dim, factor=2.)
+    self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.)
+
+    self.eps = tfe.Variable(
+        initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
+
+    # TODO(lxuechen): Remove this after model.add_weight is in place
+    self.vars_not_in_layers = [self.eps]
+    self.vars_not_in_layers += self.position_fn.vars_not_in_layers
+    self.vars_not_in_layers += self.momentum_fn.vars_not_in_layers
+
+  def apply_transition(self, position):
+    """Propose a new state and perform the accept or reject step."""
+
+    # Simulate dynamics both forward and backward;
+    # Use sampled Bernoulli masks to compute the actual solutions
+    position_f, momentum_f, accept_prob_f = self.transition_kernel(
+        position, forward=True)
+    position_b, momentum_b, accept_prob_b = self.transition_kernel(
+        position, forward=False)
+
+    # Decide direction uniformly
+    forward_mask = tf.cast(
+        tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32)
+    backward_mask = 1. - forward_mask
+
+    # Obtain proposed states
+    position_post = (
+        forward_mask[:, None] * position_f +
+        backward_mask[:, None] * position_b)
+    momentum_post = (
+        forward_mask[:, None] * momentum_f +
+        backward_mask[:, None] * momentum_b)
+
+    # Probability of accepting the proposed states
+    accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b
+
+    # Accept or reject step
+    accept_mask = tf.cast(
+        accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32)
+    reject_mask = 1. - accept_mask
+
+    # Samples after accept/reject step
+    position_out = (
+        accept_mask[:, None] * position_post + reject_mask[:, None] * position)
+
+    return position_post, momentum_post, accept_prob, position_out
+
+  def transition_kernel(self, position, forward=True):
+    """Transition kernel of augmented leapfrog integrator."""
+
+    lf_fn = self._forward_lf if forward else self._backward_lf
+
+    # Resample momentum
+    momentum = tf.random_normal(tf.shape(position))
+    position_post, momentum_post = position, momentum
+    sumlogdet = 0.
+    # Apply augmented leapfrog steps
+    for i in range(self.n_steps):
+      position_post, momentum_post, logdet = lf_fn(position_post, momentum_post,
+                                                   i)
+      sumlogdet += logdet
+
+    accept_prob = self._compute_accept_prob(position, momentum, position_post,
+                                            momentum_post, sumlogdet)
+
+    return position_post, momentum_post, accept_prob
+
+  def _forward_lf(self, position, momentum, i):
+    """One forward augmented leapfrog step. See eq (5-6) in paper."""
+
+    t = self._get_time(i)
+    mask, mask_inv = self._get_mask(i)
+    sumlogdet = 0.
+
+    momentum, logdet = self._update_momentum_forward(position, momentum, t)
+    sumlogdet += logdet
+
+    position, logdet = self._update_position_forward(position, momentum, t,
+                                                     mask)
+    sumlogdet += logdet
+
+    position, logdet = self._update_position_forward(position, momentum, t,
+                                                     mask_inv)
+    sumlogdet += logdet
+
+    momentum, logdet = self._update_momentum_forward(position, momentum, t)
+    sumlogdet += logdet
+
+    return position, momentum, tf.reduce_sum(sumlogdet, axis=1)
+
+  def _backward_lf(self, position, momentum, i):
+    """One backward augmented leapfrog step. See Appendix A in paper."""
+
+    # Reversed index/sinusoidal time
+    t = self._get_time(self.n_steps - i - 1)
+    mask, mask_inv = self._get_mask(self.n_steps - i - 1)
+    sumlogdet = 0.
+
+    momentum, logdet = self._update_momentum_backward(position, momentum, t)
+    sumlogdet += logdet
+
+    position, logdet = self._update_position_backward(position, momentum, t,
+                                                      mask)
+    sumlogdet += logdet
+
+    position, logdet = self._update_position_backward(position, momentum, t,
+                                                      mask_inv)
+    sumlogdet += logdet
+
+    momentum, logdet = self._update_momentum_backward(position, momentum, t)
+    sumlogdet += logdet
+
+    return position, momentum, tf.reduce_sum(sumlogdet, axis=1)
+
+  def _update_momentum_forward(self, position, momentum, t):
+    """Update v in the forward leapfrog step."""
+
+    grad = self.grad_potential(position)
+    scale, translation, transformed = self.momentum_fn([position, grad, t])
+    scale *= .5 * self.eps
+    transformed *= self.eps
+    momentum = (
+        momentum * tf.exp(scale) -
+        .5 * self.eps * (tf.exp(transformed) * grad - translation))
+
+    return momentum, scale
+
+  def _update_position_forward(self, position, momentum, t, mask):
+    """Update x in the forward leapfrog step."""
+
+    mask_inv = 1. - mask
+    scale, translation, transformed = self.position_fn(
+        [momentum, mask * position, t])
+    scale *= self.eps
+    transformed *= self.eps
+    position = (
+        mask * position +
+        mask_inv * (position * tf.exp(scale) + self.eps *
+                    (tf.exp(transformed) * momentum + translation)))
+
+    return position, mask_inv * scale
+
+  def _update_momentum_backward(self, position, momentum, t):
+    """Update v in the backward leapfrog step. Inverting the forward update."""
+
+    grad = self.grad_potential(position)
+    scale, translation, transformed = self.momentum_fn([position, grad, t])
+    scale *= -.5 * self.eps
+    transformed *= self.eps
+    momentum = (
+        tf.exp(scale) * (momentum + .5 * self.eps *
+                         (tf.exp(transformed) * grad - translation)))
+
+    return momentum, scale
+
+  def _update_position_backward(self, position, momentum, t, mask):
+    """Update x in the backward leapfrog step. Inverting the forward update."""
+
+    mask_inv = 1. - mask
+    scale, translation, transformed = self.position_fn(
+        [momentum, mask_inv * position, t])
+    scale *= -self.eps
+    transformed *= self.eps
+    position = (
+        mask_inv * position + mask * tf.exp(scale) *
+        (position - self.eps * tf.exp(transformed) * momentum + translation))
+
+    return position, mask * scale
+
+  def _compute_accept_prob(self, position, momentum, position_post,
+                           momentum_post, sumlogdet):
+    """Compute the prob of accepting the proposed state given old state."""
+
+    old_hamil = self.hamiltonian(position, momentum)
+    new_hamil = self.hamiltonian(position_post, momentum_post)
+
+    return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.))
+
+  def _construct_time(self):
+    """Convert leapfrog step index into sinusoidal time."""
+
+    self.ts = []
+    for i in range(self.n_steps):
+      t = tf.constant(
+          [
+              np.cos(2 * np.pi * i / self.n_steps),
+              np.sin(2 * np.pi * i / self.n_steps)
+          ],
+          dtype=tf.float32)
+      self.ts.append(t[None, :])
+
+  def _get_time(self, i):
+    """Get sinusoidal time for i-th augmented leapfrog step."""
+
+    return self.ts[i]
+
+  def _construct_masks(self):
+    """Construct different binary masks for different time steps."""
+
+    self.masks = []
+    for _ in range(self.n_steps):
+      idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2]
+      mask = np.zeros((self.x_dim,))
+      mask[idx] = 1.
+      mask = tf.constant(mask, dtype=tf.float32)
+      self.masks.append(mask[None, :])
+
+  def _get_mask(self, i):
+    """Get binary masks for i-th augmented leapfrog step."""
+
+    m = self.masks[i]
+    return m, 1. - m
+
+  def kinetic(self, v):
+    """Compute the kinetic energy."""
+
+    return .5 * tf.reduce_sum(v**2, axis=1)
+
+  def hamiltonian(self, position, momentum):
+    """Compute the overall Hamiltonian."""
+
+    return self.potential(position) + self.kinetic(momentum)
+
+  def grad_potential(self, position, check_numerics=True):
+    """Get gradient of potential function at current location."""
+
+    if not tf.executing_eagerly():
+      # TODO(lxuechen): Change this to tfe.gradients_function when it works
+      grad = tf.gradients(self.potential(position), position)[0]
+    else:
+      grad = tfe.gradients_function(self.potential)(position)[0]
+
+    if check_numerics:
+      return tf.check_numerics(grad, message="gradient of potential")
+
+    return grad
+
+
+# Defining loss and grads for training
+def compute_loss(x, dynamics, scale=.1, eps=1e-4):
+  """Compute loss defined in equation (8)."""
+
+  z = tf.random_normal(tf.shape(x))
+  x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
+  z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
+
+  # Add eps for numerical stability; following released impl
+  x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
+  z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
+
+  loss = tf.reduce_mean(
+      (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
+
+  return loss, x_out
+
+
+def loss_and_grads(x, dynamics):
+  """Obtain loss value and gradients."""
+
+  with tf.GradientTape() as tape:
+    loss_val, x_out = compute_loss(x, dynamics)
+
+  vars_ = dynamics.variables + dynamics.vars_not_in_layers
+  grads = tape.gradient(loss_val, vars_)
+
+  return loss_val, grads, x_out
+
+
+def warmup(dynamics, optimizer, n_iters=1, n_samples=200):
+  """Warmup optimization to reduce overhead."""
+
+  samples = tf.random_normal(
+      shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+
+  for _ in range(n_iters):
+    _, grads, samples = loss_and_grads(samples, dynamics)
+    vars_ = dynamics.variables + dynamics.vars_not_in_layers
+    optimizer.apply_gradients(zip(grads, vars_))
+
+
+def fit(dynamics,
+        optimizer,
+        n_samples=200,
+        n_iters=5000,
+        verbose=True,
+        logdir=None):
+  """Fit L2HMC sampler with given log-likelihood function."""
+
+  if logdir:
+    summary_writer = tf.contrib.summary.create_file_writer(logdir)
+
+  samples = tf.random_normal(
+      shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+
+  tf.train.get_or_create_global_step()
+  for i in range(n_iters):
+    loss, grads, samples = loss_and_grads(samples, dynamics)
+    # TODO(lxuechen): Proper learning rate decay
+    grads_ = [grad * .96**(i // 1000) for grad in grads]
+    vars_ = dynamics.variables + dynamics.vars_not_in_layers
+    optimizer.apply_gradients(
+        zip(grads_, vars_), global_step=tf.train.get_global_step())
+
+    if verbose:
+      print("Iteration %d: loss %.4f" % (i, loss))
+
+    if logdir:
+      with summary_writer.as_default():
+        with tf.contrib.summary.always_record_summaries():
+          tf.contrib.summary.scalar("loss", loss)
+
+
+def get_scg_energy_fn():
+  """Get energy function for 2d strongly correlated Gaussian."""
+
+  # Avoid recreating tf constants on each invocation of gradients
+  mu = tf.constant([0., 0.])
+  sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
+  sigma_inv = tf.matrix_inverse(sigma)
+
+  def energy(x):
+    """Unnormalized log density/energy of 2d strongly correlated Gaussian."""
+
+    xmmu = x - mu
+    return .5 * tf.diag_part(
+        tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
+
+  return energy
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
new file mode 100644 (file)
index 0000000..522a7c9
--- /dev/null
@@ -0,0 +1,162 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy.random as npr
+import tensorflow as tf
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc
+
+
+def get_default_hparams():
+  return tf.contrib.training.HParams(
+      x_dim=2,
+      n_samples=200,
+      n_steps=10,
+      eps=.1,
+      n_iters=5,
+      learning_rate=.001,
+      n_warmup_iters=1)
+
+
+class L2hmcTest(tf.test.TestCase):
+  """Unit tests for l2hmc in both eager and graph mode."""
+
+  def testComputeLoss(self):
+    """Testing function l2hmc.compute_loss in both graph and eager mode."""
+
+    # Eager mode testing
+    hparams = get_default_hparams()
+    dynamics = l2hmc.Dynamics(
+        x_dim=hparams.x_dim,
+        loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+        n_steps=hparams.n_steps,
+        eps=hparams.eps)
+    samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim])
+    loss, x_out = l2hmc.compute_loss(samples, dynamics)
+
+    # Check shape and numerical stability
+    self.assertEqual(x_out.shape, samples.shape)
+    self.assertEqual(loss.shape, [])
+    self.assertAllClose(loss.numpy(), loss.numpy(), rtol=1e-5)
+
+    # Graph mode testing
+    with tf.Graph().as_default():
+      dynamics = l2hmc.Dynamics(
+          x_dim=hparams.x_dim,
+          loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+          n_steps=hparams.n_steps,
+          eps=hparams.eps)
+      x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
+      loss, x_out = l2hmc.compute_loss(x, dynamics)
+      samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
+
+      with tf.Session() as sess:
+        sess.run(tf.global_variables_initializer())
+        loss_np, x_out_np = sess.run([loss, x_out], feed_dict={x: samples})
+
+        # Check shape and numerical stability
+        self.assertEqual(x_out_np.shape, samples.shape)
+        self.assertEqual(loss_np.shape, ())
+        self.assertAllClose(loss_np, loss_np, rtol=1e-5)
+
+
+class L2hmcBenchmark(tf.test.Benchmark):
+  """Eager and graph benchmarks for l2hmc."""
+
+  def benchmarkEagerL2hmc(self):
+    """Benchmark Eager performance."""
+
+    hparams = get_default_hparams()
+    dynamics = l2hmc.Dynamics(
+        x_dim=hparams.x_dim,
+        loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+        n_steps=hparams.n_steps,
+        eps=hparams.eps)
+    # TODO(lxuechen): Add learning rate decay
+    optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+
+    # Warmup to reduce initialization effect when timing
+    l2hmc.warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters)
+
+    # Time
+    start_time = time.time()
+    l2hmc.fit(
+        dynamics,
+        optimizer,
+        n_samples=hparams.n_samples,
+        n_iters=hparams.n_iters)
+    wall_time = time.time() - start_time
+    examples_per_sec = hparams.n_samples / wall_time
+
+    self.report_benchmark(
+        name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"),
+        iters=hparams.n_iters,
+        extras={"examples_per_sec": examples_per_sec},
+        wall_time=wall_time)
+
+  def benchmarkGraphL2hmc(self):
+    """Benchmark Graph performance."""
+
+    hparams = get_default_hparams()
+    with tf.Graph().as_default():
+      dynamics = l2hmc.Dynamics(
+          x_dim=hparams.x_dim,
+          loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+          n_steps=hparams.n_steps,
+          eps=hparams.eps)
+      x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
+      loss, x_out = l2hmc.compute_loss(x, dynamics)
+
+      global_step = tf.Variable(0., name="global_step", trainable=False)
+      learning_rate = tf.train.exponential_decay(
+          hparams.learning_rate, global_step, 1000, 0.96, staircase=True)
+      optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
+      train_op = optimizer.minimize(loss, global_step=global_step)
+
+      with tf.Session() as sess:
+        sess.run(tf.global_variables_initializer())
+
+        # Warmup to reduce initialization effect when timing
+        samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
+        for _ in range(hparams.n_warmup_iters):
+          samples, _, _, _ = sess.run(
+              [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
+
+        # Time
+        start_time = time.time()
+        for _ in range(hparams.n_iters):
+          samples, _, _, _ = sess.run(
+              [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
+        wall_time = time.time() - start_time
+        examples_per_sec = hparams.n_samples / wall_time
+
+        self.report_benchmark(
+            name="graph_train_%s" % ("gpu"
+                                     if tf.test.is_gpu_available() else "cpu"),
+            iters=hparams.n_iters,
+            extras={"examples_per_sec": examples_per_sec},
+            wall_time=wall_time)
+
+
+if __name__ == "__main__":
+  tf.enable_eager_execution()
+  tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
new file mode 100644 (file)
index 0000000..c902e1f
--- /dev/null
@@ -0,0 +1,86 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution.
+
+Reference [Generalizing Hamiltonian Monte Carlo with Neural
+Networks](https://arxiv.org/pdf/1711.09268.pdf)
+
+Code adapted from the released TensorFlow graph implementation by original
+authors https://github.com/brain-research/l2hmc.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import tensorflow.contrib.eager as tfe
+
+
+class GenericNet(tf.keras.Model):
+  """Generic neural net with different initialization scale based on input.
+
+  Args:
+    x_dim: dimensionality of observed data
+    factor: factor of variance scaling initializer
+    n_hidden: number of hidden units
+  """
+
+  def __init__(self, x_dim, factor, n_hidden=10):
+    super(GenericNet, self).__init__()
+
+    self.v_layer = _custom_dense(n_hidden, 1. / 3.)
+    self.x_layer = _custom_dense(n_hidden, factor / 3.)
+    self.t_layer = _custom_dense(n_hidden, 1. / 3.)
+    self.h_layer = _custom_dense(n_hidden)
+
+    # Scale
+    self.scale_layer = _custom_dense(x_dim, .001)
+    self.coeff_scale = tfe.Variable(
+        initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True)
+    # Translation
+    self.translation_layer = _custom_dense(x_dim, factor=.001)
+    # Transformation
+    self.transformation_layer = _custom_dense(x_dim, .001)
+    self.coeff_transformation = tfe.Variable(
+        initial_value=tf.zeros([1, x_dim]),
+        name='coeff_transformation',
+        trainable=True)
+    # TODO(lxuechen): Remove this after model.add_weight is in place
+    self.vars_not_in_layers = [self.coeff_scale, self.coeff_transformation]
+
+  def call(self, inputs):
+    v, x, t = inputs
+    h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t)
+    h = tf.nn.relu(h)
+    h = self.h_layer(h)
+    h = tf.nn.relu(h)
+    scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale)
+    translation = self.translation_layer(h)
+    transformation = (
+        tf.nn.tanh(self.transformation_layer(h)) * tf.exp(
+            self.coeff_transformation))
+
+    return scale, translation, transformation
+
+
+def _custom_dense(units, factor=1.):
+  """Custom dense layer with specified weight initialization."""
+
+  return tf.keras.layers.Dense(
+      units=units,
+      use_bias=True,
+      kernel_initializer=tf.contrib.layers.variance_scaling_initializer(
+          factor=factor * 2., mode='FAN_IN', uniform=False),
+      bias_initializer=tf.constant_initializer(0., dtype=tf.float32))