Add BatchNorm bijector.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Mar 2018 06:51:17 +0000 (22:51 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 06:55:00 +0000 (22:55 -0800)
PiperOrigin-RevId: 187975255

tensorflow/contrib/distributions/BUILD
tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py [new file with mode: 0644]

index d81dfc2..84f74ce 100644 (file)
@@ -832,6 +832,22 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "batch_normalization_test",
+    size = "small",
+    srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"],
+    additional_deps = [
+        ":bijectors_py",
+        ":distributions_py",
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_test(
     name = "chain_test",
     size = "small",
     srcs = ["python/kernel_tests/bijectors/chain_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
new file mode 100644 (file)
index 0000000..a215a4a
--- /dev/null
@@ -0,0 +1,236 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BatchNorm Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import distributions
+from tensorflow.contrib.distributions.python.ops import test_util
+from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import BatchNormalization
+from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import normalization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+
+
+class BatchNormTest(test_util.VectorDistributionTestHelpers,
+                    test.TestCase):
+
+  def _reduction_axes(self, input_shape, event_dims):
+    if isinstance(event_dims, int):
+      event_dims = [event_dims]
+    ndims = len(input_shape)
+    # Convert event_dims to non-negative indexing.
+    event_dims = list(event_dims)
+    for idx, x in enumerate(event_dims):
+      if x < 0:
+        event_dims[idx] = ndims + x
+    return tuple(i for i in range(ndims) if i not in event_dims)
+
+  def testForwardInverse(self):
+    """Tests forward and backward passes with different event shapes.
+
+    input_shape: Tuple of shapes for input tensor.
+    event_dims: Tuple of dimension indices that will be normalized.
+    training: Boolean of whether bijector runs in training or inference mode.
+    """
+    params = [
+        ((5*2, 4), [-1], False),
+        ((5, 2, 4), [-1], False),
+        ((5, 2, 4), [1, 2], False),
+        ((5, 2, 4), [0, 1], False),
+        ((5*2, 4), [-1], True),
+        ((5, 2, 4), [-1], True),
+        ((5, 2, 4), [1, 2], True),
+        ((5, 2, 4), [0, 1], True)
+    ]
+    for input_shape, event_dims, training in params:
+      x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape)
+      with self.test_session() as sess:
+        x = constant_op.constant(x_)
+        # When training, memorize the exact mean of the last
+        # minibatch that it normalized (instead of moving average assignment).
+        layer = normalization.BatchNormalization(
+            axis=event_dims, momentum=0., epsilon=0.)
+        batch_norm = BatchNormalization(
+            batchnorm_layer=layer, training=training)
+        # Minibatch statistics are saved only after norm_x has been computed.
+        norm_x = batch_norm.inverse(x)
+        with ops.control_dependencies(batch_norm.batchnorm.updates):
+          moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean)
+          moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance)
+          denorm_x = batch_norm.forward(array_ops.identity(norm_x))
+          fldj = batch_norm.forward_log_det_jacobian(x)
+          # Use identity to invalidate cache.
+          ildj = batch_norm.inverse_log_det_jacobian(
+              array_ops.identity(denorm_x))
+        variables.global_variables_initializer().run()
+        # Update variables.
+        norm_x_ = sess.run(norm_x)
+        [
+            norm_x_,
+            moving_mean_,
+            moving_var_,
+            denorm_x_,
+            ildj_,
+            fldj_,
+        ] = sess.run([
+            norm_x,
+            moving_mean,
+            moving_var,
+            denorm_x,
+            ildj,
+            fldj,
+        ])
+        self.assertEqual("batch_normalization", batch_norm.name)
+
+        reduction_axes = self._reduction_axes(input_shape, event_dims)
+        keepdims = len(event_dims) > 1
+
+        expected_batch_mean = np.mean(
+            x_, axis=reduction_axes, keepdims=keepdims)
+        expected_batch_var = np.var(x_, axis=reduction_axes, keepdims=keepdims)
+
+        if training:
+          # When training=True, values become normalized across batch dim and
+          # original values are recovered after de-normalizing.
+          zeros = np.zeros_like(norm_x_)
+          self.assertAllClose(np.mean(zeros, axis=reduction_axes),
+                              np.mean(norm_x_, axis=reduction_axes))
+
+          self.assertAllClose(expected_batch_mean, moving_mean_)
+          self.assertAllClose(expected_batch_var, moving_var_)
+          self.assertAllClose(x_, denorm_x_, atol=1e-5)
+          # Since moving statistics are set to batch statistics after
+          # normalization, ildj and -fldj should match.
+          self.assertAllClose(ildj_, -fldj_)
+          # ildj is computed with minibatch statistics.
+          expected_ildj = np.sum(np.log(1.) - .5 * np.log(
+              expected_batch_var + batch_norm.batchnorm.epsilon))
+          self.assertAllClose(expected_ildj, ildj_)
+        else:
+          # When training=False, moving_mean, moving_var remain at their
+          # initialized values (0., 1.), resulting in no scale/shift (a small
+          # shift occurs if epsilon > 0.)
+          self.assertAllClose(x_, norm_x_)
+          self.assertAllClose(x_, denorm_x_, atol=1e-5)
+          # ildj is computed with saved statistics.
+          expected_ildj = np.sum(
+              np.log(1.) - .5 * np.log(1. + batch_norm.batchnorm.epsilon))
+          self.assertAllClose(expected_ildj, ildj_)
+
+  def testMaximumLikelihoodTraining(self):
+    # Test Maximum Likelihood training with default bijector.
+    with self.test_session() as sess:
+      base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
+      batch_norm = BatchNormalization(training=True)
+      dist = transformed_distribution_lib.TransformedDistribution(
+          distribution=base_dist,
+          bijector=batch_norm)
+      target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.])
+      target_samples = target_dist.sample(100)
+      dist_samples = dist.sample(3000)
+      loss = -math_ops.reduce_mean(dist.log_prob(target_samples))
+      with ops.control_dependencies(batch_norm.batchnorm.updates):
+        train_op = adam.AdamOptimizer(1e-2).minimize(loss)
+        moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean)
+        moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance)
+      variables.global_variables_initializer().run()
+      for _ in range(3000):
+        sess.run(train_op)
+      [
+          dist_samples_,
+          moving_mean_,
+          moving_var_
+      ] = sess.run([
+          dist_samples,
+          moving_mean,
+          moving_var
+      ])
+      self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2)
+      self.assertAllClose([1., 2.], moving_mean_, atol=5e-2)
+      self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
+
+  def testLogProb(self):
+    with self.test_session() as sess:
+      layer = normalization.BatchNormalization(epsilon=0.)
+      batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
+      base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
+      dist = transformed_distribution_lib.TransformedDistribution(
+          distribution=base_dist,
+          bijector=batch_norm,
+          validate_args=True)
+      samples = dist.sample(int(1e5))
+      # No volume distortion since training=False, bijector is initialized
+      # to the identity transformation.
+      base_log_prob = base_dist.log_prob(samples)
+      dist_log_prob = dist.log_prob(samples)
+      variables.global_variables_initializer().run()
+      base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob])
+      self.assertAllClose(base_log_prob_, dist_log_prob_)
+
+  def testMutuallyConsistent(self):
+    # BatchNorm bijector is only mutually consistent when training=False.
+    dims = 4
+    with self.test_session() as sess:
+      layer = normalization.BatchNormalization(epsilon=0.)
+      batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
+      dist = transformed_distribution_lib.TransformedDistribution(
+          distribution=normal_lib.Normal(loc=0., scale=1.),
+          bijector=batch_norm,
+          event_shape=[dims],
+          validate_args=True)
+      self.run_test_sample_consistent_log_prob(
+          sess_run_fn=sess.run,
+          dist=dist,
+          num_samples=int(1e5),
+          radius=2.,
+          center=0.,
+          rtol=0.02)
+
+  def testInvertMutuallyConsistent(self):
+    # BatchNorm bijector is only mutually consistent when training=False.
+    dims = 4
+    with self.test_session() as sess:
+      layer = normalization.BatchNormalization(epsilon=0.)
+      batch_norm = Invert(
+          BatchNormalization(batchnorm_layer=layer, training=False))
+      dist = transformed_distribution_lib.TransformedDistribution(
+          distribution=normal_lib.Normal(loc=0., scale=1.),
+          bijector=batch_norm,
+          event_shape=[dims],
+          validate_args=True)
+      self.run_test_sample_consistent_log_prob(
+          sess_run_fn=sess.run,
+          dist=dist,
+          num_samples=int(1e5),
+          radius=2.,
+          center=0.,
+          rtol=0.02)
+
+
+if __name__ == "__main__":
+  test.main()
index 9437f56..46ec497 100644 (file)
@@ -18,6 +18,7 @@
 @@Affine
 @@AffineLinearOperator
 @@Bijector
+@@BatchNormalization
 @@Chain
 @@CholeskyOuterProduct
 @@ConditionalBijector
@@ -53,6 +54,7 @@ from __future__ import print_function
 from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import *
 from tensorflow.contrib.distributions.python.ops.bijectors.affine import *
 from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import *
+from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import *
 from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
 from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
 from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py
new file mode 100644 (file)
index 0000000..e47a3e0
--- /dev/null
@@ -0,0 +1,259 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Batch Norm bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import normalization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+    "BatchNormalization",
+]
+
+
+def _undo_batch_normalization(x,
+                              mean,
+                              variance,
+                              offset,
+                              scale,
+                              variance_epsilon,
+                              name=None):
+  r"""Inverse of tf.nn.batch_normalization.
+
+  Args:
+    x: Input `Tensor` of arbitrary dimensionality.
+    mean: A mean `Tensor`.
+    variance: A variance `Tensor`.
+    offset: An offset `Tensor`, often denoted `beta` in equations, or
+      None. If present, will be added to the normalized tensor.
+    scale: A scale `Tensor`, often denoted `gamma` in equations, or
+      `None`. If present, the scale is applied to the normalized tensor.
+    variance_epsilon: A small `float` added to the minibatch `variance` to
+      prevent dividing by zero.
+    name: A name for this operation (optional).
+
+  Returns:
+    batch_unnormalized: The de-normalized, de-scaled, de-offset `Tensor`.
+  """
+  with ops.name_scope(
+      name, "undo_batchnorm", [x, mean, variance, scale, offset]):
+    # inv = math_ops.rsqrt(variance + variance_epsilon)
+    # if scale is not None:
+    #   inv *= scale
+    # return x * inv + (
+    #     offset - mean * inv if offset is not None else -mean * inv)
+    rescale = math_ops.sqrt(variance + variance_epsilon)
+    if scale is not None:
+      rescale /= scale
+    batch_unnormalized = x * rescale + (
+        mean - offset * rescale if offset is not None else mean)
+    return batch_unnormalized
+
+
+class BatchNormalization(bijector.Bijector):
+  """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`.
+
+  Applies Batch Normalization [1] to samples from a data distribution. This can
+  be used to stabilize training of normalizing flows [2, 3].
+
+  When training Deep Neural Networks (DNNs), it is common practice to
+  normalize or whiten features by shifting them to have zero mean and
+  scaling them to have unit variance.
+
+  The `inverse()` method of the BatchNorm bijector, which is used in the
+  log-likelihood computation of data samples, implements the normalization
+  procedure (shift-and-scale) using the mean and standard deviation of the
+  current minibatch.
+
+  Conversely, the `forward()` method of the bijector de-normalizes samples (e.g.
+  `X*std(Y) + mean(Y)` with the running-average mean and standard deviation
+  computed at training-time. De-normalization is useful for sampling.
+
+
+  ```python
+
+  dist = tfd.TransformedDistribution(
+      distribution=tfd.Normal()),
+      bijector=tfb.BatchNorm())
+
+  y = tfd.MultivariateNormalDiag(loc=1., scale=2.).sample(100)  # ~ N(1, 2)
+  x = dist.bijector.inverse(y)  # ~ N(0, 1)
+  y = dist.sample()  # ~ N(1, 2)
+  ```
+
+  During training time, `BatchNorm.inverse` and `BatchNorm.forward` are not
+  guaranteed to be inverses of each other because `inverse(y)` uses statistics
+  of the current minibatch, while `forward(x)` uses running-average statistics
+  accumulated from training. In other words,
+  `BatchNorm.inverse(BatchNorm.forward(...))` and
+  `BatchNorm.forward(BatchNorm.inverse(...))` will be identical when
+  `training=False` but may be different when `training=True`.
+
+  [1]: "Batch Normalization: Accelerating Deep Network Training by Reducing
+       Internal Covariate Shift."
+       Sergey Ioffe, Christian Szegedy. Arxiv. 2015.
+       https://arxiv.org/abs/1502.03167
+
+  [2]: "Density Estimation using Real NVP."
+     Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017.
+     https://arxiv.org/abs/1605.08803
+
+  [3]: "Masked Autoregressive Flow for Density Estimation."
+       George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017.
+       https://arxiv.org/abs/1705.07057
+
+  """
+
+  def __init__(self,
+               batchnorm_layer=None,
+               training=True,
+               validate_args=False,
+               name="batch_normalization"):
+    """Instantiates the `BatchNorm` bijector.
+
+    Args:
+      batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`,
+        defaults to
+        `tf.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + 1e-6)`.
+        This ensures positivity of the scale variable.
+
+      training: If True, updates running-average statistics during call to
+        `inverse()`.
+      validate_args: Python `bool` indicating whether arguments should be
+        checked for correctness.
+      name: Python `str` name given to ops managed by this object.
+    Raises:
+      ValueError: If bn_layer is not an instance of
+        `tf.layers.BatchNormalization`, or if it is specified with `renorm=True`
+        or a virtual batch size.
+    """
+    # Scale must be positive.
+    g_constraint = lambda x: nn.relu(x) + 1e-6
+    self.batchnorm = batchnorm_layer or normalization.BatchNormalization(
+        gamma_constraint=g_constraint)
+    self._validate_bn_layer(self.batchnorm)
+    self._training = training
+    super(BatchNormalization, self).__init__(
+        validate_args=validate_args, name=name)
+
+  def _validate_bn_layer(self, layer):
+    """Check for valid BatchNormalization layer.
+
+    Args:
+      layer: Instance of `tf.layers.BatchNormalization`.
+    Raises:
+      ValueError: If batchnorm_layer argument is not an instance of
+      `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or
+      if `batchnorm_layer.virtual_batch_size` is specified.
+    """
+    if not isinstance(layer, normalization.BatchNormalization):
+      raise ValueError(
+          "batchnorm_layer must be an instance of BatchNormalization layer.")
+    if layer.renorm:
+      raise ValueError("BatchNorm Bijector does not support renormalization.")
+    if layer.virtual_batch_size:
+      raise ValueError(
+          "BatchNorm Bijector does not support virtual batch sizes.")
+
+  def _get_broadcast_fn(self, x):
+    # Compute shape to broadcast scale/shift parameters to.
+    if not x.shape.is_fully_defined():
+      raise ValueError("Input must have shape known at graph construction.")
+    input_shape = np.int32(x.shape.as_list())
+
+    ndims = len(input_shape)
+    # event_dims = self._compute_event_dims(x)
+    reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis]
+    # Broadcasting only necessary for single-axis batch norm where the axis is
+    # not the last dimension
+    broadcast_shape = [1] * ndims
+    # import pdb; pdb.set_trace()
+    broadcast_shape[self.batchnorm.axis[0]] = (
+        input_shape[self.batchnorm.axis[0]])
+    def _broadcast(v):
+      if (v is not None and
+          len(v.get_shape()) != ndims and
+          reduction_axes != list(range(ndims - 1))):
+        return array_ops.reshape(v, broadcast_shape)
+      return v
+    return _broadcast
+
+  def _normalize(self, y):
+    return self.batchnorm.apply(y, training=self._training)
+
+  def _de_normalize(self, x):
+    # Uses the saved statistics.
+    if not self.batchnorm.built:
+      input_shape = x.get_shape()
+      self.batchnorm.build(input_shape)
+    broadcast_fn = self._get_broadcast_fn(x)
+    mean = broadcast_fn(self.batchnorm.moving_mean)
+    variance = broadcast_fn(self.batchnorm.moving_variance)
+    beta = broadcast_fn(self.batchnorm.beta) if self.batchnorm.center else None
+    gamma = broadcast_fn(self.batchnorm.gamma) if self.batchnorm.scale else None
+    return _undo_batch_normalization(
+        x, mean, variance, beta, gamma, self.batchnorm.epsilon)
+
+  def _forward(self, x):
+    return self._de_normalize(x)
+
+  def _inverse(self, y):
+    return self._normalize(y)
+
+  def _forward_log_det_jacobian(self, x):
+    # Uses saved statistics to compute volume distortion.
+    return -self._inverse_log_det_jacobian(x, use_saved_statistics=True)
+
+  def _inverse_log_det_jacobian(self, y, use_saved_statistics=False):
+    if not y.shape.is_fully_defined():
+      raise ValueError("Input must have shape known at graph construction.")
+    input_shape = np.int32(y.shape.as_list())
+
+    if not self.batchnorm.built:
+      # Create variables.
+      self.batchnorm.build(input_shape)
+
+    event_dims = self.batchnorm.axis
+    reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims]
+
+    if use_saved_statistics or not self._training:
+      log_variance = math_ops.log(
+          self.batchnorm.moving_variance + self.batchnorm.epsilon)
+    else:
+      # At training-time, ildj is computed from the mean and log-variance across
+      # the current minibatch.
+      _, v = nn.moments(y, axes=reduction_axes, keep_dims=True)
+      log_variance = math_ops.log(v + self.batchnorm.epsilon)
+
+    # `gamma` and `log Var(y)` reductions over event_dims.
+    # Log(total change in area from gamma term).
+    log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma))
+
+    # Log(total change in area from log-variance term).
+    log_total_variance = math_ops.reduce_sum(log_variance)
+    # The ildj is scalar, as it does not depend on the values of x and are
+    # constant across minibatch elements.
+    return log_total_gamma - 0.5 * log_total_variance