from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
__all__ = [
+ 'group_norm',
'instance_norm',
]
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
+
+@add_arg_scope
+def group_norm(inputs,
+ groups=32,
+ channels_axis=-1,
+ reduction_axes=(-3, -2),
+ center=True,
+ scale=True,
+ epsilon=1e-6,
+ activation_fn=None,
+ param_initializers=None,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ scope=None):
+ """Functional interface for the group normalization layer.
+
+ Reference: https://arxiv.org/abs/1803.08494.
+
+ "Group Normalization", Yuxin Wu, Kaiming He
+
+ Args:
+ inputs: A Tensor with at least 2 dimensions one which is channels. All
+ shape dimensions must be fully defined.
+ groups: Integer. Divide the channels into this number of groups over which
+ normalization statistics are computed. This number must be commensurate
+ with the number of channels in `inputs`.
+ channels_axis: An integer. Specifies index of channels axis which will be
+ broken into `groups`, each of which whose statistics will be computed
+ across. Must be mutually exclusive with `reduction_axes`. Preferred usage
+ is to specify negative integers to be agnostic as to whether a batch
+ dimension is included.
+ reduction_axes: Tuple of integers. Specifies dimensions over which
+ statistics will be accumulated. Must be mutually exclusive with
+ `channels_axis`. Statistics will not be accumulated across axes not
+ specified in `reduction_axes` nor `channel_axis`. Preferred usage is to
+ specify negative integers to be agnostic to whether a batch dimension is
+ included.
+
+ Some sample usage cases:
+ NHWC format: channels_axis=-1, reduction_axes=[-3, -2]
+ NCHW format: channels_axis=-3, reduction_axes=[-2, -1]
+
+ center: If True, add offset of `beta` to normalized tensor. If False, `beta`
+ is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+ disabled since the scaling can be done by the next layer.
+ epsilon: Small float added to variance to avoid dividing by zero.
+ activation_fn: Activation function, default set to None to skip it and
+ maintain a linear activation.
+ param_initializers: Optional initializers for beta, gamma, moving mean and
+ moving variance.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional collections for the variables.
+ outputs_collections: Collections to add the outputs.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ A `Tensor` representing the output of the operation.
+
+ Raises:
+ ValueError: If the rank of `inputs` is undefined.
+ ValueError: If rank or channels dimension of `inputs` is undefined.
+ ValueError: If number of groups is not commensurate with number of channels.
+ ValueError: If reduction_axes or channels_axis are out of bounds.
+ ValueError: If reduction_axes are not mutually exclusive with channels_axis.
+ """
+ # TODO(shlens): Support partially defined shapes for the inputs.
+ inputs = ops.convert_to_tensor(inputs)
+ original_shape = inputs.shape
+
+ if inputs.shape.ndims is None:
+ raise ValueError('Inputs %s has undefined rank.' % inputs.name)
+ if channels_axis > (inputs.shape.ndims - 1):
+ raise ValueError('Axis is out of bounds.')
+
+ # Standardize the channels_axis to be positive and identify # of channels.
+ if channels_axis < 0:
+ channels_axis = inputs.shape.ndims + channels_axis
+ channels = inputs.shape[channels_axis].value
+
+ if channels is None:
+ raise ValueError('Inputs %s has undefined channel dimension: %d.' % (
+ inputs.name, channels_axis))
+
+ # Standardize the reduction_axes to be positive.
+ reduction_axes = list(reduction_axes)
+ for i in range(len(reduction_axes)):
+ if reduction_axes[i] < 0:
+ reduction_axes[i] += inputs.shape.ndims
+
+ for a in reduction_axes:
+ if a > inputs.shape.ndims:
+ raise ValueError('Axis is out of bounds.')
+ if inputs.shape[a].value is None:
+ raise ValueError('Inputs %s has undefined dimensions %d.' % (
+ inputs.name, a))
+ if channels_axis == a:
+ raise ValueError('reduction_axis must be mutually exclusive '
+ 'with channels_axis')
+ if groups > channels:
+ raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
+ if channels % groups != 0:
+ raise ValueError('%d channels is not commensurate with %d groups.' %
+ (channels, groups))
+
+ # Determine axes before channels. Some examples of common image formats:
+ # 'NCHW': before = [N], after = [HW]
+ # 'NHWC': before = [NHW], after = []
+ axes_before_channels = inputs.shape.as_list()[:channels_axis]
+ axes_after_channels = inputs.shape.as_list()[channels_axis+1:]
+
+ # Manually broadcast the parameters to conform to the number of groups.
+ params_shape_broadcast = ([1] * len(axes_before_channels) +
+ [groups, channels // groups] +
+ [1] * len(axes_after_channels))
+
+ # Reshape the input by the group within the channel dimension.
+ inputs_shape = (axes_before_channels + [groups, channels // groups] +
+ axes_after_channels)
+ inputs = array_ops.reshape(inputs, inputs_shape)
+
+ # Determine the dimensions across which moments are calculated.
+ moments_axes = [channels_axis + 1]
+ for a in reduction_axes:
+ if a > channels_axis:
+ moments_axes.append(a + 1)
+ else:
+ moments_axes.append(a)
+
+ with variable_scope.variable_scope(
+ scope, 'GroupNorm', [inputs], reuse=reuse) as sc:
+ # Note that the params_shape is the number of channels always.
+ params_shape = [channels]
+
+ # Allocate parameters for the beta and gamma of the normalization.
+ beta, gamma = None, None
+ dtype = inputs.dtype.base_dtype
+ if param_initializers is None:
+ param_initializers = {}
+ if center:
+ beta_collections = utils.get_variable_collections(
+ variables_collections, 'beta')
+ beta_initializer = param_initializers.get(
+ 'beta', init_ops.zeros_initializer())
+ beta = variables.model_variable('beta',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=beta_initializer,
+ collections=beta_collections,
+ trainable=trainable)
+ beta = array_ops.reshape(beta, params_shape_broadcast)
+
+ if scale:
+ gamma_collections = utils.get_variable_collections(
+ variables_collections, 'gamma')
+ gamma_initializer = param_initializers.get(
+ 'gamma', init_ops.ones_initializer())
+ gamma = variables.model_variable('gamma',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=gamma_initializer,
+ collections=gamma_collections,
+ trainable=trainable)
+ gamma = array_ops.reshape(gamma, params_shape_broadcast)
+
+ # Calculate the moments.
+ mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+
+ # Compute normalization.
+ # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
+ # appropriately so that this operation may be faster.
+ gain = math_ops.rsqrt(variance + epsilon)
+ offset = -mean * gain
+ if gamma is not None:
+ gain *= gamma
+ offset *= gamma
+ if beta is not None:
+ offset += beta
+ outputs = inputs * gain + offset
+
+ # Collapse the groups into the channel dimension.
+ outputs = array_ops.reshape(outputs, original_shape)
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def testOutputBigInput5DNCHW(self):
self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3)
+
+class GroupNormTest(test.TestCase):
+
+ def testInvalidGroupSize(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10))
+ with self.assertRaisesRegexp(ValueError,
+ 'Invalid groups 10 for 2 channels.'):
+ normalization.group_norm(inputs, groups=10,
+ reduction_axes=[-2, -1], channels_axis=-3)
+
+ def testBadCommensurateGroup(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10))
+ with self.assertRaisesRegexp(ValueError,
+ '4 channels is not commensurate with '
+ '3 groups.'):
+ normalization.group_norm(inputs, groups=3,
+ reduction_axes=[-2, -1], channels_axis=-3)
+
+ def testAxisIsBad(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5))
+ with self.assertRaisesRegexp(ValueError,
+ 'Axis is out of bounds.'):
+ normalization.group_norm(inputs, channels_axis=5)
+ with self.assertRaisesRegexp(ValueError,
+ 'Axis is out of bounds.'):
+ normalization.group_norm(inputs, reduction_axes=[1, 5])
+
+ def testNotMutuallyExclusiveAxis(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32))
+ # Specify axis with negative values.
+ with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+ normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2])
+ # Specify axis with positive values.
+ with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+ normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3])
+ # Specify axis with mixed positive and negative values.
+ with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+ normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2])
+
+ def testUnknownShape(self):
+ inputs = array_ops.placeholder(dtypes.float32)
+ with self.assertRaisesRegexp(ValueError, 'undefined rank'):
+ normalization.group_norm(inputs)
+
+ def testParamsShapeNotFullyDefinedReductionAxes(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4))
+ with self.assertRaisesRegexp(ValueError, 'undefined dimensions'):
+ normalization.group_norm(inputs)
+
+ def testParamsShapeNotFullyDefinedChannelsAxis(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None))
+ with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'):
+ normalization.group_norm(inputs, channels_axis=-1,
+ reduction_axes=[-3, -2])
+
+ def testCreateOp(self):
+ height, width, groups = 3, 3, 4
+ images = random_ops.random_uniform((5, height, width, 2*groups), seed=1)
+ output = normalization.group_norm(images, groups=groups, channels_axis=-1,
+ reduction_axes=[-3, -2])
+ print('name: ', output.op.name)
+ self.assertListEqual([5, height, width, 2*groups], output.shape.as_list())
+
+ def testCreateOpFloat64(self):
+ height, width, groups = 3, 3, 5
+ images = random_ops.random_uniform(
+ (5, height, width, 4*groups), dtype=dtypes.float64, seed=1)
+ output = normalization.group_norm(images, groups=groups)
+ self.assertEqual(dtypes.float64, output.dtype)
+ self.assertListEqual([5, height, width, 4*groups], output.shape.as_list())
+
+ def testCreateOpNoScaleCenter(self):
+ height, width, groups = 3, 3, 7
+ images = random_ops.random_uniform(
+ (5, height, width, 3*groups), dtype=dtypes.float32, seed=1)
+ output = normalization.group_norm(images, groups=groups, center=False,
+ scale=False)
+ self.assertListEqual([5, height, width, 3*groups], output.shape.as_list())
+ self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
+ self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
+
+ def testCreateVariables_NHWC(self):
+ height, width = 3, 3
+ images = random_ops.random_uniform((5, height, width, 8), seed=1)
+ normalization.group_norm(images, groups=4,
+ channels_axis=-1, reduction_axes=(-3, -2),
+ center=True, scale=True)
+ beta = contrib_variables.get_variables_by_name('beta')[0]
+ gamma = contrib_variables.get_variables_by_name('gamma')[0]
+ self.assertEqual('GroupNorm/beta', beta.op.name)
+ self.assertEqual('GroupNorm/gamma', gamma.op.name)
+
+ def testCreateVariables_NCHW(self):
+ height, width, groups = 3, 3, 4
+ images = random_ops.random_uniform((5, 2*groups, height, width), seed=1)
+ normalization.group_norm(images, groups=4,
+ channels_axis=-3, reduction_axes=(-2, -1),
+ center=True, scale=True)
+ beta = contrib_variables.get_variables_by_name('beta')[0]
+ gamma = contrib_variables.get_variables_by_name('gamma')[0]
+ self.assertEqual('GroupNorm/beta', beta.op.name)
+ self.assertEqual('GroupNorm/gamma', gamma.op.name)
+
+ def testReuseVariables(self):
+ height, width = 3, 3
+ images = random_ops.random_uniform((5, height, width, 4), seed=1)
+ normalization.group_norm(images, groups=2, scale=True, scope='IN')
+ normalization.group_norm(images, groups=2, scale=True, scope='IN',
+ reuse=True)
+ beta = contrib_variables.get_variables_by_name('beta')
+ gamma = contrib_variables.get_variables_by_name('gamma')
+ self.assertEqual(1, len(beta))
+ self.assertEqual(1, len(gamma))
+
+ def testValueCorrectWithReuseVars(self):
+ height, width = 3, 3
+ image_shape = (10, height, width, 4)
+ images = random_ops.random_uniform(image_shape, seed=1)
+ output_train = normalization.group_norm(images, groups=2, scope='IN')
+ output_eval = normalization.group_norm(images, groups=2, scope='IN',
+ reuse=True)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ # output_train and output_eval should be the same.
+ train_np, eval_np = sess.run([output_train, output_eval])
+ self.assertAllClose(train_np, eval_np)
+
+ def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
+ groups=2, tol=1e-2):
+ # Select the axis for the channel and the dimensions along which statistics
+ # are accumulated.
+ if channels_axis < 0:
+ channels_axis += len(input_shape)
+ reduced_axes = [channels_axis + 1]
+ for a in reduction_axes:
+ if a < 0:
+ a += len(input_shape)
+ if a < channels_axis:
+ reduced_axes.append(a)
+ else:
+ reduced_axes.append(a+1)
+ reduced_axes = tuple(reduced_axes)
+
+ # Calculate the final shape for the output Tensor.
+ axes_before_channels = input_shape[:channels_axis]
+ axes_after_channels = input_shape[channels_axis+1:]
+ channels = input_shape[channels_axis]
+ outputs_shape = (axes_before_channels + [groups, channels // groups] +
+ axes_after_channels)
+
+ # Calculate the final shape for the output statistics.
+ reduced_shape = []
+ for i, a in enumerate(outputs_shape):
+ if i not in reduced_axes:
+ reduced_shape.append(a)
+
+ for mu in (0.0, 1e2):
+ for sigma in (1.0, 0.1):
+ # Determine shape of Tensor after normalization.
+ expected_mean = np.zeros(reduced_shape)
+ expected_var = np.ones(reduced_shape)
+
+ inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+ output_op = normalization.group_norm(
+ inputs, groups=groups, center=False, scale=False,
+ channels_axis=channels_axis,
+ reduction_axes=reduction_axes)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ outputs = sess.run(output_op)
+ # Make sure that there are no NaNs
+ self.assertFalse(np.isnan(outputs).any())
+
+ outputs = np.reshape(outputs, outputs_shape)
+ mean = np.mean(outputs, axis=reduced_axes)
+ var = np.var(outputs, axis=reduced_axes)
+ # The mean and variance of each example should be close to 0 and 1
+ # respectively.
+ self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
+ self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
+
+ def testOutputSmallInput4D_NHWC(self):
+ input_shape = [10, 10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+
+ def testOutputSmallInput3D_NHWC(self):
+ input_shape = [10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+
+ def testOutputSmallInput4D_NCHW(self):
+ input_shape = [10, 10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+
+ def testOutputSmallInput3D_NCHW(self):
+ input_shape = [10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+
+ def testOutputBigInput4D_NHWC(self):
+ self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
+ groups=1)
+
+ def testOutputBigInput4D_NCHW(self):
+ self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
+ groups=4)
+
+ def testOutputSmallInput2D_NC(self):
+ self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+
+ def testOutputSmallInput5D_NCXXX(self):
+ self.doOutputTest([10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+
if __name__ == '__main__':
test.main()