Add Group Normalization to tf.contrib.layers.
authorJon Shlens <shlens@google.com>
Tue, 3 Apr 2018 22:55:59 +0000 (15:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 22:58:32 +0000 (15:58 -0700)
# Example usage: NHWC
outputs = tf.contrib.layers.group_norm(inputs, groups=32, channels_axis=-1, reduction_axes=[-3, -2])
# Example usage: NCHW
outputs = tf.contrib.layers.group_norm(inputs, groups=32, channels_axis=-3, reduction_axes=[-2, -1])
PiperOrigin-RevId: 191513496

tensorflow/contrib/layers/__init__.py
tensorflow/contrib/layers/python/layers/normalization.py
tensorflow/contrib/layers/python/layers/normalization_test.py

index 337c9e0..00f03a1 100644 (file)
@@ -104,6 +104,7 @@ See the @{$python/contrib.layers} guide.
 @@infer_real_valued_columns
 @@sequence_input_from_feature_columns
 
+@@group_norm
 @@instance_norm
 """
 
@@ -122,6 +123,7 @@ _allowed_symbols = ['bias_add',
                     'conv3d',
                     'elu',
                     'feature_column',
+                    'group_norm',
                     'instance_norm',
                     'legacy_fully_connected',
                     'legacy_linear',
index e7d4080..c807ab0 100644 (file)
@@ -24,11 +24,13 @@ from tensorflow.contrib.layers.python.layers import utils
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import variable_scope
 
 
 __all__ = [
+    'group_norm',
     'instance_norm',
 ]
 
@@ -158,3 +160,196 @@ def instance_norm(inputs,
     if activation_fn is not None:
       outputs = activation_fn(outputs)
     return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
+
+@add_arg_scope
+def group_norm(inputs,
+               groups=32,
+               channels_axis=-1,
+               reduction_axes=(-3, -2),
+               center=True,
+               scale=True,
+               epsilon=1e-6,
+               activation_fn=None,
+               param_initializers=None,
+               reuse=None,
+               variables_collections=None,
+               outputs_collections=None,
+               trainable=True,
+               scope=None):
+  """Functional interface for the group normalization layer.
+
+  Reference: https://arxiv.org/abs/1803.08494.
+
+    "Group Normalization", Yuxin Wu, Kaiming He
+
+  Args:
+    inputs: A Tensor with at least 2 dimensions one which is channels. All
+     shape dimensions must be fully defined.
+    groups: Integer. Divide the channels into this number of groups over which
+      normalization statistics are computed. This number must be commensurate
+      with the number of channels in `inputs`.
+    channels_axis: An integer. Specifies index of channels axis which will be
+      broken into `groups`, each of which whose statistics will be computed
+      across. Must be mutually exclusive with `reduction_axes`. Preferred usage
+      is to specify negative integers to be agnostic as to whether a batch
+      dimension is included.
+    reduction_axes: Tuple of integers. Specifies dimensions over which
+       statistics will be accumulated. Must be mutually exclusive with
+       `channels_axis`. Statistics will not be accumulated across axes not
+       specified in `reduction_axes` nor `channel_axis`. Preferred usage is to
+       specify negative integers to be agnostic to whether a batch dimension is
+       included.
+
+      Some sample usage cases:
+        NHWC format: channels_axis=-1, reduction_axes=[-3, -2]
+        NCHW format: channels_axis=-3, reduction_axes=[-2, -1]
+
+    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
+      is ignored.
+    scale: If True, multiply by `gamma`. If False, `gamma` is
+      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+      disabled since the scaling can be done by the next layer.
+    epsilon: Small float added to variance to avoid dividing by zero.
+    activation_fn: Activation function, default set to None to skip it and
+      maintain a linear activation.
+    param_initializers: Optional initializers for beta, gamma, moving mean and
+      moving variance.
+    reuse: Whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
+    variables_collections: Optional collections for the variables.
+    outputs_collections: Collections to add the outputs.
+    trainable: If `True` also add variables to the graph collection
+      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    scope: Optional scope for `variable_scope`.
+
+  Returns:
+    A `Tensor` representing the output of the operation.
+
+  Raises:
+    ValueError: If the rank of `inputs` is undefined.
+    ValueError: If rank or channels dimension of `inputs` is undefined.
+    ValueError: If number of groups is not commensurate with number of channels.
+    ValueError: If reduction_axes or channels_axis are out of bounds.
+    ValueError: If reduction_axes are not mutually exclusive with channels_axis.
+  """
+  # TODO(shlens): Support partially defined shapes for the inputs.
+  inputs = ops.convert_to_tensor(inputs)
+  original_shape = inputs.shape
+
+  if inputs.shape.ndims is None:
+    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
+  if channels_axis > (inputs.shape.ndims - 1):
+    raise ValueError('Axis is out of bounds.')
+
+  # Standardize the channels_axis to be positive and identify # of channels.
+  if channels_axis < 0:
+    channels_axis = inputs.shape.ndims + channels_axis
+  channels = inputs.shape[channels_axis].value
+
+  if channels is None:
+    raise ValueError('Inputs %s has undefined channel dimension: %d.' % (
+        inputs.name, channels_axis))
+
+  # Standardize the reduction_axes to be positive.
+  reduction_axes = list(reduction_axes)
+  for i in range(len(reduction_axes)):
+    if reduction_axes[i] < 0:
+      reduction_axes[i] += inputs.shape.ndims
+
+  for a in reduction_axes:
+    if a > inputs.shape.ndims:
+      raise ValueError('Axis is out of bounds.')
+    if inputs.shape[a].value is None:
+      raise ValueError('Inputs %s has undefined dimensions %d.' % (
+          inputs.name, a))
+    if channels_axis == a:
+      raise ValueError('reduction_axis must be mutually exclusive '
+                       'with channels_axis')
+  if groups > channels:
+    raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
+  if channels % groups != 0:
+    raise ValueError('%d channels is not commensurate with %d groups.' %
+                     (channels, groups))
+
+  # Determine axes before channels. Some examples of common image formats:
+  #  'NCHW': before = [N], after = [HW]
+  #  'NHWC': before = [NHW], after = []
+  axes_before_channels = inputs.shape.as_list()[:channels_axis]
+  axes_after_channels = inputs.shape.as_list()[channels_axis+1:]
+
+  # Manually broadcast the parameters to conform to the number of groups.
+  params_shape_broadcast = ([1] * len(axes_before_channels) +
+                            [groups, channels // groups] +
+                            [1] * len(axes_after_channels))
+
+  # Reshape the input by the group within the channel dimension.
+  inputs_shape = (axes_before_channels + [groups, channels // groups] +
+                  axes_after_channels)
+  inputs = array_ops.reshape(inputs, inputs_shape)
+
+  # Determine the dimensions across which moments are calculated.
+  moments_axes = [channels_axis + 1]
+  for a in reduction_axes:
+    if a > channels_axis:
+      moments_axes.append(a + 1)
+    else:
+      moments_axes.append(a)
+
+  with variable_scope.variable_scope(
+      scope, 'GroupNorm', [inputs], reuse=reuse) as sc:
+    # Note that the params_shape is the number of channels always.
+    params_shape = [channels]
+
+    # Allocate parameters for the beta and gamma of the normalization.
+    beta, gamma = None, None
+    dtype = inputs.dtype.base_dtype
+    if param_initializers is None:
+      param_initializers = {}
+    if center:
+      beta_collections = utils.get_variable_collections(
+          variables_collections, 'beta')
+      beta_initializer = param_initializers.get(
+          'beta', init_ops.zeros_initializer())
+      beta = variables.model_variable('beta',
+                                      shape=params_shape,
+                                      dtype=dtype,
+                                      initializer=beta_initializer,
+                                      collections=beta_collections,
+                                      trainable=trainable)
+      beta = array_ops.reshape(beta, params_shape_broadcast)
+
+    if scale:
+      gamma_collections = utils.get_variable_collections(
+          variables_collections, 'gamma')
+      gamma_initializer = param_initializers.get(
+          'gamma', init_ops.ones_initializer())
+      gamma = variables.model_variable('gamma',
+                                       shape=params_shape,
+                                       dtype=dtype,
+                                       initializer=gamma_initializer,
+                                       collections=gamma_collections,
+                                       trainable=trainable)
+      gamma = array_ops.reshape(gamma, params_shape_broadcast)
+
+    # Calculate the moments.
+    mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+
+    # Compute normalization.
+    # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
+    # appropriately so that this operation may be faster.
+    gain = math_ops.rsqrt(variance + epsilon)
+    offset = -mean * gain
+    if gamma is not None:
+      gain *= gamma
+      offset *= gamma
+    if beta is not None:
+      offset += beta
+    outputs = inputs * gain + offset
+
+    # Collapse the groups into the channel dimension.
+    outputs = array_ops.reshape(outputs, original_shape)
+
+    if activation_fn is not None:
+      outputs = activation_fn(outputs)
+    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
index 5cff1bf..b6e9635 100644 (file)
@@ -166,5 +166,231 @@ class InstanceNormTest(test.TestCase):
   def testOutputBigInput5DNCHW(self):
     self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3)
 
+
+class GroupNormTest(test.TestCase):
+
+  def testInvalidGroupSize(self):
+    inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10))
+    with self.assertRaisesRegexp(ValueError,
+                                 'Invalid groups 10 for 2 channels.'):
+      normalization.group_norm(inputs, groups=10,
+                               reduction_axes=[-2, -1], channels_axis=-3)
+
+  def testBadCommensurateGroup(self):
+    inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10))
+    with self.assertRaisesRegexp(ValueError,
+                                 '4 channels is not commensurate with '
+                                 '3 groups.'):
+      normalization.group_norm(inputs, groups=3,
+                               reduction_axes=[-2, -1], channels_axis=-3)
+
+  def testAxisIsBad(self):
+    inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5))
+    with self.assertRaisesRegexp(ValueError,
+                                 'Axis is out of bounds.'):
+      normalization.group_norm(inputs, channels_axis=5)
+    with self.assertRaisesRegexp(ValueError,
+                                 'Axis is out of bounds.'):
+      normalization.group_norm(inputs, reduction_axes=[1, 5])
+
+  def testNotMutuallyExclusiveAxis(self):
+    inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32))
+    # Specify axis with negative values.
+    with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+      normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2])
+    # Specify axis with positive values.
+    with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+      normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3])
+    # Specify axis with mixed positive and negative values.
+    with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+      normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2])
+
+  def testUnknownShape(self):
+    inputs = array_ops.placeholder(dtypes.float32)
+    with self.assertRaisesRegexp(ValueError, 'undefined rank'):
+      normalization.group_norm(inputs)
+
+  def testParamsShapeNotFullyDefinedReductionAxes(self):
+    inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4))
+    with self.assertRaisesRegexp(ValueError, 'undefined dimensions'):
+      normalization.group_norm(inputs)
+
+  def testParamsShapeNotFullyDefinedChannelsAxis(self):
+    inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None))
+    with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'):
+      normalization.group_norm(inputs, channels_axis=-1,
+                               reduction_axes=[-3, -2])
+
+  def testCreateOp(self):
+    height, width, groups = 3, 3, 4
+    images = random_ops.random_uniform((5, height, width, 2*groups), seed=1)
+    output = normalization.group_norm(images, groups=groups, channels_axis=-1,
+                                      reduction_axes=[-3, -2])
+    print('name: ', output.op.name)
+    self.assertListEqual([5, height, width, 2*groups], output.shape.as_list())
+
+  def testCreateOpFloat64(self):
+    height, width, groups = 3, 3, 5
+    images = random_ops.random_uniform(
+        (5, height, width, 4*groups), dtype=dtypes.float64, seed=1)
+    output = normalization.group_norm(images, groups=groups)
+    self.assertEqual(dtypes.float64, output.dtype)
+    self.assertListEqual([5, height, width, 4*groups], output.shape.as_list())
+
+  def testCreateOpNoScaleCenter(self):
+    height, width, groups = 3, 3, 7
+    images = random_ops.random_uniform(
+        (5, height, width, 3*groups), dtype=dtypes.float32, seed=1)
+    output = normalization.group_norm(images, groups=groups, center=False,
+                                      scale=False)
+    self.assertListEqual([5, height, width, 3*groups], output.shape.as_list())
+    self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
+    self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
+
+  def testCreateVariables_NHWC(self):
+    height, width = 3, 3
+    images = random_ops.random_uniform((5, height, width, 8), seed=1)
+    normalization.group_norm(images, groups=4,
+                             channels_axis=-1, reduction_axes=(-3, -2),
+                             center=True, scale=True)
+    beta = contrib_variables.get_variables_by_name('beta')[0]
+    gamma = contrib_variables.get_variables_by_name('gamma')[0]
+    self.assertEqual('GroupNorm/beta', beta.op.name)
+    self.assertEqual('GroupNorm/gamma', gamma.op.name)
+
+  def testCreateVariables_NCHW(self):
+    height, width, groups = 3, 3, 4
+    images = random_ops.random_uniform((5, 2*groups, height, width), seed=1)
+    normalization.group_norm(images, groups=4,
+                             channels_axis=-3, reduction_axes=(-2, -1),
+                             center=True, scale=True)
+    beta = contrib_variables.get_variables_by_name('beta')[0]
+    gamma = contrib_variables.get_variables_by_name('gamma')[0]
+    self.assertEqual('GroupNorm/beta', beta.op.name)
+    self.assertEqual('GroupNorm/gamma', gamma.op.name)
+
+  def testReuseVariables(self):
+    height, width = 3, 3
+    images = random_ops.random_uniform((5, height, width, 4), seed=1)
+    normalization.group_norm(images, groups=2, scale=True, scope='IN')
+    normalization.group_norm(images, groups=2, scale=True, scope='IN',
+                             reuse=True)
+    beta = contrib_variables.get_variables_by_name('beta')
+    gamma = contrib_variables.get_variables_by_name('gamma')
+    self.assertEqual(1, len(beta))
+    self.assertEqual(1, len(gamma))
+
+  def testValueCorrectWithReuseVars(self):
+    height, width = 3, 3
+    image_shape = (10, height, width, 4)
+    images = random_ops.random_uniform(image_shape, seed=1)
+    output_train = normalization.group_norm(images, groups=2, scope='IN')
+    output_eval = normalization.group_norm(images, groups=2, scope='IN',
+                                           reuse=True)
+    with self.test_session() as sess:
+      sess.run(variables.global_variables_initializer())
+      # output_train and output_eval should be the same.
+      train_np, eval_np = sess.run([output_train, output_eval])
+      self.assertAllClose(train_np, eval_np)
+
+  def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
+                   groups=2, tol=1e-2):
+    # Select the axis for the channel and the dimensions along which statistics
+    # are accumulated.
+    if channels_axis < 0:
+      channels_axis += len(input_shape)
+    reduced_axes = [channels_axis + 1]
+    for a in reduction_axes:
+      if a < 0:
+        a += len(input_shape)
+      if a < channels_axis:
+        reduced_axes.append(a)
+      else:
+        reduced_axes.append(a+1)
+    reduced_axes = tuple(reduced_axes)
+
+    # Calculate the final shape for the output Tensor.
+    axes_before_channels = input_shape[:channels_axis]
+    axes_after_channels = input_shape[channels_axis+1:]
+    channels = input_shape[channels_axis]
+    outputs_shape = (axes_before_channels + [groups, channels // groups] +
+                     axes_after_channels)
+
+    # Calculate the final shape for the output statistics.
+    reduced_shape = []
+    for i, a in enumerate(outputs_shape):
+      if i not in reduced_axes:
+        reduced_shape.append(a)
+
+    for mu in (0.0, 1e2):
+      for sigma in (1.0, 0.1):
+        # Determine shape of Tensor after normalization.
+        expected_mean = np.zeros(reduced_shape)
+        expected_var = np.ones(reduced_shape)
+
+        inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+        output_op = normalization.group_norm(
+            inputs, groups=groups, center=False, scale=False,
+            channels_axis=channels_axis,
+            reduction_axes=reduction_axes)
+        with self.test_session() as sess:
+          sess.run(variables.global_variables_initializer())
+          outputs = sess.run(output_op)
+          # Make sure that there are no NaNs
+          self.assertFalse(np.isnan(outputs).any())
+
+          outputs = np.reshape(outputs, outputs_shape)
+          mean = np.mean(outputs, axis=reduced_axes)
+          var = np.var(outputs, axis=reduced_axes)
+          # The mean and variance of each example should be close to 0 and 1
+          # respectively.
+          self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
+          self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
+
+  def testOutputSmallInput4D_NHWC(self):
+    input_shape = [10, 10, 10, 30]
+    # Specify axes with positive values.
+    self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
+    # Specify axes with negative values.
+    self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+
+  def testOutputSmallInput3D_NHWC(self):
+    input_shape = [10, 10, 30]
+    # Specify axes with positive values.
+    self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
+    # Specify axes with negative values.
+    self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+
+  def testOutputSmallInput4D_NCHW(self):
+    input_shape = [10, 10, 10, 30]
+    # Specify axes with positive values.
+    self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
+    # Specify axes with negative values.
+    self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+
+  def testOutputSmallInput3D_NCHW(self):
+    input_shape = [10, 10, 30]
+    # Specify axes with positive values.
+    self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
+    # Specify axes with negative values.
+    self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+
+  def testOutputBigInput4D_NHWC(self):
+    self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
+                      groups=1)
+
+  def testOutputBigInput4D_NCHW(self):
+    self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
+                      groups=4)
+
+  def testOutputSmallInput2D_NC(self):
+    self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+
+  def testOutputSmallInput5D_NCXXX(self):
+    self.doOutputTest([10, 10, 20, 40, 5],
+                      channels_axis=1,
+                      reduction_axes=[2, 3, 4],
+                      groups=5)
+
 if __name__ == '__main__':
   test.main()