Add probabilistic convolutional layers.
authorDustin Tran <trandustin@google.com>
Tue, 19 Dec 2017 01:15:01 +0000 (17:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 19 Dec 2017 01:18:50 +0000 (17:18 -0800)
PiperOrigin-RevId: 179490700

tensorflow/contrib/bayesflow/BUILD
tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py [new file with mode: 0644]
tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
tensorflow/contrib/bayesflow/python/ops/layers.py
tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py [new file with mode: 0644]
tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py
tensorflow/contrib/bayesflow/python/ops/layers_util.py [new file with mode: 0644]

index a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9..4e0520fa33a57e2f15c39d362ec3a39945202d46 100644 (file)
@@ -99,6 +99,25 @@ cuda_py_test(
     ],
 )
 
+cuda_py_test(
+    name = "layers_conv_variational_test",
+    size = "small",
+    srcs = ["python/kernel_tests/layers_conv_variational_test.py"],
+    additional_deps = [
+        ":bayesflow_py",
+        "//third_party/py/numpy",
+        "//tensorflow/contrib/distributions:distributions_py",
+        "//tensorflow/python/ops/distributions",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:gradients",
+        "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:nn_ops",
+    ],
+)
+
 cuda_py_test(
     name = "layers_dense_variational_test",
     size = "small",
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py
new file mode 100644 (file)
index 0000000..57f44ae
--- /dev/null
@@ -0,0 +1,289 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for convolutional Bayesian layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib
+from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.platform import test
+
+
+class Counter(object):
+  """Helper class to manage incrementing a counting `int`."""
+
+  def __init__(self):
+    self._value = -1
+
+  @property
+  def value(self):
+    return self._value
+
+  def __call__(self):
+    self._value += 1
+    return self._value
+
+
+class MockDistribution(independent_lib.Independent):
+  """Monitors DenseVariational calls to the underlying distribution."""
+
+  def __init__(self, result_sample, result_log_prob, loc=None, scale=None):
+    self.result_sample = result_sample
+    self.result_log_prob = result_log_prob
+    self.result_loc = loc
+    self.result_scale = scale
+    self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0)
+    if loc is not None and scale is not None:
+      self.result_distribution = normal_lib.Normal(loc=self.result_loc,
+                                                   scale=self.result_scale)
+    self.called_log_prob = Counter()
+    self.called_sample = Counter()
+    self.called_loc = Counter()
+    self.called_scale = Counter()
+
+  def log_prob(self, *args, **kwargs):
+    self.called_log_prob()
+    return self.result_log_prob
+
+  def sample(self, *args, **kwargs):
+    self.called_sample()
+    return self.result_sample
+
+  @property
+  def distribution(self):  # for dummy check on Independent(Normal)
+    return self.result_distribution
+
+  @property
+  def loc(self):
+    self.called_loc()
+    return self.result_loc
+
+  @property
+  def scale(self):
+    self.called_scale()
+    return self.result_scale
+
+
+class MockKLDivergence(object):
+  """Monitors layer calls to the divergence implementation."""
+
+  def __init__(self, result):
+    self.result = result
+    self.args = []
+    self.called = Counter()
+
+  def __call__(self, *args, **kwargs):
+    self.called()
+    self.args.append(args)
+    return self.result
+
+
+class ConvVariational(test.TestCase):
+
+  def _testKLPenaltyKernel(self, layer_class):
+    with self.test_session():
+      layer = layer_class(filters=2, kernel_size=3)
+      if layer_class == prob_layers_lib.Conv1DVariational:
+        inputs = random_ops.random_uniform([2, 3, 1], seed=1)
+      elif layer_class == prob_layers_lib.Conv2DVariational:
+        inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
+      elif layer_class == prob_layers_lib.Conv3DVariational:
+        inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
+
+      # No keys.
+      losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+      self.assertEqual(len(losses), 0)
+      self.assertListEqual(layer.losses, losses)
+
+      _ = layer(inputs)
+
+      # Yes keys.
+      losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+      self.assertEqual(len(losses), 1)
+      self.assertListEqual(layer.losses, losses)
+
+  def _testKLPenaltyBoth(self, layer_class):
+    def _make_normal(dtype, *args):  # pylint: disable=unused-argument
+      return normal_lib.Normal(
+          loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.))
+    with self.test_session():
+      layer = layer_class(
+          filters=2,
+          kernel_size=3,
+          bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(),
+          bias_prior_fn=_make_normal)
+      if layer_class == prob_layers_lib.Conv1DVariational:
+        inputs = random_ops.random_uniform([2, 3, 1], seed=1)
+      elif layer_class == prob_layers_lib.Conv2DVariational:
+        inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
+      elif layer_class == prob_layers_lib.Conv3DVariational:
+        inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
+
+      # No keys.
+      losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+      self.assertEqual(len(losses), 0)
+      self.assertListEqual(layer.losses, losses)
+
+      _ = layer(inputs)
+
+      # Yes keys.
+      losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+      self.assertEqual(len(losses), 2)
+      self.assertListEqual(layer.losses, losses)
+
+  def _testConvVariational(self, layer_class):
+    batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
+    with self.test_session() as sess:
+      seed = Counter()
+      if layer_class == prob_layers_lib.Conv1DVariational:
+        inputs = random_ops.random_uniform(
+            [batch_size, width, channels], seed=seed())
+        kernel_size = (2,)
+      elif layer_class == prob_layers_lib.Conv2DVariational:
+        inputs = random_ops.random_uniform(
+            [batch_size, height, width, channels], seed=seed())
+        kernel_size = (2, 2)
+      elif layer_class == prob_layers_lib.Conv3DVariational:
+        inputs = random_ops.random_uniform(
+            [batch_size, depth, height, width, channels], seed=seed())
+        kernel_size = (2, 2, 2)
+
+      kernel_shape = kernel_size + (channels, filters)
+      kernel_posterior = MockDistribution(
+          result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
+          result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
+      kernel_prior = MockDistribution(
+          result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
+          result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
+      kernel_divergence = MockKLDivergence(
+          result=random_ops.random_uniform(kernel_shape, seed=seed()))
+
+      bias_size = (filters,)
+      bias_posterior = MockDistribution(
+          result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
+          result_sample=random_ops.random_uniform(bias_size, seed=seed()))
+      bias_prior = MockDistribution(
+          result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
+          result_sample=random_ops.random_uniform(bias_size, seed=seed()))
+      bias_divergence = MockKLDivergence(
+          result=random_ops.random_uniform(bias_size, seed=seed()))
+
+      convolution_op = nn_ops.Convolution(
+          tensor_shape.TensorShape(inputs.shape),
+          filter_shape=tensor_shape.TensorShape(kernel_shape),
+          padding="SAME")
+      expected_outputs = convolution_op(inputs, kernel_posterior.result_sample)
+      expected_outputs = nn.bias_add(expected_outputs,
+                                     bias_posterior.result_sample,
+                                     data_format="NHWC")
+
+      layer = layer_class(
+          filters=filters,
+          kernel_size=kernel_size,
+          padding="SAME",
+          kernel_posterior_fn=lambda *args: kernel_posterior,
+          kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
+          kernel_prior_fn=lambda *args: kernel_prior,
+          kernel_divergence_fn=kernel_divergence,
+          bias_posterior_fn=lambda *args: bias_posterior,
+          bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
+          bias_prior_fn=lambda *args: bias_prior,
+          bias_divergence_fn=bias_divergence)
+
+      outputs = layer(inputs)
+
+      kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+
+      [
+          expected_outputs_, actual_outputs_,
+          expected_kernel_, actual_kernel_,
+          expected_kernel_divergence_, actual_kernel_divergence_,
+          expected_bias_, actual_bias_,
+          expected_bias_divergence_, actual_bias_divergence_,
+      ] = sess.run([
+          expected_outputs, outputs,
+          kernel_posterior.result_sample, layer.kernel_posterior_tensor,
+          kernel_divergence.result, kl_penalty[0],
+          bias_posterior.result_sample, layer.bias_posterior_tensor,
+          bias_divergence.result, kl_penalty[1],
+      ])
+
+      self.assertAllClose(
+          expected_kernel_, actual_kernel_,
+          rtol=1e-6, atol=0.)
+      self.assertAllClose(
+          expected_bias_, actual_bias_,
+          rtol=1e-6, atol=0.)
+      self.assertAllClose(
+          expected_outputs_, actual_outputs_,
+          rtol=1e-6, atol=0.)
+      self.assertAllClose(
+          expected_kernel_divergence_, actual_kernel_divergence_,
+          rtol=1e-6, atol=0.)
+      self.assertAllClose(
+          expected_bias_divergence_, actual_bias_divergence_,
+          rtol=1e-6, atol=0.)
+
+      self.assertAllEqual(
+          [[kernel_posterior.distribution,
+            kernel_prior.distribution,
+            kernel_posterior.result_sample]],
+          kernel_divergence.args)
+
+      self.assertAllEqual(
+          [[bias_posterior.distribution,
+            bias_prior.distribution,
+            bias_posterior.result_sample]],
+          bias_divergence.args)
+
+  def testKLPenaltyKernelConv1DVariational(self):
+    self._testKLPenaltyKernel(prob_layers_lib.Conv1DVariational)
+
+  def testKLPenaltyKernelConv2DVariational(self):
+    self._testKLPenaltyKernel(prob_layers_lib.Conv2DVariational)
+
+  def testKLPenaltyKernelConv3DVariational(self):
+    self._testKLPenaltyKernel(prob_layers_lib.Conv3DVariational)
+
+  def testKLPenaltyBothConv1DVariational(self):
+    self._testKLPenaltyBoth(prob_layers_lib.Conv1DVariational)
+
+  def testKLPenaltyBothConv2DVariational(self):
+    self._testKLPenaltyBoth(prob_layers_lib.Conv2DVariational)
+
+  def testKLPenaltyBothConv3DVariational(self):
+    self._testKLPenaltyBoth(prob_layers_lib.Conv3DVariational)
+
+  def testConv1DVariational(self):
+    self._testConvVariational(prob_layers_lib.Conv1DVariational)
+
+  def testConv2DVariational(self):
+    self._testConvVariational(prob_layers_lib.Conv2DVariational)
+
+  def testConv3DVariational(self):
+    self._testConvVariational(prob_layers_lib.Conv3DVariational)
+
+
+if __name__ == "__main__":
+  test.main()
index 5371e912ed700ea4a15bf099ccacf1d0a0cfab69..4e9f1193511c35beead85914ca988fde69b3afde 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib
+from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util
 from tensorflow.contrib.distributions.python.ops import independent as independent_lib
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -127,7 +128,7 @@ class DenseVariational(test.TestCase):
     with self.test_session():
       layer = layer_class(
           units=2,
-          bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(),
+          bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(),
           bias_prior_fn=_make_normal)
       inputs = random_ops.random_uniform([2, 3], seed=1)
 
@@ -345,7 +346,7 @@ class DenseVariational(test.TestCase):
           maxval=2,
           dtype=dtypes.int32,
           seed=distribution_util.gen_new_seed(
-              layer.seed, salt="conv_variational"))
+              layer.seed, salt="dense_flipout"))
       sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype)
       perturbed_inputs = math_ops.matmul(
           inputs * sign_input, expected_kernel_posterior_affine_tensor)
index 121f36ec4ecb317ca0cd85075421ed136721f516..93412afae738564d440065f230c9df0036591467 100644 (file)
@@ -23,11 +23,25 @@ from __future__ import print_function
 
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import *
 from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import *
+from tensorflow.contrib.bayesflow.python.ops.layers_util import *
 # pylint: enable=wildcard-import
 from tensorflow.python.util.all_util import remove_undocumented
 
 _allowed_symbols = [
+    'Convolution1DVariational',
+    'Convolution2DVariational',
+    'Convolution3DVariational',
+    'Conv1DVariational',
+    'Conv2DVariational',
+    'Conv3DVariational',
+    'convolution1d_variational',
+    'convolution2d_variational',
+    'convolution3d_variational',
+    'conv1d_variational',
+    'conv2d_variational',
+    'conv3d_variational',
     'DenseReparameterization',
     'DenseLocalReparameterization',
     'DenseFlipout',
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
new file mode 100644 (file)
index 0000000..6ffb55f
--- /dev/null
@@ -0,0 +1,1415 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Convolutional variational layer classes and their functional aliases.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bayesflow.python.ops import layers_util
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import base as layers_lib
+from tensorflow.python.layers import utils
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
+from tensorflow.python.ops.distributions import normal as normal_lib
+
+
+class _ConvVariational(layers_lib.Layer):
+  """Abstract nD convolution layer (private, used as implementation base).
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of n integers, specifying the
+      length of the convolution window.
+    strides: An integer or tuple/list of n integers,
+      specifying the stride length of the convolution.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, ..., channels)` while `channels_first` corresponds to
+      inputs with shape `(batch, channels, ...)`.
+    dilation_rate: An integer or tuple/list of n integers, specifying
+      the dilation rate to use for dilated convolution.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any `strides` value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    name: A string, the name of the layer.
+
+  Properties:
+    rank: Python integer, dimensionality of convolution.
+    filters: Python integer, dimensionality of the output space.
+    kernel_size: Size of the convolution window.
+    strides: Stride length of convolution.
+    padding: Python string describing padding approach.
+    data_format: Python string describing input data's dimensions.
+    dilation_rate: Dilation rate for an atrous convolution.
+    activation: Activation function (`callable`).
+    activity_regularizer: Regularizer function for the output.
+    kernel_use_local_reparameterization: Python `bool` indicating whether
+      `kernel` calculation should employ the Local Reparameterization Trick.
+    kernel_posterior_fn: `callable` returning posterior.
+    kernel_posterior_tensor_fn: `callable` operating on posterior.
+    kernel_prior_fn: `callable` returning prior.
+    kernel_divergence_fn: `callable` returning divergence.
+    bias_posterior_fn: `callable` returning posterior.
+    bias_posterior_tensor_fn: `callable` operating on posterior.
+    bias_prior_fn: `callable` returning prior.
+    bias_divergence_fn: `callable` returning divergence.
+  """
+
+  def __init__(
+      self,
+      rank,
+      filters,
+      kernel_size,
+      strides=1,
+      padding="valid",
+      data_format="channels_last",
+      dilation_rate=1,
+      activation=None,
+      activity_regularizer=None,
+      trainable=True,
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+      kernel_posterior_tensor_fn=lambda d: d.sample(),
+      kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+          loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+      kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+      bias_posterior_tensor_fn=lambda d: d.sample(),
+      bias_prior_fn=None,
+      bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      name=None,
+      **kwargs):
+    super(_ConvVariational, self).__init__(
+        trainable=trainable,
+        name=name,
+        activity_regularizer=activity_regularizer,
+        **kwargs)
+    self.rank = rank
+    self.filters = filters
+    self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size")
+    self.strides = utils.normalize_tuple(strides, rank, "strides")
+    self.padding = utils.normalize_padding(padding)
+    self.data_format = utils.normalize_data_format(data_format)
+    self.dilation_rate = utils.normalize_tuple(
+        dilation_rate, rank, "dilation_rate")
+    self.activation = activation
+    self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2)
+    self.kernel_posterior_fn = kernel_posterior_fn
+    self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
+    self.kernel_prior_fn = kernel_prior_fn
+    self.kernel_divergence_fn = kernel_divergence_fn
+    self.bias_posterior_fn = bias_posterior_fn
+    self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
+    self.bias_prior_fn = bias_prior_fn
+    self.bias_divergence_fn = bias_divergence_fn
+
+  def build(self, input_shape):
+    input_shape = tensor_shape.TensorShape(input_shape)
+    if self.data_format == "channels_first":
+      channel_axis = 1
+    else:
+      channel_axis = -1
+    if input_shape[channel_axis].value is None:
+      raise ValueError("The channel dimension of the inputs "
+                       "should be defined. Found `None`.")
+    input_dim = input_shape[channel_axis].value
+    kernel_shape = self.kernel_size + (input_dim, self.filters)
+    dtype = dtypes.as_dtype(self.dtype)
+
+    # Must have a posterior kernel.
+    self.kernel_posterior = self.kernel_posterior_fn(
+        dtype, kernel_shape, "kernel_posterior",
+        self.trainable, self.add_variable)
+
+    if self.kernel_prior_fn is None:
+      self.kernel_prior = None
+    else:
+      self.kernel_prior = self.kernel_prior_fn(
+          dtype, kernel_shape, "kernel_prior",
+          self.trainable, self.add_variable)
+    self._built_kernel_divergence = False
+
+    if self.bias_posterior_fn is None:
+      self.bias_posterior = None
+    else:
+      self.bias_posterior = self.bias_posterior_fn(
+          dtype, (self.filters,), "bias_posterior",
+          self.trainable, self.add_variable)
+
+    if self.bias_prior_fn is None:
+      self.bias_prior = None
+    else:
+      self.bias_prior = self.bias_prior_fn(
+          dtype, (self.filters,), "bias_prior",
+          self.trainable, self.add_variable)
+    self._built_bias_divergence = False
+
+    self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2,
+                                           axes={channel_axis: input_dim})
+    self._convolution_op = nn_ops.Convolution(
+        input_shape,
+        filter_shape=tensor_shape.TensorShape(kernel_shape),
+        dilation_rate=self.dilation_rate,
+        strides=self.strides,
+        padding=self.padding.upper(),
+        data_format=utils.convert_data_format(self.data_format,
+                                              self.rank + 2))
+
+    self.built = True
+
+  def call(self, inputs):
+    inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
+
+    outputs = self._apply_variational_kernel(inputs)
+    outputs = self._apply_variational_bias(outputs)
+    if self.activation is not None:
+      outputs = self.activation(outputs)
+    if not self._built_kernel_divergence:
+      kernel_posterior = self.kernel_posterior
+      kernel_prior = self.kernel_prior
+      if isinstance(self.kernel_posterior, independent_lib.Independent):
+        kernel_posterior = kernel_posterior.distribution
+      if isinstance(self.kernel_prior, independent_lib.Independent):
+        kernel_prior = kernel_prior.distribution
+      self._apply_divergence(self.kernel_divergence_fn,
+                             kernel_posterior,
+                             kernel_prior,
+                             self.kernel_posterior_tensor,
+                             name="divergence_kernel")
+      self._built_kernel_divergence = True
+    if not self._built_bias_divergence:
+      bias_posterior = self.bias_posterior
+      bias_prior = self.bias_prior
+      if isinstance(self.bias_posterior, independent_lib.Independent):
+        bias_posterior = bias_posterior.distribution
+      if isinstance(self.bias_prior, independent_lib.Independent):
+        bias_prior = bias_prior.distribution
+      self._apply_divergence(self.bias_divergence_fn,
+                             bias_posterior,
+                             bias_prior,
+                             self.bias_posterior_tensor,
+                             name="divergence_bias")
+      self._built_bias_divergence = True
+    return outputs
+
+  def _apply_variational_kernel(self, inputs):
+    self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn(
+        self.kernel_posterior)
+    outputs = self._convolution_op(inputs, self.kernel_posterior_tensor)
+    return outputs
+
+  def _apply_variational_bias(self, inputs):
+    if self.bias_posterior is None:
+      self.bias_posterior_tensor = None
+      return inputs
+    self.bias_posterior_tensor = self.bias_posterior_tensor_fn(
+        self.bias_posterior)
+    outputs = inputs
+    if self.data_format == "channels_first":
+      if self.rank == 1:
+        # nn.bias_add does not accept a 1D input tensor.
+        bias = array_ops.reshape(self.bias_posterior_tensor,
+                                 (1, self.filters, 1))
+        outputs += bias
+      if self.rank == 2:
+        outputs = nn.bias_add(outputs,
+                              self.bias_posterior_tensor,
+                              data_format="NCHW")
+      if self.rank == 3:
+        # As of Mar 2017, direct addition is significantly slower than
+        # bias_add when computing gradients. To use bias_add, we collapse Z
+        # and Y into a single dimension to obtain a 4D input tensor.
+        outputs_shape = outputs.shape.as_list()
+        outputs_4d = array_ops.reshape(outputs,
+                                       [outputs_shape[0], outputs_shape[1],
+                                        outputs_shape[2] * outputs_shape[3],
+                                        outputs_shape[4]])
+        outputs_4d = nn.bias_add(outputs_4d,
+                                 self.bias_posterior_tensor,
+                                 data_format="NCHW")
+        outputs = array_ops.reshape(outputs_4d, outputs_shape)
+    else:
+      outputs = nn.bias_add(outputs,
+                            self.bias_posterior_tensor,
+                            data_format="NHWC")
+    return outputs
+
+  def _apply_divergence(self, divergence_fn, posterior, prior,
+                        posterior_tensor, name):
+    if (divergence_fn is None or
+        posterior is None or
+        prior is None):
+      divergence = None
+      return
+    divergence = standard_ops.identity(
+        divergence_fn(
+            posterior, prior, posterior_tensor),
+        name=name)
+    self.add_loss(divergence)
+
+  def _compute_output_shape(self, input_shape):
+    input_shape = tensor_shape.TensorShape(input_shape).as_list()
+    if self.data_format == "channels_last":
+      space = input_shape[1:-1]
+      new_space = []
+      for i in range(len(space)):
+        new_dim = utils.conv_output_length(
+            space[i],
+            self.kernel_size[i],
+            padding=self.padding,
+            stride=self.strides[i],
+            dilation=self.dilation_rate[i])
+        new_space.append(new_dim)
+      return tensor_shape.TensorShape([input_shape[0]] + new_space +
+                                      [self.filters])
+    else:
+      space = input_shape[2:]
+      new_space = []
+      for i in range(len(space)):
+        new_dim = utils.conv_output_length(
+            space[i],
+            self.kernel_size[i],
+            padding=self.padding,
+            stride=self.strides[i],
+            dilation=self.dilation_rate[i])
+        new_space.append(new_dim)
+      return tensor_shape.TensorShape([input_shape[0], self.filters] +
+                                      new_space)
+
+
+class Conv1DVariational(_ConvVariational):
+  """1D convolution layer (e.g. temporal convolution).
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of a single integer, specifying the
+      length of the 1D convolution window.
+    strides: An integer or tuple/list of a single integer,
+      specifying the stride length of the convolution.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, length, channels)` while `channels_first` corresponds to
+      inputs with shape `(batch, channels, length)`.
+    dilation_rate: An integer or tuple/list of a single integer, specifying
+      the dilation rate to use for dilated convolution.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any `strides` value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+    name: A string, the name of the layer.
+
+  Properties:
+    filters: Python integer, dimensionality of the output space.
+    kernel_size: Size of the convolution window.
+    strides: Stride length of convolution.
+    padding: Python string describing padding approach.
+    data_format: Python string describing input data's dimensions.
+    dilation_rate: Dilation rate for an atrous convolution.
+    activation: Activation function (`callable`).
+    activity_regularizer: Regularizer function for the output.
+    kernel_use_local_reparameterization: Python `bool` indicating whether
+      `kernel` calculation should employ the Local Reparameterization Trick.
+    kernel_posterior_fn: `callable` returning posterior.
+    kernel_posterior_tensor_fn: `callable` operating on posterior.
+    kernel_prior_fn: `callable` returning prior.
+    kernel_divergence_fn: `callable` returning divergence.
+    bias_posterior_fn: `callable` returning posterior.
+    bias_posterior_tensor_fn: `callable` operating on posterior.
+    bias_prior_fn: `callable` returning prior.
+    bias_divergence_fn: `callable` returning divergence.
+
+  #### Examples
+
+  We illustrate a Bayesian neural network with [variational inference](
+  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+  assuming a dataset of `features` and `labels`.
+
+  ```python
+  tfp = tf.contrib.bayesflow
+
+  net = tf.reshape(features, [-1, 128, 1])
+  net = tfp.layers.Conv1DVariational(64,
+                                     kernel_size=5,
+                                     padding="SAME",
+                                     activation=tf.nn.relu)(net)
+  net = tf.reshape(net, [-1, 128 * 64])
+  logits = tfp.layers.DenseVariational(10)(net)
+  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+      labels=labels, logits=logits)
+  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+  loss = neg_log_likelihood + kl
+  train_op = tf.train.AdamOptimizer().minimize(loss)
+  ```
+
+  It uses reparameterization gradients to minimize the
+  Kullback-Leibler divergence up to a constant, also known as the
+  negative Evidence Lower Bound. It consists of the sum of two terms:
+  the expected negative log-likelihood, which we approximate via
+  Monte Carlo; and the KL divergence, which is added via regularizer
+  terms which are arguments to the layer.
+  """
+
+  def __init__(
+      self,
+      filters,
+      kernel_size,
+      strides=1,
+      padding="valid",
+      data_format="channels_last",
+      dilation_rate=1,
+      activation=None,
+      activity_regularizer=None,
+      trainable=True,
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+      kernel_posterior_tensor_fn=lambda d: d.sample(),
+      kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+          loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+      kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+      bias_posterior_tensor_fn=lambda d: d.sample(),
+      bias_prior_fn=None,
+      bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      name=None,
+      **kwargs):
+    super(Conv1DVariational, self).__init__(
+        rank=1,
+        filters=filters,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        data_format=data_format,
+        dilation_rate=dilation_rate,
+        activation=activation,
+        activity_regularizer=activity_regularizer,
+        trainable=trainable,
+        kernel_posterior_fn=kernel_posterior_fn,
+        kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+        kernel_prior_fn=kernel_prior_fn,
+        kernel_divergence_fn=kernel_divergence_fn,
+        bias_posterior_fn=bias_posterior_fn,
+        bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+        bias_prior_fn=bias_prior_fn,
+        bias_divergence_fn=bias_divergence_fn,
+        name=name, **kwargs)
+
+
+def conv1d_variational(
+    inputs,
+    filters,
+    kernel_size,
+    strides=1,
+    padding="valid",
+    data_format="channels_last",
+    dilation_rate=1,
+    activation=None,
+    activity_regularizer=None,
+    trainable=True,
+    kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+    kernel_posterior_tensor_fn=lambda d: d.sample(),
+    kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+        loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+    kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+    bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+    bias_posterior_tensor_fn=lambda d: d.sample(),
+    bias_prior_fn=None,
+    bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+    name=None,
+    reuse=None):
+  """Functional interface for 1D convolution layer (e.g. temporal convolution).
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    inputs: Tensor input.
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of a single integer, specifying the
+      length of the 1D convolution window.
+    strides: An integer or tuple/list of a single integer,
+      specifying the stride length of the convolution.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, length, channels)` while `channels_first` corresponds to
+      inputs with shape `(batch, channels, length)`.
+    dilation_rate: An integer or tuple/list of a single integer, specifying
+      the dilation rate to use for dilated convolution.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any `strides` value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+    name: A string, the name of the layer.
+    reuse: Boolean, whether to reuse the weights of a previous layer
+      by the same name.
+
+  Returns:
+    Output tensor.
+
+  Raises:
+    ValueError: if eager execution is enabled.
+
+  #### Examples
+
+  We illustrate a Bayesian neural network with [variational inference](
+  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+  assuming a dataset of `features` and `labels`.
+
+  ```python
+  tfp = tf.contrib.bayesflow
+
+  net = tf.reshape(features, [-1, 128, 1])
+  net = tfp.layers.conv1d_variational(net,
+                                      64,
+                                      kernel_size=5,
+                                      padding="SAME",
+                                      activation=tf.nn.relu)
+  net = tf.reshape(net, [-1, 128 * 64])
+  logits = tfp.layers.dense_variational(net, 10)
+  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+      labels=labels, logits=logits)
+  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+  loss = neg_log_likelihood + kl
+  train_op = tf.train.AdamOptimizer().minimize(loss)
+  ```
+
+  It uses reparameterization gradients to minimize the
+  Kullback-Leibler divergence up to a constant, also known as the
+  negative Evidence Lower Bound. It consists of the sum of two terms:
+  the expected negative log-likelihood, which we approximate via
+  Monte Carlo; and the KL divergence, which is added via regularizer
+  terms which are arguments to the layer.
+  """
+  layer = Conv1DVariational(
+      filters=filters,
+      kernel_size=kernel_size,
+      strides=strides,
+      padding=padding,
+      data_format=data_format,
+      dilation_rate=dilation_rate,
+      activation=activation,
+      activity_regularizer=activity_regularizer,
+      trainable=trainable,
+      kernel_posterior_fn=kernel_posterior_fn,
+      kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+      kernel_prior_fn=kernel_prior_fn,
+      kernel_divergence_fn=kernel_divergence_fn,
+      bias_posterior_fn=bias_posterior_fn,
+      bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+      bias_prior_fn=bias_prior_fn,
+      bias_divergence_fn=bias_divergence_fn,
+      name=name,
+      dtype=inputs.dtype.base_dtype,
+      _scope=name,
+      _reuse=reuse)
+  return layer.apply(inputs)
+
+
+class Conv2DVariational(_ConvVariational):
+  """2D convolution layer (e.g. spatial convolution over images).
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of 2 integers, specifying the
+      height and width of the 2D convolution window.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+    strides: An integer or tuple/list of 2 integers,
+      specifying the strides of the convolution along the height and width.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, height, width, channels)` while `channels_first` corresponds to
+      inputs with shape `(batch, channels, height, width)`.
+
+    dilation_rate: An integer or tuple/list of 2 integers, specifying
+      the dilation rate to use for dilated convolution.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any stride value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+    name: A string, the name of the layer.
+
+  Properties:
+    filters: Python integer, dimensionality of the output space.
+    kernel_size: Size of the convolution window.
+    strides: Stride length of convolution.
+    padding: Python string describing padding approach.
+    data_format: Python string describing input data's dimensions.
+    dilation_rate: Dilation rate for an atrous convolution.
+    activation: Activation function (`callable`).
+    activity_regularizer: Regularizer function for the output.
+    kernel_use_local_reparameterization: Python `bool` indicating whether
+      `kernel` calculation should employ the Local Reparameterization Trick.
+    kernel_posterior_fn: `callable` returning posterior.
+    kernel_posterior_tensor_fn: `callable` operating on posterior.
+    kernel_prior_fn: `callable` returning prior.
+    kernel_divergence_fn: `callable` returning divergence.
+    bias_posterior_fn: `callable` returning posterior.
+    bias_posterior_tensor_fn: `callable` operating on posterior.
+    bias_prior_fn: `callable` returning prior.
+    bias_divergence_fn: `callable` returning divergence.
+
+  #### Examples
+
+  We illustrate a Bayesian neural network with [variational inference](
+  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+  assuming a dataset of `features` and `labels`.
+
+  ```python
+  tfp = tf.contrib.bayesflow
+
+  net = tf.reshape(features, [-1, 32, 32, 3])
+  net = tfp.layers.Conv2DVariational(64,
+                                     kernel_size=5,
+                                     padding="SAME",
+                                     activation=tf.nn.relu)(net)
+  net = tf.layers.MaxPooling2D(pool_size=2,
+                               strides=2,
+                               padding="SAME")(net)
+  net = tf.reshape(net, [-1, 8 * 8 * 64])
+  logits = tfp.layers.DenseVariational(10)(net)
+  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+      labels=labels, logits=logits)
+  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+  loss = neg_log_likelihood + kl
+  train_op = tf.train.AdamOptimizer().minimize(loss)
+  ```
+
+  It uses reparameterization gradients to minimize the
+  Kullback-Leibler divergence up to a constant, also known as the
+  negative Evidence Lower Bound. It consists of the sum of two terms:
+  the expected negative log-likelihood, which we approximate via
+  Monte Carlo; and the KL divergence, which is added via regularizer
+  terms which are arguments to the layer.
+  """
+
+  def __init__(
+      self,
+      filters,
+      kernel_size,
+      strides=(1, 1),
+      padding="valid",
+      data_format="channels_last",
+      dilation_rate=(1, 1),
+      activation=None,
+      activity_regularizer=None,
+      trainable=True,
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+      kernel_posterior_tensor_fn=lambda d: d.sample(),
+      kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+          loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+      kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+      bias_posterior_tensor_fn=lambda d: d.sample(),
+      bias_prior_fn=None,
+      bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      name=None,
+      **kwargs):
+    super(Conv2DVariational, self).__init__(
+        rank=2,
+        filters=filters,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        data_format=data_format,
+        dilation_rate=dilation_rate,
+        activation=activation,
+        activity_regularizer=activity_regularizer,
+        trainable=trainable,
+        kernel_posterior_fn=kernel_posterior_fn,
+        kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+        kernel_prior_fn=kernel_prior_fn,
+        kernel_divergence_fn=kernel_divergence_fn,
+        bias_posterior_fn=bias_posterior_fn,
+        bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+        bias_prior_fn=bias_prior_fn,
+        bias_divergence_fn=bias_divergence_fn,
+        name=name, **kwargs)
+
+
+def conv2d_variational(
+    inputs,
+    filters,
+    kernel_size,
+    strides=(1, 1),
+    padding="valid",
+    data_format="channels_last",
+    dilation_rate=(1, 1),
+    activation=None,
+    activity_regularizer=None,
+    trainable=True,
+    kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+    kernel_posterior_tensor_fn=lambda d: d.sample(),
+    kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+        loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+    kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+    bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+    bias_posterior_tensor_fn=lambda d: d.sample(),
+    bias_prior_fn=None,
+    bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+    name=None,
+    reuse=None):
+  """Functional interface for the 2D convolution layer.
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    inputs: Tensor input.
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of 2 integers, specifying the
+      height and width of the 2D convolution window.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+    strides: An integer or tuple/list of 2 integers,
+      specifying the strides of the convolution along the height and width.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, height, width, channels)` while `channels_first` corresponds to
+      inputs with shape `(batch, channels, height, width)`.
+
+    dilation_rate: An integer or tuple/list of 2 integers, specifying
+      the dilation rate to use for dilated convolution.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any stride value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+    name: A string, the name of the layer.
+    reuse: Boolean, whether to reuse the weights of a previous layer
+      by the same name.
+
+  Returns:
+    Output tensor.
+
+  Raises:
+    ValueError: if eager execution is enabled.
+
+  #### Examples
+
+  We illustrate a Bayesian neural network with [variational inference](
+  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+  assuming a dataset of `features` and `labels`.
+
+  ```python
+  tfp = tf.contrib.bayesflow
+
+  net = tf.reshape(features, [-1, 32, 32, 3])
+  net = tfp.layers.conv2d_variational(net,
+                                      64,
+                                      kernel_size=5,
+                                      padding="SAME",
+                                      activation=tf.nn.relu)
+  net = tf.layers.max_pooling2d(net,
+                                pool_size=2,
+                                strides=2,
+                                padding="SAME")
+  net = tf.reshape(net, [-1, 8 * 8 * 64])
+  logits = tfp.layers.dense_variational(net, 10)
+  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+      labels=labels, logits=logits)
+  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+  loss = neg_log_likelihood + kl
+  train_op = tf.train.AdamOptimizer().minimize(loss)
+  ```
+
+  It uses reparameterization gradients to minimize the
+  Kullback-Leibler divergence up to a constant, also known as the
+  negative Evidence Lower Bound. It consists of the sum of two terms:
+  the expected negative log-likelihood, which we approximate via
+  Monte Carlo; and the KL divergence, which is added via regularizer
+  terms which are arguments to the layer.
+  """
+  layer = Conv2DVariational(
+      filters=filters,
+      kernel_size=kernel_size,
+      strides=strides,
+      padding=padding,
+      data_format=data_format,
+      dilation_rate=dilation_rate,
+      activation=activation,
+      activity_regularizer=activity_regularizer,
+      trainable=trainable,
+      kernel_posterior_fn=kernel_posterior_fn,
+      kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+      kernel_prior_fn=kernel_prior_fn,
+      kernel_divergence_fn=kernel_divergence_fn,
+      bias_posterior_fn=bias_posterior_fn,
+      bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+      bias_prior_fn=bias_prior_fn,
+      bias_divergence_fn=bias_divergence_fn,
+      name=name,
+      dtype=inputs.dtype.base_dtype,
+      _scope=name,
+      _reuse=reuse)
+  return layer.apply(inputs)
+
+
+class Conv3DVariational(_ConvVariational):
+  """3D convolution layer (e.g. spatial convolution over volumes).
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of 3 integers, specifying the
+      depth, height and width of the 3D convolution window.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+    strides: An integer or tuple/list of 3 integers,
+      specifying the strides of the convolution along the depth,
+      height and width.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, depth, height, width, channels)` while `channels_first`
+      corresponds to inputs with shape
+      `(batch, channels, depth, height, width)`.
+    dilation_rate: An integer or tuple/list of 3 integers, specifying
+      the dilation rate to use for dilated convolution.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any stride value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+    name: A string, the name of the layer.
+
+  Properties:
+    filters: Python integer, dimensionality of the output space.
+    kernel_size: Size of the convolution window.
+    strides: Stride length of convolution.
+    padding: Python string describing padding approach.
+    data_format: Python string describing input data's dimensions.
+    dilation_rate: Dilation rate for an atrous convolution.
+    activation: Activation function (`callable`).
+    activity_regularizer: Regularizer function for the output.
+    kernel_use_local_reparameterization: Python `bool` indicating whether
+      `kernel` calculation should employ the Local Reparameterization Trick.
+    kernel_posterior_fn: `callable` returning posterior.
+    kernel_posterior_tensor_fn: `callable` operating on posterior.
+    kernel_prior_fn: `callable` returning prior.
+    kernel_divergence_fn: `callable` returning divergence.
+    bias_posterior_fn: `callable` returning posterior.
+    bias_posterior_tensor_fn: `callable` operating on posterior.
+    bias_prior_fn: `callable` returning prior.
+    bias_divergence_fn: `callable` returning divergence.
+
+  #### Examples
+
+  We illustrate a Bayesian neural network with [variational inference](
+  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+  assuming a dataset of `features` and `labels`.
+
+  ```python
+  tfp = tf.contrib.bayesflow
+
+  net = tf.reshape(features, [-1, 256, 32, 32, 3])
+  net = tfp.layers.Conv3DVariational(64,
+                                     kernel_size=5,
+                                     padding="SAME",
+                                     activation=tf.nn.relu)(net)
+  net = tf.layers.MaxPooling2D(pool_size=2,
+                               strides=2,
+                               padding="SAME")(net)
+  net = tf.reshape(net, [-1, 256 * 8 * 8 * 64])
+  logits = tfp.layers.DenseVariational(10)(net)
+  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+      labels=labels, logits=logits)
+  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+  loss = neg_log_likelihood + kl
+  train_op = tf.train.AdamOptimizer().minimize(loss)
+  ```
+
+  It uses reparameterization gradients to minimize the
+  Kullback-Leibler divergence up to a constant, also known as the
+  negative Evidence Lower Bound. It consists of the sum of two terms:
+  the expected negative log-likelihood, which we approximate via
+  Monte Carlo; and the KL divergence, which is added via regularizer
+  terms which are arguments to the layer.
+  """
+
+  def __init__(
+      self,
+      filters,
+      kernel_size,
+      strides=(1, 1, 1),
+      padding="valid",
+      data_format="channels_last",
+      dilation_rate=(1, 1, 1),
+      activation=None,
+      activity_regularizer=None,
+      trainable=True,
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+      kernel_posterior_tensor_fn=lambda d: d.sample(),
+      kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+          loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+      kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+      bias_posterior_tensor_fn=lambda d: d.sample(),
+      bias_prior_fn=None,
+      bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+      name=None,
+      **kwargs):
+    super(Conv3DVariational, self).__init__(
+        rank=3,
+        filters=filters,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        data_format=data_format,
+        dilation_rate=dilation_rate,
+        activation=activation,
+        activity_regularizer=activity_regularizer,
+        trainable=trainable,
+        kernel_posterior_fn=kernel_posterior_fn,
+        kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+        kernel_prior_fn=kernel_prior_fn,
+        kernel_divergence_fn=kernel_divergence_fn,
+        bias_posterior_fn=bias_posterior_fn,
+        bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+        bias_prior_fn=bias_prior_fn,
+        bias_divergence_fn=bias_divergence_fn,
+        name=name, **kwargs)
+
+
+def conv3d_variational(
+    inputs,
+    filters,
+    kernel_size,
+    strides=(1, 1, 1),
+    padding="valid",
+    data_format="channels_last",
+    dilation_rate=(1, 1, 1),
+    activation=None,
+    activity_regularizer=None,
+    trainable=True,
+    kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+    kernel_posterior_tensor_fn=lambda d: d.sample(),
+    kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
+        loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+    kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+    bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
+    bias_posterior_tensor_fn=lambda d: d.sample(),
+    bias_prior_fn=None,
+    bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+    name=None,
+    reuse=None):
+  """Functional interface for the 3D convolution layer.
+
+  This layer creates a convolution kernel that is convolved
+  (actually cross-correlated) with the layer input to produce a tensor of
+  outputs. It may also include a bias addition and activation function
+  on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+  distributions.
+
+  By default, the layer implements a stochastic forward pass via
+  sampling from the kernel and bias posteriors,
+  ```none
+  outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+  ```
+  where f denotes the layer's calculation.
+
+  The arguments permit separate specification of the surrogate posterior
+  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+  distributions.
+
+  Arguments:
+    inputs: Tensor input.
+    filters: Integer, the dimensionality of the output space (i.e. the number
+      of filters in the convolution).
+    kernel_size: An integer or tuple/list of 3 integers, specifying the
+      depth, height and width of the 3D convolution window.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+    strides: An integer or tuple/list of 3 integers,
+      specifying the strides of the convolution along the depth,
+      height and width.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Specifying any stride value != 1 is incompatible with specifying
+      any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string, one of `channels_last` (default) or `channels_first`.
+      The ordering of the dimensions in the inputs.
+      `channels_last` corresponds to inputs with shape
+      `(batch, depth, height, width, channels)` while `channels_first`
+      corresponds to inputs with shape
+      `(batch, channels, depth, height, width)`.
+    dilation_rate: An integer or tuple/list of 3 integers, specifying
+      the dilation rate to use for dilated convolution.
+      Can be a single integer to specify the same value for
+      all spatial dimensions.
+      Currently, specifying any `dilation_rate` value != 1 is
+      incompatible with specifying any stride value != 1.
+    activation: Activation function. Set it to None to maintain a
+      linear activation.
+    activity_regularizer: Optional regularizer function for the output.
+    trainable: Boolean, if `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    kernel_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `kernel` parameter. Default value:
+      `default_mean_field_normal_fn()`.
+    kernel_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    kernel_prior_fn: Python `callable` which creates `tf.distributions`
+      instance. See `default_mean_field_normal_fn` docstring for required
+      parameter signature.
+      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+      sample is a `Tensor`.
+    bias_posterior_fn: Python `callable` which creates
+      `tf.distributions.Distribution` instance representing the surrogate
+      posterior of the `bias` parameter. Default value:
+      `default_mean_field_normal_fn(is_singular=True)` (which creates an
+      instance of `tf.distributions.Deterministic`).
+    bias_posterior_tensor_fn: Python `callable` which takes a
+      `tf.distributions.Distribution` instance and returns a representative
+      value. Default value: `lambda d: d.sample()`.
+    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+      See `default_mean_field_normal_fn` docstring for required parameter
+      signature. Default value: `None` (no prior, no variational inference)
+    bias_divergence_fn: Python `callable` which takes the surrogate posterior
+      distribution, prior distribution and random variate sample(s) from the
+      surrogate posterior and computes or approximates the KL divergence. The
+      distributions are `tf.distributions.Distribution`-like instances and the
+    name: A string, the name of the layer.
+    reuse: Boolean, whether to reuse the weights of a previous layer
+      by the same name.
+
+  Returns:
+    Output tensor.
+
+  Raises:
+    ValueError: if eager execution is enabled.
+
+  #### Examples
+
+  We illustrate a Bayesian neural network with [variational inference](
+  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+  assuming a dataset of `features` and `labels`.
+
+  ```python
+  tfp = tf.contrib.bayesflow
+
+  net = tf.reshape(features, [-1, 256, 32, 32, 3])
+  net = tfp.layers.conv3d_variational(net,
+                                      64,
+                                      kernel_size=5,
+                                      padding="SAME",
+                                      activation=tf.nn.relu)
+  net = tf.layers.max_pooling2d(net,
+                                pool_size=2,
+                                strides=2,
+                                padding="SAME")
+  net = tf.reshape(net, [-1, 256 * 8 * 8 * 64])
+  logits = tfp.layers.dense_variational(net, 10)
+  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+      labels=labels, logits=logits)
+  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+  loss = neg_log_likelihood + kl
+  train_op = tf.train.AdamOptimizer().minimize(loss)
+  ```
+
+  It uses reparameterization gradients to minimize the
+  Kullback-Leibler divergence up to a constant, also known as the
+  negative Evidence Lower Bound. It consists of the sum of two terms:
+  the expected negative log-likelihood, which we approximate via
+  Monte Carlo; and the KL divergence, which is added via regularizer
+  terms which are arguments to the layer.
+  """
+  layer = Conv3DVariational(
+      filters=filters,
+      kernel_size=kernel_size,
+      strides=strides,
+      padding=padding,
+      data_format=data_format,
+      dilation_rate=dilation_rate,
+      activation=activation,
+      activity_regularizer=activity_regularizer,
+      trainable=trainable,
+      kernel_posterior_fn=kernel_posterior_fn,
+      kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+      kernel_prior_fn=kernel_prior_fn,
+      kernel_divergence_fn=kernel_divergence_fn,
+      bias_posterior_fn=bias_posterior_fn,
+      bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+      bias_prior_fn=bias_prior_fn,
+      bias_divergence_fn=bias_divergence_fn,
+      name=name,
+      dtype=inputs.dtype.base_dtype,
+      _scope=name,
+      _reuse=reuse)
+  return layer.apply(inputs)
+
+
+# Aliases
+
+Convolution1DVariational = Conv1DVariational
+Convolution2DVariational = Conv2DVariational
+Convolution3DVariational = Conv3DVariational
+convolution1d_variational = conv1d_variational
+convolution2d_variational = conv2d_variational
+convolution3d_variational = conv3d_variational
index 2a260405d0c3c0e06be0369d290797bde8d51925..a749a396f15188ef345b4ae7c53017b6004c5e71 100644 (file)
 @@dense_reparameterization
 @@dense_local_reparameterization
 @@dense_flipout
-
-@@default_loc_scale_fn
-@@default_mean_field_normal_fn
 """
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
-
-from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
+from tensorflow.contrib.bayesflow.python.ops import layers_util
 from tensorflow.contrib.distributions.python.ops import independent as independent_lib
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.layers import base as layers_lib
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
-from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import standard_ops
 from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
@@ -56,162 +49,9 @@ __all__ = [
     "dense_reparameterization",
     "dense_local_reparameterization",
     "dense_flipout",
-    "default_loc_scale_fn",
-    "default_mean_field_normal_fn",
 ]
 
 
-def default_loc_scale_fn(
-    is_singular=False,
-    loc_initializer=init_ops.random_normal_initializer(stddev=0.1),
-    untransformed_scale_initializer=init_ops.random_normal_initializer(
-        mean=-3., stddev=0.1),
-    loc_regularizer=None,
-    untransformed_scale_regularizer=None,
-    loc_constraint=None,
-    untransformed_scale_constraint=None):
-  """Makes closure which creates `loc`, `scale` params from `tf.get_variable`.
-
-  This function produces a closure which produces `loc`, `scale` using
-  `tf.get_variable`. The closure accepts the following arguments:
-
-    dtype: Type of parameter's event.
-    shape: Python `list`-like representing the parameter's event shape.
-    name: Python `str` name prepended to any created (or existing)
-      `tf.Variable`s.
-    trainable: Python `bool` indicating all created `tf.Variable`s should be
-      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
-    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
-      access existing) `tf.Variable`s.
-
-  Args:
-    is_singular: Python `bool` indicating if `scale is None`. Default: `False`.
-    loc_initializer: Initializer function for the `loc` parameters.
-      The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`.
-    untransformed_scale_initializer: Initializer function for the `scale`
-      parameters. Default value: `tf.random_normal_initializer(mean=-3.,
-      stddev=0.1)`. This implies the softplus transformed result has mean
-      approximately `0.05` and std. deviation approximately `0.005`.
-    loc_regularizer: Regularizer function for the `loc` parameters.
-      The default (`None`) is to use the `tf.get_variable` default.
-    untransformed_scale_regularizer: Regularizer function for the `scale`
-      parameters. The default (`None`) is to use the `tf.get_variable` default.
-    loc_constraint: An optional projection function to be applied to the
-      loc after being updated by an `Optimizer`. The function must take as input
-      the unprojected variable and must return the projected variable (which
-      must have the same shape). Constraints are not safe to use when doing
-      asynchronous distributed training.
-      The default (`None`) is to use the `tf.get_variable` default.
-    untransformed_scale_constraint: An optional projection function to be
-      applied to the `scale` parameters after being updated by an `Optimizer`
-      (e.g. used to implement norm constraints or value constraints). The
-      function must take as input the unprojected variable and must return the
-      projected variable (which must have the same shape). Constraints are not
-      safe to use when doing asynchronous distributed training. The default
-      (`None`) is to use the `tf.get_variable` default.
-
-  Returns:
-    default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale`
-    parameters from args: `dtype, shape, name, trainable, add_variable_fn`.
-  """
-  def _fn(dtype, shape, name, trainable, add_variable_fn):
-    """Creates `loc`, `scale` parameters."""
-    loc = add_variable_fn(
-        name=name + "_loc",
-        shape=shape,
-        initializer=loc_initializer,
-        regularizer=loc_regularizer,
-        constraint=loc_constraint,
-        dtype=dtype,
-        trainable=trainable)
-    if is_singular:
-      return loc, None
-    untransformed_scale = add_variable_fn(
-        name=name + "_untransformed_scale",
-        shape=shape,
-        initializer=untransformed_scale_initializer,
-        regularizer=untransformed_scale_regularizer,
-        constraint=untransformed_scale_constraint,
-        dtype=dtype,
-        trainable=trainable)
-    scale = (np.finfo(dtype.as_numpy_dtype).eps +
-             nn_ops.softplus(untransformed_scale))
-    return loc, scale
-  return _fn
-
-
-def default_mean_field_normal_fn(
-    is_singular=False,
-    loc_initializer=None,
-    untransformed_scale_initializer=None,
-    loc_regularizer=None,
-    untransformed_scale_regularizer=None,
-    loc_constraint=None,
-    untransformed_scale_constraint=None):
-  """Creates a function to build Normal distributions with trainable params.
-
-  This function produces a closure which produces `tf.distributions.Normal`
-  parameterized by a loc` and `scale` each created using `tf.get_variable`. The
-  produced closure accepts the following arguments:
-
-    name: Python `str` name prepended to any created (or existing)
-      `tf.Variable`s.
-    shape: Python `list`-like representing the parameter's event shape.
-    dtype: Type of parameter's event.
-    trainable: Python `bool` indicating all created `tf.Variable`s should be
-      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
-    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
-      access existing) `tf.Variable`s.
-
-  Args:
-    is_singular: Python `bool` if `True`, forces the special case limit of
-      `scale->0`, i.e., a `Deterministic` distribution.
-    loc_initializer: Initializer function for the `loc` parameters.
-      If `None` (default), values are initialized using the default
-      initializer used by `tf.get_variable`.
-    untransformed_scale_initializer: Initializer function for the `scale`
-      parameters. If `None` (default), values are initialized using the default
-      initializer used by `tf.get_variable`.
-    loc_regularizer: Regularizer function for the `loc` parameters.
-    untransformed_scale_regularizer: Regularizer function for the `scale`
-      parameters.
-    loc_constraint: An optional projection function to be applied to the
-      loc after being updated by an `Optimizer`. The function must take as input
-      the unprojected variable and must return the projected variable (which
-      must have the same shape). Constraints are not safe to use when doing
-      asynchronous distributed training.
-    untransformed_scale_constraint: An optional projection function to be
-      applied to the `scale` parameters after being updated by an `Optimizer`
-      (e.g. used to implement norm constraints or value constraints). The
-      function must take as input the unprojected variable and must return the
-      projected variable (which must have the same shape). Constraints are not
-      safe to use when doing asynchronous distributed training.
-
-  Returns:
-    make_normal_fn: Python `callable` which creates a `tf.distributions.Normal`
-      using from args: `dtype, shape, name, trainable, add_variable_fn`.
-  """
-  loc_scale_fn_ = default_loc_scale_fn(
-      is_singular,
-      loc_initializer,
-      untransformed_scale_initializer,
-      loc_regularizer,
-      untransformed_scale_regularizer,
-      loc_constraint,
-      untransformed_scale_constraint)
-  def _fn(dtype, shape, name, trainable, add_variable_fn):
-    """Creates multivariate `Deterministic` or `Normal` distribution."""
-    loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
-    if scale is None:
-      dist = deterministic_lib.Deterministic(loc=loc)
-    else:
-      dist = normal_lib.Normal(loc=loc, scale=scale)
-    reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0]
-    return independent_lib.Independent(
-        dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
-  return _fn
-
-
 class _DenseVariational(layers_lib.Layer):
   """Abstract densely-connected class (private, used as implementation base).
 
@@ -294,12 +134,12 @@ class _DenseVariational(layers_lib.Layer):
       activation=None,
       activity_regularizer=None,
       trainable=True,
-      kernel_posterior_fn=default_mean_field_normal_fn(),
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
       kernel_posterior_tensor_fn=lambda d: d.sample(),
       kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
           loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
       kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-      bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
       bias_posterior_tensor_fn=lambda d: d.sample(),
       bias_prior_fn=None,
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
@@ -540,12 +380,13 @@ class DenseReparameterization(_DenseVariational):
       activation=None,
       activity_regularizer=None,
       trainable=True,
-      kernel_posterior_fn=default_mean_field_normal_fn(),
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
       kernel_posterior_tensor_fn=lambda d: d.sample(),
       kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
           loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
       kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-      bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(
+          is_singular=True),
       bias_posterior_tensor_fn=lambda d: d.sample(),
       bias_prior_fn=None,
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
@@ -581,12 +422,12 @@ def dense_reparameterization(
     activation=None,
     activity_regularizer=None,
     trainable=True,
-    kernel_posterior_fn=default_mean_field_normal_fn(),
+    kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
     kernel_posterior_tensor_fn=lambda d: d.sample(),
     kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
         loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
     kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-    bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+    bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
     bias_posterior_tensor_fn=lambda d: d.sample(),
     bias_prior_fn=None,
     bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
@@ -812,12 +653,13 @@ class DenseLocalReparameterization(_DenseVariational):
       activation=None,
       activity_regularizer=None,
       trainable=True,
-      kernel_posterior_fn=default_mean_field_normal_fn(),
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
       kernel_posterior_tensor_fn=lambda d: d.sample(),
       kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
           loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
       kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-      bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(
+          is_singular=True),
       bias_posterior_tensor_fn=lambda d: d.sample(),
       bias_prior_fn=None,
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
@@ -864,12 +706,13 @@ def dense_local_reparameterization(
     activation=None,
     activity_regularizer=None,
     trainable=True,
-    kernel_posterior_fn=default_mean_field_normal_fn(),
+    kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
     kernel_posterior_tensor_fn=lambda d: d.sample(),
     kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
         loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
     kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-    bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+    bias_posterior_fn=layers_util.default_mean_field_normal_fn(
+        is_singular=True),
     bias_posterior_tensor_fn=lambda d: d.sample(),
     bias_prior_fn=None,
     bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
@@ -1098,12 +941,13 @@ class DenseFlipout(_DenseVariational):
       activation=None,
       activity_regularizer=None,
       trainable=True,
-      kernel_posterior_fn=default_mean_field_normal_fn(),
+      kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
       kernel_posterior_tensor_fn=lambda d: d.sample(),
       kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
           loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
       kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-      bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+      bias_posterior_fn=layers_util.default_mean_field_normal_fn(
+          is_singular=True),
       bias_posterior_tensor_fn=lambda d: d.sample(),
       bias_prior_fn=None,
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
@@ -1151,7 +995,7 @@ class DenseFlipout(_DenseVariational):
                           array_ops.expand_dims(self.units, 0)], 0),
         dtype=inputs.dtype,
         seed=distribution_util.gen_new_seed(
-            self.seed, salt="conv_variational"))
+            self.seed, salt="dense_flipout"))
     perturbed_inputs = self._matmul(
         inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output
 
@@ -1166,12 +1010,13 @@ def dense_flipout(
     activation=None,
     activity_regularizer=None,
     trainable=True,
-    kernel_posterior_fn=default_mean_field_normal_fn(),
+    kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
     kernel_posterior_tensor_fn=lambda d: d.sample(),
     kernel_prior_fn=lambda dtype, *args: normal_lib.Normal(  # pylint: disable=g-long-lambda
         loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
     kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
-    bias_posterior_fn=default_mean_field_normal_fn(is_singular=True),
+    bias_posterior_fn=layers_util.default_mean_field_normal_fn(
+        is_singular=True),
     bias_posterior_tensor_fn=lambda d: d.sample(),
     bias_prior_fn=None,
     bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py
new file mode 100644 (file)
index 0000000..9a4fecf
--- /dev/null
@@ -0,0 +1,180 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for probabilistic layers.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops.distributions import normal as normal_lib
+
+
+def default_loc_scale_fn(
+    is_singular=False,
+    loc_initializer=init_ops.random_normal_initializer(stddev=0.1),
+    untransformed_scale_initializer=init_ops.random_normal_initializer(
+        mean=-3., stddev=0.1),
+    loc_regularizer=None,
+    untransformed_scale_regularizer=None,
+    loc_constraint=None,
+    untransformed_scale_constraint=None):
+  """Makes closure which creates `loc`, `scale` params from `tf.get_variable`.
+
+  This function produces a closure which produces `loc`, `scale` using
+  `tf.get_variable`. The closure accepts the following arguments:
+
+    dtype: Type of parameter's event.
+    shape: Python `list`-like representing the parameter's event shape.
+    name: Python `str` name prepended to any created (or existing)
+      `tf.Variable`s.
+    trainable: Python `bool` indicating all created `tf.Variable`s should be
+      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
+    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
+      access existing) `tf.Variable`s.
+
+  Args:
+    is_singular: Python `bool` indicating if `scale is None`. Default: `False`.
+    loc_initializer: Initializer function for the `loc` parameters.
+      The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`.
+    untransformed_scale_initializer: Initializer function for the `scale`
+      parameters. Default value: `tf.random_normal_initializer(mean=-3.,
+      stddev=0.1)`. This implies the softplus transformed result has mean
+      approximately `0.05` and std. deviation approximately `0.005`.
+    loc_regularizer: Regularizer function for the `loc` parameters.
+      The default (`None`) is to use the `tf.get_variable` default.
+    untransformed_scale_regularizer: Regularizer function for the `scale`
+      parameters. The default (`None`) is to use the `tf.get_variable` default.
+    loc_constraint: An optional projection function to be applied to the
+      loc after being updated by an `Optimizer`. The function must take as input
+      the unprojected variable and must return the projected variable (which
+      must have the same shape). Constraints are not safe to use when doing
+      asynchronous distributed training.
+      The default (`None`) is to use the `tf.get_variable` default.
+    untransformed_scale_constraint: An optional projection function to be
+      applied to the `scale` parameters after being updated by an `Optimizer`
+      (e.g. used to implement norm constraints or value constraints). The
+      function must take as input the unprojected variable and must return the
+      projected variable (which must have the same shape). Constraints are not
+      safe to use when doing asynchronous distributed training. The default
+      (`None`) is to use the `tf.get_variable` default.
+
+  Returns:
+    default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale`
+    parameters from args: `dtype, shape, name, trainable, add_variable_fn`.
+  """
+  def _fn(dtype, shape, name, trainable, add_variable_fn):
+    """Creates `loc`, `scale` parameters."""
+    loc = add_variable_fn(
+        name=name + "_loc",
+        shape=shape,
+        initializer=loc_initializer,
+        regularizer=loc_regularizer,
+        constraint=loc_constraint,
+        dtype=dtype,
+        trainable=trainable)
+    if is_singular:
+      return loc, None
+    untransformed_scale = add_variable_fn(
+        name=name + "_untransformed_scale",
+        shape=shape,
+        initializer=untransformed_scale_initializer,
+        regularizer=untransformed_scale_regularizer,
+        constraint=untransformed_scale_constraint,
+        dtype=dtype,
+        trainable=trainable)
+    scale = (np.finfo(dtype.as_numpy_dtype).eps +
+             nn_ops.softplus(untransformed_scale))
+    return loc, scale
+  return _fn
+
+
+def default_mean_field_normal_fn(
+    is_singular=False,
+    loc_initializer=None,
+    untransformed_scale_initializer=None,
+    loc_regularizer=None,
+    untransformed_scale_regularizer=None,
+    loc_constraint=None,
+    untransformed_scale_constraint=None):
+  """Creates a function to build Normal distributions with trainable params.
+
+  This function produces a closure which produces `tf.distributions.Normal`
+  parameterized by a loc` and `scale` each created using `tf.get_variable`. The
+  produced closure accepts the following arguments:
+
+    name: Python `str` name prepended to any created (or existing)
+      `tf.Variable`s.
+    shape: Python `list`-like representing the parameter's event shape.
+    dtype: Type of parameter's event.
+    trainable: Python `bool` indicating all created `tf.Variable`s should be
+      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
+    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
+      access existing) `tf.Variable`s.
+
+  Args:
+    is_singular: Python `bool` if `True`, forces the special case limit of
+      `scale->0`, i.e., a `Deterministic` distribution.
+    loc_initializer: Initializer function for the `loc` parameters.
+      If `None` (default), values are initialized using the default
+      initializer used by `tf.get_variable`.
+    untransformed_scale_initializer: Initializer function for the `scale`
+      parameters. If `None` (default), values are initialized using the default
+      initializer used by `tf.get_variable`.
+    loc_regularizer: Regularizer function for the `loc` parameters.
+    untransformed_scale_regularizer: Regularizer function for the `scale`
+      parameters.
+    loc_constraint: An optional projection function to be applied to the
+      loc after being updated by an `Optimizer`. The function must take as input
+      the unprojected variable and must return the projected variable (which
+      must have the same shape). Constraints are not safe to use when doing
+      asynchronous distributed training.
+    untransformed_scale_constraint: An optional projection function to be
+      applied to the `scale` parameters after being updated by an `Optimizer`
+      (e.g. used to implement norm constraints or value constraints). The
+      function must take as input the unprojected variable and must return the
+      projected variable (which must have the same shape). Constraints are not
+      safe to use when doing asynchronous distributed training.
+
+  Returns:
+    make_normal_fn: Python `callable` which creates a `tf.distributions.Normal`
+      using from args: `dtype, shape, name, trainable, add_variable_fn`.
+  """
+  loc_scale_fn_ = default_loc_scale_fn(
+      is_singular,
+      loc_initializer,
+      untransformed_scale_initializer,
+      loc_regularizer,
+      untransformed_scale_regularizer,
+      loc_constraint,
+      untransformed_scale_constraint)
+  def _fn(dtype, shape, name, trainable, add_variable_fn):
+    """Creates multivariate `Deterministic` or `Normal` distribution."""
+    loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
+    if scale is None:
+      dist = deterministic_lib.Deterministic(loc=loc)
+    else:
+      dist = normal_lib.Normal(loc=loc, scale=scale)
+    reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0]
+    return independent_lib.Independent(
+        dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
+  return _fn