Reduce tfp.layers boilerplate via programmable docstrings.
authorDustin Tran <trandustin@google.com>
Tue, 20 Feb 2018 05:39:03 +0000 (21:39 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Feb 2018 05:42:41 +0000 (21:42 -0800)
PiperOrigin-RevId: 186260342

tensorflow/contrib/bayesflow/BUILD
tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py [new file with mode: 0644]
tensorflow/contrib/bayesflow/python/ops/docstring_util.py [new file with mode: 0644]
tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py

index 74712ae..fc04933 100644 (file)
@@ -119,6 +119,16 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "docstring_util_test",
+    size = "small",
+    srcs = ["python/kernel_tests/docstring_util_test.py"],
+    additional_deps = [
+        ":bayesflow_py",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+cuda_py_test(
     name = "layers_dense_variational_test",
     size = "small",
     srcs = ["python/kernel_tests/layers_dense_variational_test.py"],
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py
new file mode 100644 (file)
index 0000000..09ae6f3
--- /dev/null
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for docstring utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bayesflow.python.ops import docstring_util
+from tensorflow.python.platform import test
+
+
+class DocstringUtil(test.TestCase):
+
+  def _testFunction(self):
+    doc_args = """  x: Input to return as output.
+  y: Baz."""
+    @docstring_util.expand_docstring(args=doc_args)
+    def foo(x):
+      """Hello world.
+
+      Args:
+      @{args}
+
+      Returns:
+        x.
+      """
+      return x
+
+    true_docstring = """Hello world.
+
+    Args:
+      x: Input to return as output.
+      y: Baz.
+
+    Returns:
+      x.
+    """
+    self.assertEqual(foo.__doc__, true_docstring)
+
+  def _testClassInit(self):
+    doc_args = """  x: Input to return as output.
+  y: Baz."""
+
+    class Foo(object):
+
+      @docstring_util.expand_docstring(args=doc_args)
+      def __init__(self, x, y):
+        """Hello world.
+
+        Args:
+        @{args}
+
+        Bar.
+        """
+        pass
+
+    true_docstring = """Hello world.
+
+    Args:
+      x: Input to return as output.
+      y: Baz.
+
+    Bar.
+    """
+    self.assertEqual(Foo.__doc__, true_docstring)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/docstring_util.py b/tensorflow/contrib/bayesflow/python/ops/docstring_util.py
new file mode 100644 (file)
index 0000000..44a1ea2
--- /dev/null
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for programmable docstrings.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import sys
+import six
+
+
+def expand_docstring(**kwargs):
+  """Decorator to programmatically expand the docstring.
+
+  Args:
+    **kwargs: Keyword arguments to set. For each key-value pair `k` and `v`,
+      the key is found as `@{k}` in the docstring and replaced with `v`.
+
+  Returns:
+    Decorated function.
+  """
+  def _fn_wrapped(fn):
+    """Original function with modified `__doc__` attribute."""
+    doc = _trim(fn.__doc__)
+    for k, v in six.iteritems(kwargs):
+      # Capture each @{k} reference to replace with v.
+      # We wrap the replacement in a function so no backslash escapes
+      # are processed.
+      pattern = r'@\{' + str(k) + r'\}'
+      doc = re.sub(pattern, lambda match: v, doc)  # pylint: disable=cell-var-from-loop
+    fn.__doc__ = doc
+    return fn
+  return _fn_wrapped
+
+
+def _trim(docstring):
+  """Trims docstring indentation.
+
+  In general, multi-line docstrings carry their level of indentation when
+  defined under a function or class method. This function standardizes
+  indentation levels by removing them. Taken from PEP 257 docs.
+
+  Args:
+    docstring: Python string to trim indentation.
+
+  Returns:
+    Trimmed docstring.
+  """
+  if not docstring:
+    return ''
+  # Convert tabs to spaces (following the normal Python rules)
+  # and split into a list of lines:
+  lines = docstring.expandtabs().splitlines()
+  # Determine minimum indentation (first line doesn't count):
+  indent = sys.maxint
+  for line in lines[1:]:
+    stripped = line.lstrip()
+    if stripped:
+      indent = min(indent, len(line) - len(stripped))
+  # Remove indentation (first line is special):
+  trimmed = [lines[0].strip()]
+  if indent < sys.maxint:
+    for line in lines[1:]:
+      trimmed.append(line[indent:].rstrip())
+  # Strip off trailing and leading blank lines:
+  while trimmed and not trimmed[-1]:
+    trimmed.pop()
+  while trimmed and not trimmed[0]:
+    trimmed.pop(0)
+  # Return a single string:
+  return '\n'.join(trimmed)
index 7723cfb..90219fd 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.bayesflow.python.ops import docstring_util
 from tensorflow.contrib.bayesflow.python.ops import layers_util
 from tensorflow.contrib.distributions.python.ops import independent as independent_lib
 from tensorflow.python.framework import dtypes
@@ -34,6 +35,45 @@ from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
 from tensorflow.python.ops.distributions import normal as normal_lib
 from tensorflow.python.ops.distributions import util as distribution_util
 
+doc_args = """  activation: Activation function. Set it to None to maintain a
+      linear activation.
+  activity_regularizer: Optional regularizer function for the output.
+  trainable: Boolean, if `True` also add variables to the graph collection
+    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+  kernel_posterior_fn: Python `callable` which creates
+    `tf.distributions.Distribution` instance representing the surrogate
+    posterior of the `kernel` parameter. Default value:
+    `default_mean_field_normal_fn()`.
+  kernel_posterior_tensor_fn: Python `callable` which takes a
+    `tf.distributions.Distribution` instance and returns a representative
+    value. Default value: `lambda d: d.sample()`.
+  kernel_prior_fn: Python `callable` which creates `tf.distributions`
+    instance. See `default_mean_field_normal_fn` docstring for required
+    parameter signature.
+    Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+  kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+    distribution, prior distribution and random variate sample(s) from the
+    surrogate posterior and computes or approximates the KL divergence. The
+    distributions are `tf.distributions.Distribution`-like instances and the
+    sample is a `Tensor`.
+  bias_posterior_fn: Python `callable` which creates
+    `tf.distributions.Distribution` instance representing the surrogate
+    posterior of the `bias` parameter. Default value:
+    `default_mean_field_normal_fn(is_singular=True)` (which creates an
+    instance of `tf.distributions.Deterministic`).
+  bias_posterior_tensor_fn: Python `callable` which takes a
+    `tf.distributions.Distribution` instance and returns a representative
+    value. Default value: `lambda d: d.sample()`.
+  bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+    See `default_mean_field_normal_fn` docstring for required parameter
+    signature. Default value: `None` (no prior, no variational inference)
+  bias_divergence_fn: Python `callable` which takes the surrogate posterior
+    distribution, prior distribution and random variate sample(s) from the
+    surrogate posterior and computes or approximates the KL divergence. The
+    distributions are `tf.distributions.Distribution`-like instances and the
+    sample is a `Tensor`.
+  name: A string, the name of the layer."""
+
 
 class _ConvVariational(layers_lib.Layer):
   """Abstract nD convolution layer (private, used as implementation base).
@@ -55,65 +95,6 @@ class _ConvVariational(layers_lib.Layer):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of n integers, specifying the
-      length of the convolution window.
-    strides: An integer or tuple/list of n integers,
-      specifying the stride length of the convolution.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, ..., channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, ...)`.
-    dilation_rate: An integer or tuple/list of n integers, specifying
-      the dilation rate to use for dilated convolution.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: A string, the name of the layer.
-
   Properties:
     rank: Python integer, dimensionality of convolution.
     filters: Python integer, dimensionality of the output space.
@@ -134,6 +115,7 @@ class _ConvVariational(layers_lib.Layer):
     bias_divergence_fn: `callable` returning divergence.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       rank,
@@ -157,6 +139,31 @@ class _ConvVariational(layers_lib.Layer):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      rank: An integer, the rank of the convolution, e.g. "2" for 2D
+        convolution.
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of n integers, specifying the
+        length of the convolution window.
+      strides: An integer or tuple/list of n integers,
+        specifying the stride length of the convolution.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, ...,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, ...)`.
+      dilation_rate: An integer or tuple/list of n integers, specifying
+        the dilation rate to use for dilated convolution.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any `strides` value != 1.
+    @{args}
+    """
     super(_ConvVariational, self).__init__(
         trainable=trainable,
         name=name,
@@ -371,65 +378,6 @@ class _ConvReparameterization(_ConvVariational):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of n integers, specifying the
-      length of the convolution window.
-    strides: An integer or tuple/list of n integers,
-      specifying the stride length of the convolution.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, ..., channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, ...)`.
-    dilation_rate: An integer or tuple/list of n integers, specifying
-      the dilation rate to use for dilated convolution.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: A string, the name of the layer.
-
   Properties:
     rank: Python integer, dimensionality of convolution.
     filters: Python integer, dimensionality of the output space.
@@ -454,6 +402,7 @@ class _ConvReparameterization(_ConvVariational):
         International Conference on Learning Representations, 2014.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       rank,
@@ -477,6 +426,31 @@ class _ConvReparameterization(_ConvVariational):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      rank: An integer, the rank of the convolution, e.g. "2" for 2D
+        convolution.
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of n integers, specifying the
+        length of the convolution window.
+      strides: An integer or tuple/list of n integers,
+        specifying the stride length of the convolution.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, ...,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, ...)`.
+      dilation_rate: An integer or tuple/list of n integers, specifying
+        the dilation rate to use for dilated convolution.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any `strides` value != 1.
+    @{args}
+    """
     super(_ConvReparameterization, self).__init__(
         rank=rank,
         filters=filters,
@@ -529,63 +503,6 @@ class Conv1DReparameterization(_ConvReparameterization):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of a single integer, specifying the
-      length of the 1D convolution window.
-    strides: An integer or tuple/list of a single integer,
-      specifying the stride length of the convolution.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, length, channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, length)`.
-    dilation_rate: An integer or tuple/list of a single integer, specifying
-      the dilation rate to use for dilated convolution.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    name: A string, the name of the layer.
-
   Properties:
     filters: Python integer, dimensionality of the output space.
     kernel_size: Size of the convolution window.
@@ -639,6 +556,7 @@ class Conv1DReparameterization(_ConvReparameterization):
         International Conference on Learning Representations, 2014.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       filters,
@@ -661,6 +579,29 @@ class Conv1DReparameterization(_ConvReparameterization):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of a single integer, specifying the
+        length of the 1D convolution window.
+      strides: An integer or tuple/list of a single integer,
+        specifying the stride length of the convolution.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, length,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, length)`.
+      dilation_rate: An integer or tuple/list of a single integer, specifying
+        the dilation rate to use for dilated convolution.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any `strides` value != 1.
+    @{args}
+    """
     super(Conv1DReparameterization, self).__init__(
         rank=1,
         filters=filters,
@@ -683,6 +624,7 @@ class Conv1DReparameterization(_ConvReparameterization):
         name=name, **kwargs)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def conv1d_reparameterization(
     inputs,
     filters,
@@ -726,7 +668,7 @@ def conv1d_reparameterization(
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
+  Args:
     inputs: Tensor input.
     filters: Integer, the dimensionality of the output space (i.e. the number
       of filters in the convolution).
@@ -746,43 +688,7 @@ def conv1d_reparameterization(
       the dilation rate to use for dilated convolution.
       Currently, specifying any `dilation_rate` value != 1 is
       incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    name: A string, the name of the layer.
+  @{args}
     reuse: Boolean, whether to reuse the weights of a previous layer
       by the same name.
 
@@ -874,70 +780,6 @@ class Conv2DReparameterization(_ConvReparameterization):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of 2 integers, specifying the
-      height and width of the 2D convolution window.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-    strides: An integer or tuple/list of 2 integers,
-      specifying the strides of the convolution along the height and width.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, height, width, channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, height, width)`.
-
-    dilation_rate: An integer or tuple/list of 2 integers, specifying
-      the dilation rate to use for dilated convolution.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    name: A string, the name of the layer.
-
   Properties:
     filters: Python integer, dimensionality of the output space.
     kernel_size: Size of the convolution window.
@@ -994,6 +836,7 @@ class Conv2DReparameterization(_ConvReparameterization):
         International Conference on Learning Representations, 2014.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       filters,
@@ -1016,6 +859,35 @@ class Conv2DReparameterization(_ConvReparameterization):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of 2 integers, specifying the
+        height and width of the 2D convolution window.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+      strides: An integer or tuple/list of 2 integers,
+        specifying the strides of the convolution along the height and width.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, height,
+        width, channels)` while `channels_first` corresponds to inputs with
+        shape `(batch, channels, height, width)`.
+      dilation_rate: An integer or tuple/list of 2 integers, specifying
+        the dilation rate to use for dilated convolution.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any stride value != 1.
+    @{args}
+    """
     super(Conv2DReparameterization, self).__init__(
         rank=2,
         filters=filters,
@@ -1038,6 +910,7 @@ class Conv2DReparameterization(_ConvReparameterization):
         name=name, **kwargs)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def conv2d_reparameterization(
     inputs,
     filters,
@@ -1081,7 +954,7 @@ def conv2d_reparameterization(
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
+  Args:
     inputs: Tensor input.
     filters: Integer, the dimensionality of the output space (i.e. the number
       of filters in the convolution).
@@ -1101,50 +974,13 @@ def conv2d_reparameterization(
       `channels_last` corresponds to inputs with shape
       `(batch, height, width, channels)` while `channels_first` corresponds to
       inputs with shape `(batch, channels, height, width)`.
-
     dilation_rate: An integer or tuple/list of 2 integers, specifying
       the dilation rate to use for dilated convolution.
       Can be a single integer to specify the same value for
       all spatial dimensions.
       Currently, specifying any `dilation_rate` value != 1 is
       incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    name: A string, the name of the layer.
+  @{args}
     reuse: Boolean, whether to reuse the weights of a previous layer
       by the same name.
 
@@ -1240,71 +1076,6 @@ class Conv3DReparameterization(_ConvReparameterization):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of 3 integers, specifying the
-      depth, height and width of the 3D convolution window.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-    strides: An integer or tuple/list of 3 integers,
-      specifying the strides of the convolution along the depth,
-      height and width.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, depth, height, width, channels)` while `channels_first`
-      corresponds to inputs with shape
-      `(batch, channels, depth, height, width)`.
-    dilation_rate: An integer or tuple/list of 3 integers, specifying
-      the dilation rate to use for dilated convolution.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    name: A string, the name of the layer.
-
   Properties:
     filters: Python integer, dimensionality of the output space.
     kernel_size: Size of the convolution window.
@@ -1361,6 +1132,7 @@ class Conv3DReparameterization(_ConvReparameterization):
         International Conference on Learning Representations, 2014.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       filters,
@@ -1383,6 +1155,36 @@ class Conv3DReparameterization(_ConvReparameterization):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of 3 integers, specifying the
+        depth, height and width of the 3D convolution window.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+      strides: An integer or tuple/list of 3 integers,
+        specifying the strides of the convolution along the depth,
+        height and width.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, depth,
+        height, width, channels)` while `channels_first` corresponds to inputs
+        with shape `(batch, channels, depth, height, width)`.
+      dilation_rate: An integer or tuple/list of 3 integers, specifying
+        the dilation rate to use for dilated convolution.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any stride value != 1.
+    @{args}
+    """
     super(Conv3DReparameterization, self).__init__(
         rank=3,
         filters=filters,
@@ -1405,6 +1207,7 @@ class Conv3DReparameterization(_ConvReparameterization):
         name=name, **kwargs)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def conv3d_reparameterization(
     inputs,
     filters,
@@ -1448,7 +1251,7 @@ def conv3d_reparameterization(
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
+  Args:
     inputs: Tensor input.
     filters: Integer, the dimensionality of the output space (i.e. the number
       of filters in the convolution).
@@ -1476,43 +1279,7 @@ def conv3d_reparameterization(
       all spatial dimensions.
       Currently, specifying any `dilation_rate` value != 1 is
       incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    name: A string, the name of the layer.
+  @{args}
     reuse: Boolean, whether to reuse the weights of a previous layer
       by the same name.
 
@@ -1611,67 +1378,6 @@ class _ConvFlipout(_ConvVariational):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of n integers, specifying the
-      length of the convolution window.
-    strides: An integer or tuple/list of n integers,
-      specifying the stride length of the convolution.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, ..., channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, ...)`.
-    dilation_rate: An integer or tuple/list of n integers, specifying
-      the dilation rate to use for dilated convolution.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
-
   Properties:
     rank: Python integer, dimensionality of convolution.
     filters: Python integer, dimensionality of the output space.
@@ -1694,10 +1400,11 @@ class _ConvFlipout(_ConvVariational):
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       rank,
@@ -1722,6 +1429,31 @@ class _ConvFlipout(_ConvVariational):
       seed=None,
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      rank: An integer, the rank of the convolution, e.g. "2" for 2D
+        convolution.
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of n integers, specifying the
+        length of the convolution window.
+      strides: An integer or tuple/list of n integers,
+        specifying the stride length of the convolution.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, ...,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, ...)`.
+      dilation_rate: An integer or tuple/list of n integers, specifying
+        the dilation rate to use for dilated convolution.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any `strides` value != 1.
+    @{args}
+    """
     super(_ConvFlipout, self).__init__(
         rank=rank,
         filters=filters,
@@ -1822,65 +1554,6 @@ class Conv1DFlipout(_ConvFlipout):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of a single integer, specifying the
-      length of the 1D convolution window.
-    strides: An integer or tuple/list of a single integer,
-      specifying the stride length of the convolution.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, length, channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, length)`.
-    dilation_rate: An integer or tuple/list of a single integer, specifying
-      the dilation rate to use for dilated convolution.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
-
   Properties:
     filters: Python integer, dimensionality of the output space.
     kernel_size: Size of the convolution window.
@@ -1932,10 +1605,11 @@ class Conv1DFlipout(_ConvFlipout):
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       filters,
@@ -1959,6 +1633,29 @@ class Conv1DFlipout(_ConvFlipout):
       seed=None,
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of a single integer, specifying the
+        length of the 1D convolution window.
+      strides: An integer or tuple/list of a single integer,
+        specifying the stride length of the convolution.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, length,
+        channels)` while `channels_first` corresponds to inputs with shape
+        `(batch, channels, length)`.
+      dilation_rate: An integer or tuple/list of a single integer, specifying
+        the dilation rate to use for dilated convolution.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any `strides` value != 1.
+    @{args}
+    """
     super(Conv1DFlipout, self).__init__(
         rank=1,
         filters=filters,
@@ -1982,6 +1679,7 @@ class Conv1DFlipout(_ConvFlipout):
         name=name, **kwargs)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def conv1d_flipout(
     inputs,
     filters,
@@ -2029,7 +1727,7 @@ def conv1d_flipout(
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
+  Args:
     inputs: Tensor input.
     filters: Integer, the dimensionality of the output space (i.e. the number
       of filters in the convolution).
@@ -2049,45 +1747,7 @@ def conv1d_flipout(
       the dilation rate to use for dilated convolution.
       Currently, specifying any `dilation_rate` value != 1 is
       incompatible with specifying any `strides` value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
+  @{args}
     reuse: Boolean, whether to reuse the weights of a previous layer
       by the same name.
 
@@ -2130,8 +1790,8 @@ def conv1d_flipout(
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
   layer = Conv1DFlipout(
       filters=filters,
@@ -2184,72 +1844,6 @@ class Conv2DFlipout(_ConvFlipout):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of 2 integers, specifying the
-      height and width of the 2D convolution window.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-    strides: An integer or tuple/list of 2 integers,
-      specifying the strides of the convolution along the height and width.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, height, width, channels)` while `channels_first` corresponds to
-      inputs with shape `(batch, channels, height, width)`.
-
-    dilation_rate: An integer or tuple/list of 2 integers, specifying
-      the dilation rate to use for dilated convolution.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
-
   Properties:
     filters: Python integer, dimensionality of the output space.
     kernel_size: Size of the convolution window.
@@ -2304,10 +1898,11 @@ class Conv2DFlipout(_ConvFlipout):
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       filters,
@@ -2331,6 +1926,35 @@ class Conv2DFlipout(_ConvFlipout):
       seed=None,
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of 2 integers, specifying the
+        height and width of the 2D convolution window.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+      strides: An integer or tuple/list of 2 integers,
+        specifying the strides of the convolution along the height and width.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, height,
+        width, channels)` while `channels_first` corresponds to inputs with
+        shape `(batch, channels, height, width)`.
+      dilation_rate: An integer or tuple/list of 2 integers, specifying
+        the dilation rate to use for dilated convolution.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any stride value != 1.
+    @{args}
+    """
     super(Conv2DFlipout, self).__init__(
         rank=2,
         filters=filters,
@@ -2354,6 +1978,7 @@ class Conv2DFlipout(_ConvFlipout):
         name=name, **kwargs)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def conv2d_flipout(
     inputs,
     filters,
@@ -2401,7 +2026,7 @@ def conv2d_flipout(
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
+  Args:
     inputs: Tensor input.
     filters: Integer, the dimensionality of the output space (i.e. the number
       of filters in the convolution).
@@ -2421,52 +2046,13 @@ def conv2d_flipout(
       `channels_last` corresponds to inputs with shape
       `(batch, height, width, channels)` while `channels_first` corresponds to
       inputs with shape `(batch, channels, height, width)`.
-
     dilation_rate: An integer or tuple/list of 2 integers, specifying
       the dilation rate to use for dilated convolution.
       Can be a single integer to specify the same value for
       all spatial dimensions.
       Currently, specifying any `dilation_rate` value != 1 is
       incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
+  @{args}
     reuse: Boolean, whether to reuse the weights of a previous layer
       by the same name.
 
@@ -2513,8 +2099,8 @@ def conv2d_flipout(
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
   layer = Conv2DFlipout(
       filters=filters,
@@ -2567,73 +2153,6 @@ class Conv3DFlipout(_ConvFlipout):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
-    filters: Integer, the dimensionality of the output space (i.e. the number
-      of filters in the convolution).
-    kernel_size: An integer or tuple/list of 3 integers, specifying the
-      depth, height and width of the 3D convolution window.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-    strides: An integer or tuple/list of 3 integers,
-      specifying the strides of the convolution along the depth,
-      height and width.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Specifying any stride value != 1 is incompatible with specifying
-      any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
-    data_format: A string, one of `channels_last` (default) or `channels_first`.
-      The ordering of the dimensions in the inputs.
-      `channels_last` corresponds to inputs with shape
-      `(batch, depth, height, width, channels)` while `channels_first`
-      corresponds to inputs with shape
-      `(batch, channels, depth, height, width)`.
-    dilation_rate: An integer or tuple/list of 3 integers, specifying
-      the dilation rate to use for dilated convolution.
-      Can be a single integer to specify the same value for
-      all spatial dimensions.
-      Currently, specifying any `dilation_rate` value != 1 is
-      incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
-
   Properties:
     filters: Python integer, dimensionality of the output space.
     kernel_size: Size of the convolution window.
@@ -2688,10 +2207,11 @@ class Conv3DFlipout(_ConvFlipout):
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       filters,
@@ -2715,6 +2235,36 @@ class Conv3DFlipout(_ConvFlipout):
       seed=None,
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+      filters: Integer, the dimensionality of the output space (i.e. the number
+        of filters in the convolution).
+      kernel_size: An integer or tuple/list of 3 integers, specifying the
+        depth, height and width of the 3D convolution window.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+      strides: An integer or tuple/list of 3 integers,
+        specifying the strides of the convolution along the depth,
+        height and width.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+      padding: One of `"valid"` or `"same"` (case-insensitive).
+      data_format: A string, one of `channels_last` (default) or
+        `channels_first`. The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape `(batch, depth,
+        height, width, channels)` while `channels_first` corresponds to inputs
+        with shape `(batch, channels, depth, height, width)`.
+      dilation_rate: An integer or tuple/list of 3 integers, specifying
+        the dilation rate to use for dilated convolution.
+        Can be a single integer to specify the same value for
+        all spatial dimensions.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any stride value != 1.
+    @{args}
+    """
     super(Conv3DFlipout, self).__init__(
         rank=3,
         filters=filters,
@@ -2738,6 +2288,7 @@ class Conv3DFlipout(_ConvFlipout):
         name=name, **kwargs)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def conv3d_flipout(
     inputs,
     filters,
@@ -2785,7 +2336,7 @@ def conv3d_flipout(
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Arguments:
+  Args:
     inputs: Tensor input.
     filters: Integer, the dimensionality of the output space (i.e. the number
       of filters in the convolution).
@@ -2813,45 +2364,7 @@ def conv3d_flipout(
       all spatial dimensions.
       Currently, specifying any `dilation_rate` value != 1 is
       incompatible with specifying any stride value != 1.
-    activation: Activation function. Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Optional regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: A string, the name of the layer.
+  @{args}
     reuse: Boolean, whether to reuse the weights of a previous layer
       by the same name.
 
@@ -2898,8 +2411,8 @@ def conv3d_flipout(
 
   [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
         Mini-Batches."
-        Anonymous. OpenReview, 2017.
-        https://openreview.net/forum?id=rJnpifWAb
+        Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+        International Conference on Learning Representations, 2018.
   """
   layer = Conv3DFlipout(
       filters=filters,
index 591a8e5..1e4a445 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.bayesflow.python.ops import docstring_util
 from tensorflow.contrib.bayesflow.python.ops import layers_util
 from tensorflow.contrib.distributions.python.ops import independent as independent_lib
 from tensorflow.python.framework import dtypes
@@ -33,6 +34,53 @@ from tensorflow.python.ops.distributions import normal as normal_lib
 from tensorflow.python.ops.distributions import util as distribution_util
 
 
+doc_args = """  units: Integer or Long, dimensionality of the output space.
+  activation: Activation function (`callable`). Set it to None to maintain a
+    linear activation.
+  activity_regularizer: Regularizer function for the output.
+  trainable: Boolean, if `True` also add variables to the graph collection
+    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+  kernel_posterior_fn: Python `callable` which creates
+    `tf.distributions.Distribution` instance representing the surrogate
+    posterior of the `kernel` parameter. Default value:
+    `default_mean_field_normal_fn()`.
+  kernel_posterior_tensor_fn: Python `callable` which takes a
+    `tf.distributions.Distribution` instance and returns a representative
+    value. Default value: `lambda d: d.sample()`.
+  kernel_prior_fn: Python `callable` which creates `tf.distributions`
+    instance. See `default_mean_field_normal_fn` docstring for required
+    parameter signature.
+    Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+  kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+    distribution, prior distribution and random variate sample(s) from the
+    surrogate posterior and computes or approximates the KL divergence. The
+    distributions are `tf.distributions.Distribution`-like instances and the
+    sample is a `Tensor`.
+  bias_posterior_fn: Python `callable` which creates
+    `tf.distributions.Distribution` instance representing the surrogate
+    posterior of the `bias` parameter. Default value:
+    `default_mean_field_normal_fn(is_singular=True)` (which creates an
+    instance of `tf.distributions.Deterministic`).
+  bias_posterior_tensor_fn: Python `callable` which takes a
+    `tf.distributions.Distribution` instance and returns a representative
+    value. Default value: `lambda d: d.sample()`.
+  bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+    See `default_mean_field_normal_fn` docstring for required parameter
+    signature. Default value: `None` (no prior, no variational inference)
+  bias_divergence_fn: Python `callable` which takes the surrogate posterior
+    distribution, prior distribution and random variate sample(s) from the
+    surrogate posterior and computes or approximates the KL divergence. The
+    distributions are `tf.distributions.Distribution`-like instances and the
+    sample is a `Tensor`.
+  seed: Python scalar `int` which initializes the random number
+    generator. Default value: `None` (i.e., use global seed).
+  name: Python `str`, the name of the layer. Layers with the same name will
+    share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
+    such cases.
+  reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
+    layer by the same name."""
+
+
 class _DenseVariational(layers_lib.Layer):
   """Abstract densely-connected class (private, used as implementation base).
 
@@ -50,51 +98,6 @@ class _DenseVariational(layers_lib.Layer):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Args:
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
-
   Properties:
     units: Python integer, dimensionality of the output space.
     activation: Activation function (`callable`).
@@ -109,6 +112,7 @@ class _DenseVariational(layers_lib.Layer):
     bias_divergence_fn: `callable` returning divergence.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       units,
@@ -126,6 +130,11 @@ class _DenseVariational(layers_lib.Layer):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+    @{args}
+    """
     super(_DenseVariational, self).__init__(
         trainable=trainable,
         name=name,
@@ -274,51 +283,6 @@ class DenseReparameterization(_DenseVariational):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Args:
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
-
   Properties:
     units: Python integer, dimensionality of the output space.
     activation: Activation function (`callable`).
@@ -363,6 +327,7 @@ class DenseReparameterization(_DenseVariational):
         International Conference on Learning Representations, 2014.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       units,
@@ -381,6 +346,11 @@ class DenseReparameterization(_DenseVariational):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+    @{args}
+    """
     super(DenseReparameterization, self).__init__(
         units=units,
         activation=activation,
@@ -405,6 +375,7 @@ class DenseReparameterization(_DenseVariational):
     return self._matmul(inputs, self.kernel_posterior_tensor)
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def dense_reparameterization(
     inputs,
     units,
@@ -444,49 +415,7 @@ def dense_reparameterization(
 
   Args:
     inputs: Tensor input.
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
+  @{args}
 
   Returns:
     output: `Tensor` representing a the affine transformed input under a random
@@ -563,51 +492,6 @@ class DenseLocalReparameterization(_DenseVariational):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Args:
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
-
   Properties:
     units: Python integer, dimensionality of the output space.
     activation: Activation function (`callable`).
@@ -652,6 +536,7 @@ class DenseLocalReparameterization(_DenseVariational):
         Neural Information Processing Systems, 2015.
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       units,
@@ -670,6 +555,11 @@ class DenseLocalReparameterization(_DenseVariational):
       bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+    @{args}
+    """
     super(DenseLocalReparameterization, self).__init__(
         units=units,
         activation=activation,
@@ -705,6 +595,7 @@ class DenseLocalReparameterization(_DenseVariational):
     return self.kernel_posterior_affine_tensor
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def dense_local_reparameterization(
     inputs,
     units,
@@ -745,49 +636,7 @@ def dense_local_reparameterization(
 
   Args:
     inputs: Tensor input.
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
+  @{args}
 
   Returns:
     output: `Tensor` representing a the affine transformed input under a random
@@ -866,53 +715,6 @@ class DenseFlipout(_DenseVariational):
   (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
   distributions.
 
-  Args:
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
-
   Properties:
     units: Python integer, dimensionality of the output space.
     activation: Activation function (`callable`).
@@ -959,6 +761,7 @@ class DenseFlipout(_DenseVariational):
         https://openreview.net/forum?id=rJnpifWAb
   """
 
+  @docstring_util.expand_docstring(args=doc_args)
   def __init__(
       self,
       units,
@@ -978,6 +781,11 @@ class DenseFlipout(_DenseVariational):
       seed=None,
       name=None,
       **kwargs):
+    """Construct layer.
+
+    Args:
+    @{args}
+    """
     super(DenseFlipout, self).__init__(
         units=units,
         activation=activation,
@@ -1031,6 +839,7 @@ class DenseFlipout(_DenseVariational):
     return outputs
 
 
+@docstring_util.expand_docstring(args=doc_args)
 def dense_flipout(
     inputs,
     units,
@@ -1074,51 +883,7 @@ def dense_flipout(
 
   Args:
     inputs: Tensor input.
-    units: Integer or Long, dimensionality of the output space.
-    activation: Activation function (`callable`). Set it to None to maintain a
-      linear activation.
-    activity_regularizer: Regularizer function for the output.
-    trainable: Boolean, if `True` also add variables to the graph collection
-      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
-    kernel_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `kernel` parameter. Default value:
-      `default_mean_field_normal_fn()`.
-    kernel_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    kernel_prior_fn: Python `callable` which creates `tf.distributions`
-      instance. See `default_mean_field_normal_fn` docstring for required
-      parameter signature.
-      Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
-    kernel_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    bias_posterior_fn: Python `callable` which creates
-      `tf.distributions.Distribution` instance representing the surrogate
-      posterior of the `bias` parameter. Default value:
-      `default_mean_field_normal_fn(is_singular=True)` (which creates an
-      instance of `tf.distributions.Deterministic`).
-    bias_posterior_tensor_fn: Python `callable` which takes a
-      `tf.distributions.Distribution` instance and returns a representative
-      value. Default value: `lambda d: d.sample()`.
-    bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
-      See `default_mean_field_normal_fn` docstring for required parameter
-      signature. Default value: `None` (no prior, no variational inference)
-    bias_divergence_fn: Python `callable` which takes the surrogate posterior
-      distribution, prior distribution and random variate sample(s) from the
-      surrogate posterior and computes or approximates the KL divergence. The
-      distributions are `tf.distributions.Distribution`-like instances and the
-      sample is a `Tensor`.
-    seed: Python scalar `int` which initializes the random number
-      generator. Default value: `None` (i.e., use global seed).
-    name: Python `str`, the name of the layer. Layers with the same name will
-      share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
-      such cases.
-    reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
-      layer by the same name.
+  @{args}
 
   Returns:
     output: `Tensor` representing a the affine transformed input under a random