)
cuda_py_test(
+ name = "batch_normalization_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "chain_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/chain_test.py"],
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BatchNorm Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import distributions
+from tensorflow.contrib.distributions.python.ops import test_util
+from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import BatchNormalization
+from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import normalization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+
+
+class BatchNormTest(test_util.VectorDistributionTestHelpers,
+ test.TestCase):
+
+ def _reduction_axes(self, input_shape, event_dims):
+ if isinstance(event_dims, int):
+ event_dims = [event_dims]
+ ndims = len(input_shape)
+ # Convert event_dims to non-negative indexing.
+ event_dims = list(event_dims)
+ for idx, x in enumerate(event_dims):
+ if x < 0:
+ event_dims[idx] = ndims + x
+ return tuple(i for i in range(ndims) if i not in event_dims)
+
+ def testForwardInverse(self):
+ """Tests forward and backward passes with different event shapes.
+
+ input_shape: Tuple of shapes for input tensor.
+ event_dims: Tuple of dimension indices that will be normalized.
+ training: Boolean of whether bijector runs in training or inference mode.
+ """
+ params = [
+ ((5*2, 4), [-1], False),
+ ((5, 2, 4), [-1], False),
+ ((5, 2, 4), [1, 2], False),
+ ((5, 2, 4), [0, 1], False),
+ ((5*2, 4), [-1], True),
+ ((5, 2, 4), [-1], True),
+ ((5, 2, 4), [1, 2], True),
+ ((5, 2, 4), [0, 1], True)
+ ]
+ for input_shape, event_dims, training in params:
+ x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape)
+ with self.test_session() as sess:
+ x = constant_op.constant(x_)
+ # When training, memorize the exact mean of the last
+ # minibatch that it normalized (instead of moving average assignment).
+ layer = normalization.BatchNormalization(
+ axis=event_dims, momentum=0., epsilon=0.)
+ batch_norm = BatchNormalization(
+ batchnorm_layer=layer, training=training)
+ # Minibatch statistics are saved only after norm_x has been computed.
+ norm_x = batch_norm.inverse(x)
+ with ops.control_dependencies(batch_norm.batchnorm.updates):
+ moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean)
+ moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance)
+ denorm_x = batch_norm.forward(array_ops.identity(norm_x))
+ fldj = batch_norm.forward_log_det_jacobian(x)
+ # Use identity to invalidate cache.
+ ildj = batch_norm.inverse_log_det_jacobian(
+ array_ops.identity(denorm_x))
+ variables.global_variables_initializer().run()
+ # Update variables.
+ norm_x_ = sess.run(norm_x)
+ [
+ norm_x_,
+ moving_mean_,
+ moving_var_,
+ denorm_x_,
+ ildj_,
+ fldj_,
+ ] = sess.run([
+ norm_x,
+ moving_mean,
+ moving_var,
+ denorm_x,
+ ildj,
+ fldj,
+ ])
+ self.assertEqual("batch_normalization", batch_norm.name)
+
+ reduction_axes = self._reduction_axes(input_shape, event_dims)
+ keepdims = len(event_dims) > 1
+
+ expected_batch_mean = np.mean(
+ x_, axis=reduction_axes, keepdims=keepdims)
+ expected_batch_var = np.var(x_, axis=reduction_axes, keepdims=keepdims)
+
+ if training:
+ # When training=True, values become normalized across batch dim and
+ # original values are recovered after de-normalizing.
+ zeros = np.zeros_like(norm_x_)
+ self.assertAllClose(np.mean(zeros, axis=reduction_axes),
+ np.mean(norm_x_, axis=reduction_axes))
+
+ self.assertAllClose(expected_batch_mean, moving_mean_)
+ self.assertAllClose(expected_batch_var, moving_var_)
+ self.assertAllClose(x_, denorm_x_, atol=1e-5)
+ # Since moving statistics are set to batch statistics after
+ # normalization, ildj and -fldj should match.
+ self.assertAllClose(ildj_, -fldj_)
+ # ildj is computed with minibatch statistics.
+ expected_ildj = np.sum(np.log(1.) - .5 * np.log(
+ expected_batch_var + batch_norm.batchnorm.epsilon))
+ self.assertAllClose(expected_ildj, ildj_)
+ else:
+ # When training=False, moving_mean, moving_var remain at their
+ # initialized values (0., 1.), resulting in no scale/shift (a small
+ # shift occurs if epsilon > 0.)
+ self.assertAllClose(x_, norm_x_)
+ self.assertAllClose(x_, denorm_x_, atol=1e-5)
+ # ildj is computed with saved statistics.
+ expected_ildj = np.sum(
+ np.log(1.) - .5 * np.log(1. + batch_norm.batchnorm.epsilon))
+ self.assertAllClose(expected_ildj, ildj_)
+
+ def testMaximumLikelihoodTraining(self):
+ # Test Maximum Likelihood training with default bijector.
+ with self.test_session() as sess:
+ base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
+ batch_norm = BatchNormalization(training=True)
+ dist = transformed_distribution_lib.TransformedDistribution(
+ distribution=base_dist,
+ bijector=batch_norm)
+ target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.])
+ target_samples = target_dist.sample(100)
+ dist_samples = dist.sample(3000)
+ loss = -math_ops.reduce_mean(dist.log_prob(target_samples))
+ with ops.control_dependencies(batch_norm.batchnorm.updates):
+ train_op = adam.AdamOptimizer(1e-2).minimize(loss)
+ moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean)
+ moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance)
+ variables.global_variables_initializer().run()
+ for _ in range(3000):
+ sess.run(train_op)
+ [
+ dist_samples_,
+ moving_mean_,
+ moving_var_
+ ] = sess.run([
+ dist_samples,
+ moving_mean,
+ moving_var
+ ])
+ self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2)
+ self.assertAllClose([1., 2.], moving_mean_, atol=5e-2)
+ self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
+
+ def testLogProb(self):
+ with self.test_session() as sess:
+ layer = normalization.BatchNormalization(epsilon=0.)
+ batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
+ base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
+ dist = transformed_distribution_lib.TransformedDistribution(
+ distribution=base_dist,
+ bijector=batch_norm,
+ validate_args=True)
+ samples = dist.sample(int(1e5))
+ # No volume distortion since training=False, bijector is initialized
+ # to the identity transformation.
+ base_log_prob = base_dist.log_prob(samples)
+ dist_log_prob = dist.log_prob(samples)
+ variables.global_variables_initializer().run()
+ base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob])
+ self.assertAllClose(base_log_prob_, dist_log_prob_)
+
+ def testMutuallyConsistent(self):
+ # BatchNorm bijector is only mutually consistent when training=False.
+ dims = 4
+ with self.test_session() as sess:
+ layer = normalization.BatchNormalization(epsilon=0.)
+ batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
+ dist = transformed_distribution_lib.TransformedDistribution(
+ distribution=normal_lib.Normal(loc=0., scale=1.),
+ bijector=batch_norm,
+ event_shape=[dims],
+ validate_args=True)
+ self.run_test_sample_consistent_log_prob(
+ sess_run_fn=sess.run,
+ dist=dist,
+ num_samples=int(1e5),
+ radius=2.,
+ center=0.,
+ rtol=0.02)
+
+ def testInvertMutuallyConsistent(self):
+ # BatchNorm bijector is only mutually consistent when training=False.
+ dims = 4
+ with self.test_session() as sess:
+ layer = normalization.BatchNormalization(epsilon=0.)
+ batch_norm = Invert(
+ BatchNormalization(batchnorm_layer=layer, training=False))
+ dist = transformed_distribution_lib.TransformedDistribution(
+ distribution=normal_lib.Normal(loc=0., scale=1.),
+ bijector=batch_norm,
+ event_shape=[dims],
+ validate_args=True)
+ self.run_test_sample_consistent_log_prob(
+ sess_run_fn=sess.run,
+ dist=dist,
+ num_samples=int(1e5),
+ radius=2.,
+ center=0.,
+ rtol=0.02)
+
+
+if __name__ == "__main__":
+ test.main()
@@Affine
@@AffineLinearOperator
@@Bijector
+@@BatchNormalization
@@Chain
@@CholeskyOuterProduct
@@ConditionalBijector
from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import *
from tensorflow.contrib.distributions.python.ops.bijectors.affine import *
from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import *
+from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import *
from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Batch Norm bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import normalization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+ "BatchNormalization",
+]
+
+
+def _undo_batch_normalization(x,
+ mean,
+ variance,
+ offset,
+ scale,
+ variance_epsilon,
+ name=None):
+ r"""Inverse of tf.nn.batch_normalization.
+
+ Args:
+ x: Input `Tensor` of arbitrary dimensionality.
+ mean: A mean `Tensor`.
+ variance: A variance `Tensor`.
+ offset: An offset `Tensor`, often denoted `beta` in equations, or
+ None. If present, will be added to the normalized tensor.
+ scale: A scale `Tensor`, often denoted `gamma` in equations, or
+ `None`. If present, the scale is applied to the normalized tensor.
+ variance_epsilon: A small `float` added to the minibatch `variance` to
+ prevent dividing by zero.
+ name: A name for this operation (optional).
+
+ Returns:
+ batch_unnormalized: The de-normalized, de-scaled, de-offset `Tensor`.
+ """
+ with ops.name_scope(
+ name, "undo_batchnorm", [x, mean, variance, scale, offset]):
+ # inv = math_ops.rsqrt(variance + variance_epsilon)
+ # if scale is not None:
+ # inv *= scale
+ # return x * inv + (
+ # offset - mean * inv if offset is not None else -mean * inv)
+ rescale = math_ops.sqrt(variance + variance_epsilon)
+ if scale is not None:
+ rescale /= scale
+ batch_unnormalized = x * rescale + (
+ mean - offset * rescale if offset is not None else mean)
+ return batch_unnormalized
+
+
+class BatchNormalization(bijector.Bijector):
+ """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`.
+
+ Applies Batch Normalization [1] to samples from a data distribution. This can
+ be used to stabilize training of normalizing flows [2, 3].
+
+ When training Deep Neural Networks (DNNs), it is common practice to
+ normalize or whiten features by shifting them to have zero mean and
+ scaling them to have unit variance.
+
+ The `inverse()` method of the BatchNorm bijector, which is used in the
+ log-likelihood computation of data samples, implements the normalization
+ procedure (shift-and-scale) using the mean and standard deviation of the
+ current minibatch.
+
+ Conversely, the `forward()` method of the bijector de-normalizes samples (e.g.
+ `X*std(Y) + mean(Y)` with the running-average mean and standard deviation
+ computed at training-time. De-normalization is useful for sampling.
+
+
+ ```python
+
+ dist = tfd.TransformedDistribution(
+ distribution=tfd.Normal()),
+ bijector=tfb.BatchNorm())
+
+ y = tfd.MultivariateNormalDiag(loc=1., scale=2.).sample(100) # ~ N(1, 2)
+ x = dist.bijector.inverse(y) # ~ N(0, 1)
+ y = dist.sample() # ~ N(1, 2)
+ ```
+
+ During training time, `BatchNorm.inverse` and `BatchNorm.forward` are not
+ guaranteed to be inverses of each other because `inverse(y)` uses statistics
+ of the current minibatch, while `forward(x)` uses running-average statistics
+ accumulated from training. In other words,
+ `BatchNorm.inverse(BatchNorm.forward(...))` and
+ `BatchNorm.forward(BatchNorm.inverse(...))` will be identical when
+ `training=False` but may be different when `training=True`.
+
+ [1]: "Batch Normalization: Accelerating Deep Network Training by Reducing
+ Internal Covariate Shift."
+ Sergey Ioffe, Christian Szegedy. Arxiv. 2015.
+ https://arxiv.org/abs/1502.03167
+
+ [2]: "Density Estimation using Real NVP."
+ Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017.
+ https://arxiv.org/abs/1605.08803
+
+ [3]: "Masked Autoregressive Flow for Density Estimation."
+ George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017.
+ https://arxiv.org/abs/1705.07057
+
+ """
+
+ def __init__(self,
+ batchnorm_layer=None,
+ training=True,
+ validate_args=False,
+ name="batch_normalization"):
+ """Instantiates the `BatchNorm` bijector.
+
+ Args:
+ batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`,
+ defaults to
+ `tf.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + 1e-6)`.
+ This ensures positivity of the scale variable.
+
+ training: If True, updates running-average statistics during call to
+ `inverse()`.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ Raises:
+ ValueError: If bn_layer is not an instance of
+ `tf.layers.BatchNormalization`, or if it is specified with `renorm=True`
+ or a virtual batch size.
+ """
+ # Scale must be positive.
+ g_constraint = lambda x: nn.relu(x) + 1e-6
+ self.batchnorm = batchnorm_layer or normalization.BatchNormalization(
+ gamma_constraint=g_constraint)
+ self._validate_bn_layer(self.batchnorm)
+ self._training = training
+ super(BatchNormalization, self).__init__(
+ validate_args=validate_args, name=name)
+
+ def _validate_bn_layer(self, layer):
+ """Check for valid BatchNormalization layer.
+
+ Args:
+ layer: Instance of `tf.layers.BatchNormalization`.
+ Raises:
+ ValueError: If batchnorm_layer argument is not an instance of
+ `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or
+ if `batchnorm_layer.virtual_batch_size` is specified.
+ """
+ if not isinstance(layer, normalization.BatchNormalization):
+ raise ValueError(
+ "batchnorm_layer must be an instance of BatchNormalization layer.")
+ if layer.renorm:
+ raise ValueError("BatchNorm Bijector does not support renormalization.")
+ if layer.virtual_batch_size:
+ raise ValueError(
+ "BatchNorm Bijector does not support virtual batch sizes.")
+
+ def _get_broadcast_fn(self, x):
+ # Compute shape to broadcast scale/shift parameters to.
+ if not x.shape.is_fully_defined():
+ raise ValueError("Input must have shape known at graph construction.")
+ input_shape = np.int32(x.shape.as_list())
+
+ ndims = len(input_shape)
+ # event_dims = self._compute_event_dims(x)
+ reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis]
+ # Broadcasting only necessary for single-axis batch norm where the axis is
+ # not the last dimension
+ broadcast_shape = [1] * ndims
+ # import pdb; pdb.set_trace()
+ broadcast_shape[self.batchnorm.axis[0]] = (
+ input_shape[self.batchnorm.axis[0]])
+ def _broadcast(v):
+ if (v is not None and
+ len(v.get_shape()) != ndims and
+ reduction_axes != list(range(ndims - 1))):
+ return array_ops.reshape(v, broadcast_shape)
+ return v
+ return _broadcast
+
+ def _normalize(self, y):
+ return self.batchnorm.apply(y, training=self._training)
+
+ def _de_normalize(self, x):
+ # Uses the saved statistics.
+ if not self.batchnorm.built:
+ input_shape = x.get_shape()
+ self.batchnorm.build(input_shape)
+ broadcast_fn = self._get_broadcast_fn(x)
+ mean = broadcast_fn(self.batchnorm.moving_mean)
+ variance = broadcast_fn(self.batchnorm.moving_variance)
+ beta = broadcast_fn(self.batchnorm.beta) if self.batchnorm.center else None
+ gamma = broadcast_fn(self.batchnorm.gamma) if self.batchnorm.scale else None
+ return _undo_batch_normalization(
+ x, mean, variance, beta, gamma, self.batchnorm.epsilon)
+
+ def _forward(self, x):
+ return self._de_normalize(x)
+
+ def _inverse(self, y):
+ return self._normalize(y)
+
+ def _forward_log_det_jacobian(self, x):
+ # Uses saved statistics to compute volume distortion.
+ return -self._inverse_log_det_jacobian(x, use_saved_statistics=True)
+
+ def _inverse_log_det_jacobian(self, y, use_saved_statistics=False):
+ if not y.shape.is_fully_defined():
+ raise ValueError("Input must have shape known at graph construction.")
+ input_shape = np.int32(y.shape.as_list())
+
+ if not self.batchnorm.built:
+ # Create variables.
+ self.batchnorm.build(input_shape)
+
+ event_dims = self.batchnorm.axis
+ reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims]
+
+ if use_saved_statistics or not self._training:
+ log_variance = math_ops.log(
+ self.batchnorm.moving_variance + self.batchnorm.epsilon)
+ else:
+ # At training-time, ildj is computed from the mean and log-variance across
+ # the current minibatch.
+ _, v = nn.moments(y, axes=reduction_axes, keep_dims=True)
+ log_variance = math_ops.log(v + self.batchnorm.epsilon)
+
+ # `gamma` and `log Var(y)` reductions over event_dims.
+ # Log(total change in area from gamma term).
+ log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma))
+
+ # Log(total change in area from log-variance term).
+ log_total_variance = math_ops.reduce_sum(log_variance)
+ # The ildj is scalar, as it does not depend on the values of x and are
+ # constant across minibatch elements.
+ return log_total_gamma - 0.5 * log_total_variance