Fixed linter errors.
authorJianwei Xie <xiejw@google.com>
Wed, 24 Jan 2018 19:31:06 +0000 (11:31 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 24 Jan 2018 19:35:51 +0000 (11:35 -0800)
PiperOrigin-RevId: 183115307

26 files changed:
tensorflow/contrib/cmake/python_sanity_test.py
tensorflow/contrib/layers/python/layers/layers.py
tensorflow/contrib/opt/python/training/model_average_optimizer.py
tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
tensorflow/contrib/rnn/python/ops/rnn_cell.py
tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
tensorflow/contrib/verbs/rdma.cc
tensorflow/contrib/verbs/rdma.h
tensorflow/contrib/verbs/rdma_mgr.cc
tensorflow/core/kernels/mkl_aggregate_ops.cc
tensorflow/core/kernels/mkl_softmax_op.cc
tensorflow/core/kernels/spectrogram_test_utils.cc
tensorflow/core/kernels/transpose_functor_cpu.cc
tensorflow/examples/tutorials/word2vec/word2vec_basic.py
tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
tensorflow/python/ops/histogram_ops.py
tensorflow/python/ops/histogram_ops_test.py
tensorflow/python/ops/image_ops_impl.py
tensorflow/python/ops/metrics_impl.py
tensorflow/python/ops/nn_impl.py
tensorflow/python/ops/nn_test.py
tensorflow/python/util/compat.py
tensorflow/tools/pip_package/pip_smoke_test.py

index 3be5bd1b23af34561e06472f5712bdc633a817ea..e0056823a80833329bcb1f275a3384a33127bb40 100644 (file)
@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""
-Complain about invalid or missing entries in python_*.txt files.
+"""Complain about invalid or missing entries in python_*.txt files.
+
 Problematic entries can be commented for temporary whitelisting.
 """
 
@@ -35,6 +35,7 @@ def abs_path(path):
   path = os.path.abspath(path)
   return path
 
+
 def read_entries(test):
   with open(abs_path(test.entries_file), "r") as f:
     lines = f.readlines()
@@ -47,25 +48,28 @@ def read_entries(test):
 
   for line in lines:
     # line is comment
-    if line.startswith('#'):
+    if line.startswith("#"):
       line = line[1:].strip()
       # whitelist entry
-      if line.startswith('tensorflow/'):
+      if line.startswith("tensorflow/"):
         test.whitelist.append(line)
     # line has comment -> strip comment
-    elif line.find('#') != -1:
-      line = line[:line.find('#')].strip()
+    elif line.find("#") != -1:
+      line = line[:line.find("#")].strip()
       test.entries.append(line)
     else:
       test.entries.append(line)
 
+
 def test_invalid_directories(test):
   for entry in test.entries:
     if not os.path.isdir(abs_path(entry)):
       problem = "'" + test.entries_file + "' contains invalid '" + entry + "'"
-      solution = "Please remove the invalid entry (or add the missing directory)."
+      solution = ("Please remove the invalid entry (or add the missing "
+                  "directory).")
       raise AssertionError(problem + "\n" + solution)
 
+
 def test_missing_directory(test, path):
   if path in test.whitelist:
     return
index ef2b67307472c39f7469d2aeacdeb979b4bb8e5d..7c52da7b494bdb35f71303811de32962a4ed0fa9 100644 (file)
@@ -54,47 +54,17 @@ from tensorflow.python.layers.maxout import maxout
 
 # TODO(b/28426988): Replace legacy_* fns migrated from slim.
 # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API.
-__all__ = ['avg_pool2d',
-           'avg_pool3d',
-           'batch_norm',
-           'bias_add',
-           'conv2d',
-           'conv3d',
-           'conv2d_in_plane',
-           'conv2d_transpose',
-           'conv3d_transpose',
-           'convolution',
-           'convolution2d',
-           'convolution2d_in_plane',
-           'convolution2d_transpose',
-           'convolution3d',
-           'convolution3d_transpose',
-           'dropout',
-           'elu',
-           'flatten',
-           'fully_connected',
-           'GDN',
-           'gdn',
-           'layer_norm',
-           'linear',
-           'pool',
-           'max_pool2d',
-           'max_pool3d',
-           'one_hot_encoding',
-           'relu',
-           'relu6',
-           'repeat',
-           'scale_gradient',
-           'separable_conv2d',
-           'separable_convolution2d',
-           'softmax',
-           'spatial_softmax',
-           'stack',
-           'unit_norm',
-           'legacy_fully_connected',
-           'legacy_linear',
-           'legacy_relu',
-           'maxout']
+__all__ = [
+    'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d',
+    'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution',
+    'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose',
+    'convolution3d', 'convolution3d_transpose', 'dropout', 'elu', 'flatten',
+    'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', 'pool',
+    'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat',
+    'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax',
+    'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected',
+    'legacy_linear', 'legacy_relu', 'maxout'
+]
 
 DATA_FORMAT_NCHW = 'NCHW'
 DATA_FORMAT_NHWC = 'NHWC'
@@ -139,13 +109,14 @@ def avg_pool2d(inputs,
     raise ValueError('data_format has to be either NCHW or NHWC.')
   with ops.name_scope(scope, 'AvgPool2D', [inputs]) as sc:
     inputs = ops.convert_to_tensor(inputs)
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
-    layer = pooling_layers.AveragePooling2D(pool_size=kernel_size,
-                                            strides=stride,
-                                            padding=padding,
-                                            data_format=df,
-                                            _scope=sc)
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
+    layer = pooling_layers.AveragePooling2D(
+        pool_size=kernel_size,
+        strides=stride,
+        padding=padding,
+        data_format=df,
+        _scope=sc)
     outputs = layer.apply(inputs)
     return utils.collect_named_outputs(outputs_collections, sc, outputs)
 
@@ -187,13 +158,14 @@ def avg_pool3d(inputs,
     raise ValueError('data_format has to be either NCDHW or NDHWC.')
   with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc:
     inputs = ops.convert_to_tensor(inputs)
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
-    layer = pooling_layers.AveragePooling3D(pool_size=kernel_size,
-                                            strides=stride,
-                                            padding=padding,
-                                            data_format=df,
-                                            _scope=sc)
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
+    layer = pooling_layers.AveragePooling3D(
+        pool_size=kernel_size,
+        strides=stride,
+        padding=padding,
+        data_format=df,
+        _scope=sc)
     outputs = layer.apply(inputs)
     return utils.collect_named_outputs(outputs_collections, sc, outputs)
 
@@ -298,8 +270,8 @@ def _fused_batch_norm(inputs,
       raise ValueError('Inputs %s has undefined rank' % inputs.name)
     elif original_rank not in [2, 4]:
       raise ValueError('Inputs %s has unsupported rank.'
-                       ' Expected 2 or 4 but got %d' % (
-                           inputs.name, original_rank))
+                       ' Expected 2 or 4 but got %d' % (inputs.name,
+                                                        original_rank))
     if original_rank == 2:
       channels = inputs.get_shape()[-1].value
       if channels is None:
@@ -393,6 +365,7 @@ def _fused_batch_norm(inputs,
     def _fused_batch_norm_training():
       return nn.fused_batch_norm(
           inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
+
     def _fused_batch_norm_inference():
       return nn.fused_batch_norm(
           inputs,
@@ -403,9 +376,9 @@ def _fused_batch_norm(inputs,
           epsilon=epsilon,
           is_training=False,
           data_format=data_format)
-    outputs, mean, variance = utils.smart_cond(is_training,
-                                               _fused_batch_norm_training,
-                                               _fused_batch_norm_inference)
+
+    outputs, mean, variance = utils.smart_cond(
+        is_training, _fused_batch_norm_training, _fused_batch_norm_inference)
 
     # If `is_training` doesn't have a constant value, because it is a `Tensor`,
     # a `Variable` or `Placeholder` then is_training_value will be None and
@@ -415,6 +388,7 @@ def _fused_batch_norm(inputs,
     if need_updates:
       if updates_collections is None:
         no_updates = lambda: outputs
+
         def _force_updates():
           """Internal function forces updates moving_vars if is_training."""
           update_moving_mean = moving_averages.assign_moving_average(
@@ -424,9 +398,11 @@ def _fused_batch_norm(inputs,
           with ops.control_dependencies(
               [update_moving_mean, update_moving_variance]):
             return array_ops.identity(outputs)
+
         outputs = utils.smart_cond(is_training, _force_updates, no_updates)
       else:
         moving_vars_fn = lambda: (moving_mean, moving_variance)
+
         def _delay_updates():
           """Internal function that delay updates moving_vars if is_training."""
           update_moving_mean = moving_averages.assign_moving_average(
@@ -434,9 +410,9 @@ def _fused_batch_norm(inputs,
           update_moving_variance = moving_averages.assign_moving_average(
               moving_variance, variance, decay, zero_debias=False)
           return update_moving_mean, update_moving_variance
-        update_mean, update_variance = utils.smart_cond(is_training,
-                                                        _delay_updates,
-                                                        moving_vars_fn)
+
+        update_mean, update_variance = utils.smart_cond(
+            is_training, _delay_updates, moving_vars_fn)
         ops.add_to_collections(updates_collections, update_mean)
         ops.add_to_collections(updates_collections, update_variance)
 
@@ -482,9 +458,10 @@ def batch_norm(inputs,
   Can be used as a normalizer function for conv2d and fully_connected. The
   normalization is over all but the last dimension if `data_format` is `NHWC`
   and all but the second dimension if `data_format` is `NCHW`.  In case of a 2D
-  tensor this corresponds to the batch dimension, while in case of a 4D tensor this
+  tensor this corresponds to the batch dimension, while in case of a 4D tensor
+  this
   corresponds to the batch and space dimensions.
-  
+
   Note: when training, the moving_mean and moving_variance need to be updated.
   By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
   need to be added as a dependency to the `train_op`. For example:
@@ -592,10 +569,9 @@ def batch_norm(inputs,
   #   implementation in normalization_layers.BatchNormalization.
   inputs = ops.convert_to_tensor(inputs)
   rank = inputs.get_shape().ndims
-  possible_to_fuse = (batch_weights is None and
-                      not renorm and
-                      rank in [2, 4] and
-                      adjustment is None)
+  possible_to_fuse = (
+      batch_weights is None and not renorm and rank in [2, 4] and
+      adjustment is None)
   if fused and possible_to_fuse and (
       zero_debias_moving_mean or rank == 2 or
       updates_collections is not ops.GraphKeys.UPDATE_OPS):
@@ -623,7 +599,9 @@ def batch_norm(inputs,
 
   layer_variable_getter = _build_variable_getter()
   with variable_scope.variable_scope(
-      scope, 'BatchNorm', [inputs], reuse=reuse,
+      scope,
+      'BatchNorm', [inputs],
+      reuse=reuse,
       custom_getter=layer_variable_getter) as sc:
     inputs = ops.convert_to_tensor(inputs)
 
@@ -671,15 +649,15 @@ def batch_norm(inputs,
       outputs = layer.apply(inputs, training=is_training)
 
       # Add variables to collections.
-      _add_variable_to_collections(
-          layer.moving_mean, variables_collections, 'moving_mean')
-      _add_variable_to_collections(
-          layer.moving_variance, variables_collections, 'moving_variance')
+      _add_variable_to_collections(layer.moving_mean, variables_collections,
+                                   'moving_mean')
+      _add_variable_to_collections(layer.moving_variance, variables_collections,
+                                   'moving_variance')
       if layer.beta is not None:
         _add_variable_to_collections(layer.beta, variables_collections, 'beta')
       if layer.gamma is not None:
-        _add_variable_to_collections(
-            layer.gamma, variables_collections, 'gamma')
+        _add_variable_to_collections(layer.gamma, variables_collections,
+                                     'gamma')
 
       if activation_fn is not None:
         outputs = activation_fn(outputs)
@@ -719,8 +697,8 @@ def batch_norm(inputs,
       params_shape = inputs_shape[-1:]
       params_shape_broadcast = None
     if not params_shape.is_fully_defined():
-      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
-          inputs.name, params_shape))
+      raise ValueError('Inputs %s has undefined channels dimension %s.' %
+                       (inputs.name, params_shape))
 
     # Allocate parameters for the beta and gamma of the normalization.
     beta, gamma = None, None
@@ -731,23 +709,25 @@ def batch_norm(inputs,
                                                         'beta')
       beta_initializer = param_initializers.get('beta',
                                                 init_ops.zeros_initializer())
-      beta = variables.model_variable('beta',
-                                      shape=params_shape,
-                                      dtype=dtype,
-                                      initializer=beta_initializer,
-                                      collections=beta_collections,
-                                      trainable=trainable)
+      beta = variables.model_variable(
+          'beta',
+          shape=params_shape,
+          dtype=dtype,
+          initializer=beta_initializer,
+          collections=beta_collections,
+          trainable=trainable)
     if scale:
-      gamma_collections = utils.get_variable_collections(variables_collections,
-                                                         'gamma')
+      gamma_collections = utils.get_variable_collections(
+          variables_collections, 'gamma')
       gamma_initializer = param_initializers.get('gamma',
                                                  init_ops.ones_initializer())
-      gamma = variables.model_variable('gamma',
-                                       shape=params_shape,
-                                       dtype=dtype,
-                                       initializer=gamma_initializer,
-                                       collections=gamma_collections,
-                                       trainable=trainable)
+      gamma = variables.model_variable(
+          'gamma',
+          shape=params_shape,
+          dtype=dtype,
+          initializer=gamma_initializer,
+          collections=gamma_collections,
+          trainable=trainable)
 
     # Create moving_mean and moving_variance variables and add them to the
     # appropriate collections. We disable variable partitioning while creating
@@ -796,8 +776,8 @@ def batch_norm(inputs,
           mean, variance = nn.moments(inputs, moments_axes)
       else:
         if data_format == DATA_FORMAT_NCHW:
-          mean, variance = nn.weighted_moments(inputs, moments_axes,
-                                               batch_weights, keep_dims=True)
+          mean, variance = nn.weighted_moments(
+              inputs, moments_axes, batch_weights, keep_dims=True)
           mean = array_ops.reshape(mean, [-1])
           variance = array_ops.reshape(variance, [-1])
         else:
@@ -806,19 +786,21 @@ def batch_norm(inputs,
 
       moving_vars_fn = lambda: (moving_mean, moving_variance)
       if updates_collections is None:
+
         def _force_updates():
           """Internal function forces updates moving_vars if is_training."""
           update_moving_mean = moving_averages.assign_moving_average(
               moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
           update_moving_variance = moving_averages.assign_moving_average(
               moving_variance, variance, decay, zero_debias=False)
-          with ops.control_dependencies([update_moving_mean,
-                                         update_moving_variance]):
+          with ops.control_dependencies(
+              [update_moving_mean, update_moving_variance]):
             return array_ops.identity(mean), array_ops.identity(variance)
-        mean, variance = utils.smart_cond(is_training,
-                                          _force_updates,
+
+        mean, variance = utils.smart_cond(is_training, _force_updates,
                                           moving_vars_fn)
       else:
+
         def _delay_updates():
           """Internal function that delay updates moving_vars if is_training."""
           update_moving_mean = moving_averages.assign_moving_average(
@@ -827,9 +809,8 @@ def batch_norm(inputs,
               moving_variance, variance, decay, zero_debias=False)
           return update_moving_mean, update_moving_variance
 
-        update_mean, update_variance = utils.smart_cond(is_training,
-                                                        _delay_updates,
-                                                        moving_vars_fn)
+        update_mean, update_variance = utils.smart_cond(
+            is_training, _delay_updates, moving_vars_fn)
         ops.add_to_collections(updates_collections, update_mean)
         ops.add_to_collections(updates_collections, update_variance)
         # Use computed moments during training and moving_vars otherwise.
@@ -897,8 +878,8 @@ def bias_add(inputs,
   """
   if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
     raise ValueError('data_format has to be either NCHW or NHWC.')
-  with variable_scope.variable_scope(scope, 'BiasAdd', [inputs],
-                                     reuse=reuse) as sc:
+  with variable_scope.variable_scope(
+      scope, 'BiasAdd', [inputs], reuse=reuse) as sc:
     inputs = ops.convert_to_tensor(inputs)
     dtype = inputs.dtype.base_dtype
     inputs_shape = inputs.get_shape()
@@ -913,13 +894,16 @@ def bias_add(inputs,
       raise ValueError('`C` dimension must be known but is None')
     biases_collections = utils.get_variable_collections(variables_collections,
                                                         'biases')
-    biases = variables.model_variable('biases',
-                                      shape=[num_features,],
-                                      dtype=dtype,
-                                      initializer=initializer,
-                                      regularizer=regularizer,
-                                      collections=biases_collections,
-                                      trainable=trainable)
+    biases = variables.model_variable(
+        'biases',
+        shape=[
+            num_features,
+        ],
+        dtype=dtype,
+        initializer=initializer,
+        regularizer=regularizer,
+        collections=biases_collections,
+        trainable=trainable)
     outputs = nn.bias_add(inputs, biases, data_format=data_format)
     if activation_fn is not None:
       outputs = activation_fn(outputs)
@@ -1019,8 +1003,10 @@ def convolution(inputs,
   if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
     raise ValueError('Invalid data_format: %r' % (data_format,))
 
-  layer_variable_getter = _build_variable_getter(
-      {'bias': 'biases', 'kernel': 'weights'})
+  layer_variable_getter = _build_variable_getter({
+      'bias': 'biases',
+      'kernel': 'weights'
+  })
 
   with variable_scope.variable_scope(
       scope, 'Conv', [inputs], reuse=reuse,
@@ -1038,26 +1024,27 @@ def convolution(inputs,
       raise ValueError('Convolution not supported for input with rank',
                        input_rank)
 
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
-    layer = layer_class(filters=num_outputs,
-                        kernel_size=kernel_size,
-                        strides=stride,
-                        padding=padding,
-                        data_format=df,
-                        dilation_rate=rate,
-                        activation=None,
-                        use_bias=not normalizer_fn and biases_initializer,
-                        kernel_initializer=weights_initializer,
-                        bias_initializer=biases_initializer,
-                        kernel_regularizer=weights_regularizer,
-                        bias_regularizer=biases_regularizer,
-                        activity_regularizer=None,
-                        trainable=trainable,
-                        name=sc.name,
-                        dtype=inputs.dtype.base_dtype,
-                        _scope=sc,
-                        _reuse=reuse)
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
+    layer = layer_class(
+        filters=num_outputs,
+        kernel_size=kernel_size,
+        strides=stride,
+        padding=padding,
+        data_format=df,
+        dilation_rate=rate,
+        activation=None,
+        use_bias=not normalizer_fn and biases_initializer,
+        kernel_initializer=weights_initializer,
+        bias_initializer=biases_initializer,
+        kernel_regularizer=weights_regularizer,
+        bias_regularizer=biases_regularizer,
+        activity_regularizer=None,
+        trainable=trainable,
+        name=sc.name,
+        dtype=inputs.dtype.base_dtype,
+        _scope=sc,
+        _reuse=reuse)
     outputs = layer.apply(inputs)
 
     # Add variables to collections.
@@ -1073,6 +1060,7 @@ def convolution(inputs,
       outputs = activation_fn(outputs)
     return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
 
+
 convolution2d = convolution
 convolution3d = convolution
 
@@ -1148,13 +1136,14 @@ def convolution2d_in_plane(
     weights_shape = [kernel_h, kernel_w, 1, 1]
     weights_collections = utils.get_variable_collections(
         variables_collections, 'weights')
-    weights = variables.model_variable('weights',
-                                       shape=weights_shape,
-                                       dtype=dtype,
-                                       initializer=weights_initializer,
-                                       regularizer=weights_regularizer,
-                                       collections=weights_collections,
-                                       trainable=trainable)
+    weights = variables.model_variable(
+        'weights',
+        shape=weights_shape,
+        dtype=dtype,
+        initializer=weights_initializer,
+        regularizer=weights_regularizer,
+        collections=weights_collections,
+        trainable=trainable)
     depthwise_weights = array_ops.tile(weights, [1, 1, num_filters_in, 1])
     outputs = nn.depthwise_conv2d(inputs, depthwise_weights,
                                   [1, stride_h, stride_w, 1], padding)
@@ -1165,13 +1154,16 @@ def convolution2d_in_plane(
       if biases_initializer is not None:
         biases_collections = utils.get_variable_collections(
             variables_collections, 'biases')
-        biases = variables.model_variable('biases',
-                                          shape=[num_filters_in,],
-                                          dtype=dtype,
-                                          initializer=biases_initializer,
-                                          regularizer=biases_regularizer,
-                                          collections=biases_collections,
-                                          trainable=trainable)
+        biases = variables.model_variable(
+            'biases',
+            shape=[
+                num_filters_in,
+            ],
+            dtype=dtype,
+            initializer=biases_initializer,
+            regularizer=biases_regularizer,
+            collections=biases_collections,
+            trainable=trainable)
         outputs = nn.bias_add(outputs, biases)
 
     if activation_fn is not None:
@@ -1244,19 +1236,23 @@ def convolution2d_transpose(
     ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
     ValueError: If `C` dimension of `inputs` is None.
   """
-  layer_variable_getter = _build_variable_getter(
-      {'bias': 'biases', 'kernel': 'weights'})
+  layer_variable_getter = _build_variable_getter({
+      'bias': 'biases',
+      'kernel': 'weights'
+  })
 
   with variable_scope.variable_scope(
-      scope, 'Conv2d_transpose', [inputs], reuse=reuse,
+      scope,
+      'Conv2d_transpose', [inputs],
+      reuse=reuse,
       custom_getter=layer_variable_getter) as sc:
     if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
       raise ValueError('data_format has to be either NCHW or NHWC.')
 
     inputs = ops.convert_to_tensor(inputs)
 
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
     layer = convolutional_layers.Convolution2DTranspose(
         filters=num_outputs,
         kernel_size=kernel_size,
@@ -1353,19 +1349,23 @@ def convolution3d_transpose(
     ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`.
     ValueError: If `C` dimension of `inputs` is None.
   """
-  layer_variable_getter = _build_variable_getter(
-      {'bias': 'biases', 'kernel': 'weights'})
+  layer_variable_getter = _build_variable_getter({
+      'bias': 'biases',
+      'kernel': 'weights'
+  })
 
   with variable_scope.variable_scope(
-      scope, 'Conv3d_transpose', [inputs], reuse=reuse,
+      scope,
+      'Conv3d_transpose', [inputs],
+      reuse=reuse,
       custom_getter=layer_variable_getter) as sc:
     if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC):
       raise ValueError('data_format has to be either NCDHW or NDHWC.')
 
     inputs = ops.convert_to_tensor(inputs)
 
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
     layer = convolutional_layers.Convolution3DTranspose(
         filters=num_outputs,
         kernel_size=kernel_size,
@@ -1434,19 +1434,18 @@ def dropout(inputs,
   with variable_scope.variable_scope(
       scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc:
     inputs = ops.convert_to_tensor(inputs)
-    layer = core_layers.Dropout(rate=1 - keep_prob,
-                                noise_shape=noise_shape,
-                                seed=seed,
-                                name=sc.name,
-                                _scope=sc)
+    layer = core_layers.Dropout(
+        rate=1 - keep_prob,
+        noise_shape=noise_shape,
+        seed=seed,
+        name=sc.name,
+        _scope=sc)
     outputs = layer.apply(inputs, training=is_training)
     return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
 
 
 @add_arg_scope
-def flatten(inputs,
-            outputs_collections=None,
-            scope=None):
+def flatten(inputs, outputs_collections=None, scope=None):
   """Flattens the input while maintaining the batch_size.
 
     Assumes that the first dimension represents the batch.
@@ -1478,8 +1477,8 @@ def _sparse_inner_flatten(inputs, new_rank):
 
   outer_dimensions = inputs.dense_shape[:new_rank - 1]
   inner_dimensions = inputs.dense_shape[new_rank - 1:]
-  new_shape = array_ops.concat((outer_dimensions,
-                                [math_ops.reduce_prod(inner_dimensions)]), 0)
+  new_shape = array_ops.concat(
+      (outer_dimensions, [math_ops.reduce_prod(inner_dimensions)]), 0)
   flattened = sparse_ops.sparse_reshape(inputs, new_shape)
   return flattened
 
@@ -1545,10 +1544,18 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
   return utils.collect_named_outputs(output_collections, sc, flattened)
 
 
-def _model_variable_getter(getter, name, shape=None, dtype=None,
-                           initializer=None, regularizer=None, trainable=True,
-                           collections=None, caching_device=None,
-                           partitioner=None, rename=None, use_resource=None,
+def _model_variable_getter(getter,
+                           name,
+                           shape=None,
+                           dtype=None,
+                           initializer=None,
+                           regularizer=None,
+                           trainable=True,
+                           collections=None,
+                           caching_device=None,
+                           partitioner=None,
+                           rename=None,
+                           use_resource=None,
                            **_):
   """Getter that uses model_variable for compatibility with core layers."""
   short_name = name.split('/')[-1]
@@ -1557,25 +1564,34 @@ def _model_variable_getter(getter, name, shape=None, dtype=None,
     name_components[-1] = rename[short_name]
     name = '/'.join(name_components)
   return variables.model_variable(
-      name, shape=shape, dtype=dtype, initializer=initializer,
-      regularizer=regularizer, collections=collections, trainable=trainable,
-      caching_device=caching_device, partitioner=partitioner,
-      custom_getter=getter, use_resource=use_resource)
+      name,
+      shape=shape,
+      dtype=dtype,
+      initializer=initializer,
+      regularizer=regularizer,
+      collections=collections,
+      trainable=trainable,
+      caching_device=caching_device,
+      partitioner=partitioner,
+      custom_getter=getter,
+      use_resource=use_resource)
 
 
 def _build_variable_getter(rename=None):
   """Build a model variable getter that respects scope getter and renames."""
+
   # VariableScope will nest the getters
   def layer_variable_getter(getter, *args, **kwargs):
     kwargs['rename'] = rename
     return _model_variable_getter(getter, *args, **kwargs)
+
   return layer_variable_getter
 
 
 def _add_variable_to_collections(variable, collections_set, collections_name):
   """Adds variable (or all its parts) to all collections with that name."""
-  collections = utils.get_variable_collections(
-      collections_set, collections_name) or []
+  collections = utils.get_variable_collections(collections_set,
+                                               collections_name) or []
   variables_list = [variable]
   if isinstance(variable, tf_variables.PartitionedVariable):
     variables_list = [v for v in variable]
@@ -1644,15 +1660,19 @@ def fully_connected(inputs,
     ValueError: If x has rank less than 2 or if its last dimension is not set.
   """
   if not isinstance(num_outputs, six.integer_types):
-    raise ValueError(
-        'num_outputs should be int or long, got %s.' % (num_outputs,))
+    raise ValueError('num_outputs should be int or long, got %s.' %
+                     (num_outputs,))
 
-  layer_variable_getter = _build_variable_getter({'bias': 'biases',
-                                                  'kernel': 'weights'})
+  layer_variable_getter = _build_variable_getter({
+      'bias': 'biases',
+      'kernel': 'weights'
+  })
 
   with variable_scope.variable_scope(
-      scope, 'fully_connected', [inputs],
-      reuse=reuse, custom_getter=layer_variable_getter) as sc:
+      scope,
+      'fully_connected', [inputs],
+      reuse=reuse,
+      custom_getter=layer_variable_getter) as sc:
     inputs = ops.convert_to_tensor(inputs)
     layer = core_layers.Dense(
         units=num_outputs,
@@ -1758,15 +1778,17 @@ class GDN(base.Layer):
                inverse=False,
                beta_min=1e-6,
                gamma_init=.1,
-               reparam_offset=2 ** -18,
+               reparam_offset=2**-18,
                data_format='channels_last',
                activity_regularizer=None,
                trainable=True,
                name=None,
                **kwargs):
-    super(GDN, self).__init__(trainable=trainable, name=name,
-                              activity_regularizer=activity_regularizer,
-                              **kwargs)
+    super(GDN, self).__init__(
+        trainable=trainable,
+        name=name,
+        activity_regularizer=activity_regularizer,
+        **kwargs)
     self.inverse = inverse
     self._beta_min = beta_min
     self._gamma_init = gamma_init
@@ -1801,8 +1823,9 @@ class GDN(base.Layer):
     with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
       inputs = ops.convert_to_tensor(inputs, name='inputs')
       bound = ops.convert_to_tensor(bound, name='bound')
-      with ops.get_default_graph().gradient_override_map(
-          {'Maximum': 'GDNLowerBound'}):
+      with ops.get_default_graph().gradient_override_map({
+          'Maximum': 'GDNLowerBound'
+      }):
         return math_ops.maximum(inputs, bound, name=scope)
 
   @staticmethod
@@ -1829,12 +1852,14 @@ class GDN(base.Layer):
       raise ValueError('The channel dimension of the inputs to `GDN` '
                        'must be defined.')
     self._input_rank = input_shape.ndims
-    self.input_spec = base.InputSpec(ndim=input_shape.ndims,
-                                     axes={channel_axis: num_channels})
+    self.input_spec = base.InputSpec(
+        ndim=input_shape.ndims, axes={
+            channel_axis: num_channels
+        })
 
-    pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
+    pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype)
     beta_bound = array_ops.constant(
-        (self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
+        (self._beta_min + self._reparam_offset**2)**.5, dtype=self.dtype)
     gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
 
     def beta_initializer(shape, dtype=None, partition_info=None):
@@ -1848,19 +1873,21 @@ class GDN(base.Layer):
       eye = linalg_ops.eye(shape[0], dtype=dtype)
       return math_ops.sqrt(self._gamma_init * eye + pedestal)
 
-    beta = self.add_variable('reparam_beta',
-                             shape=[num_channels],
-                             initializer=beta_initializer,
-                             dtype=self.dtype,
-                             trainable=True)
+    beta = self.add_variable(
+        'reparam_beta',
+        shape=[num_channels],
+        initializer=beta_initializer,
+        dtype=self.dtype,
+        trainable=True)
     beta = self._lower_bound(beta, beta_bound)
     self.beta = math_ops.square(beta) - pedestal
 
-    gamma = self.add_variable('reparam_gamma',
-                              shape=[num_channels, num_channels],
-                              initializer=gamma_initializer,
-                              dtype=self.dtype,
-                              trainable=True)
+    gamma = self.add_variable(
+        'reparam_gamma',
+        shape=[num_channels, num_channels],
+        initializer=gamma_initializer,
+        dtype=self.dtype,
+        trainable=True)
     gamma = self._lower_bound(gamma, gamma_bound)
     self.gamma = math_ops.square(gamma) - pedestal
 
@@ -1875,8 +1902,11 @@ class GDN(base.Layer):
 
     # Compute normalization pool.
     if self.data_format == 'channels_first':
-      norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
-                                 data_format='NC' + 'DHW'[-(ndim - 2):])
+      norm_pool = nn.convolution(
+          math_ops.square(inputs),
+          gamma,
+          'VALID',
+          data_format='NC' + 'DHW' [-(ndim - 2):])
       if ndim == 3:
         norm_pool = array_ops.expand_dims(norm_pool, 2)
         norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
@@ -1918,7 +1948,7 @@ def gdn(inputs,
         inverse=False,
         beta_min=1e-6,
         gamma_init=.1,
-        reparam_offset=2 ** -18,
+        reparam_offset=2**-18,
         data_format='channels_last',
         activity_regularizer=None,
         trainable=True,
@@ -1984,17 +2014,18 @@ def gdn(inputs,
   Returns:
     Output tensor.
   """
-  layer = GDN(inverse=inverse,
-              beta_min=beta_min,
-              gamma_init=gamma_init,
-              reparam_offset=reparam_offset,
-              data_format=data_format,
-              activity_regularizer=activity_regularizer,
-              trainable=trainable,
-              name=name,
-              dtype=inputs.dtype.base_dtype,
-              _scope=name,
-              _reuse=reuse)
+  layer = GDN(
+      inverse=inverse,
+      beta_min=beta_min,
+      gamma_init=gamma_init,
+      reparam_offset=reparam_offset,
+      data_format=data_format,
+      activity_regularizer=activity_regularizer,
+      trainable=trainable,
+      name=name,
+      dtype=inputs.dtype.base_dtype,
+      _scope=name,
+      _reuse=reuse)
   return layer.apply(inputs)
 
 
@@ -2070,8 +2101,8 @@ def layer_norm(inputs,
       or if `inputs.shape[begin_params_axis:]` is not fully defined at
       graph build time.
   """
-  with variable_scope.variable_scope(scope, 'LayerNorm', [inputs],
-                                     reuse=reuse) as sc:
+  with variable_scope.variable_scope(
+      scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
     inputs = ops.convert_to_tensor(inputs)
     inputs_shape = inputs.shape
     inputs_rank = inputs_shape.ndims
@@ -2081,15 +2112,14 @@ def layer_norm(inputs,
     if begin_norm_axis < 0:
       begin_norm_axis = inputs_rank + begin_norm_axis
     if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
-      raise ValueError(
-          'begin_params_axis (%d) and begin_norm_axis (%d) '
-          'must be < rank(inputs) (%d)'
-          % (begin_params_axis, begin_norm_axis, inputs_rank))
+      raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
+                       'must be < rank(inputs) (%d)' %
+                       (begin_params_axis, begin_norm_axis, inputs_rank))
     params_shape = inputs_shape[begin_params_axis:]
     if not params_shape.is_fully_defined():
       raise ValueError(
-          'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (
-              inputs.name, begin_params_axis, inputs_shape))
+          'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
+          (inputs.name, begin_params_axis, inputs_shape))
     # Allocate parameters for the beta and gamma of the normalization.
     beta, gamma = None, None
     if center:
@@ -2103,8 +2133,8 @@ def layer_norm(inputs,
           collections=beta_collections,
           trainable=trainable)
     if scale:
-      gamma_collections = utils.get_variable_collections(variables_collections,
-                                                         'gamma')
+      gamma_collections = utils.get_variable_collections(
+          variables_collections, 'gamma')
       gamma = variables.model_variable(
           'gamma',
           shape=params_shape,
@@ -2118,7 +2148,11 @@ def layer_norm(inputs,
     # Compute layer normalization using the batch_normalization function.
     variance_epsilon = 1e-12
     outputs = nn.batch_normalization(
-        inputs, mean, variance, offset=beta, scale=gamma,
+        inputs,
+        mean,
+        variance,
+        offset=beta,
+        scale=gamma,
         variance_epsilon=variance_epsilon)
     outputs.set_shape(inputs_shape)
     if activation_fn is not None:
@@ -2164,13 +2198,14 @@ def max_pool2d(inputs,
     raise ValueError('data_format has to be either NCHW or NHWC.')
   with ops.name_scope(scope, 'MaxPool2D', [inputs]) as sc:
     inputs = ops.convert_to_tensor(inputs)
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
-    layer = pooling_layers.MaxPooling2D(pool_size=kernel_size,
-                                        strides=stride,
-                                        padding=padding,
-                                        data_format=df,
-                                        _scope=sc)
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
+    layer = pooling_layers.MaxPooling2D(
+        pool_size=kernel_size,
+        strides=stride,
+        padding=padding,
+        data_format=df,
+        _scope=sc)
     outputs = layer.apply(inputs)
     return utils.collect_named_outputs(outputs_collections, sc, outputs)
 
@@ -2213,13 +2248,14 @@ def max_pool3d(inputs,
     raise ValueError('data_format has to be either NCDHW or NDHWC.')
   with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc:
     inputs = ops.convert_to_tensor(inputs)
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
-    layer = pooling_layers.MaxPooling3D(pool_size=kernel_size,
-                                        strides=stride,
-                                        padding=padding,
-                                        data_format=df,
-                                        _scope=sc)
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
+    layer = pooling_layers.MaxPooling3D(
+        pool_size=kernel_size,
+        strides=stride,
+        padding=padding,
+        data_format=df,
+        _scope=sc)
     outputs = layer.apply(inputs)
     return utils.collect_named_outputs(outputs_collections, sc, outputs)
 
@@ -2272,8 +2308,8 @@ def pool(inputs,
 
   """
   # pylint: enable=line-too-long
-  with ops.name_scope(scope, '%s_pool' %
-                      (pooling_type.lower()), [inputs]) as sc:
+  with ops.name_scope(scope, '%s_pool' % (pooling_type.lower()),
+                      [inputs]) as sc:
     inputs = ops.convert_to_tensor(inputs)
     input_rank = inputs.get_shape().ndims
     if input_rank is None:
@@ -2318,18 +2354,16 @@ def one_hot_encoding(labels,
     labels = ops.convert_to_tensor(labels)
     if labels.dtype == dtypes.int32:
       labels = standard_ops.to_int64(labels)
-    outputs = standard_ops.one_hot(labels,
-                                   num_classes,
-                                   on_value=on_value,
-                                   off_value=off_value)
+    outputs = standard_ops.one_hot(
+        labels, num_classes, on_value=on_value, off_value=off_value)
     return utils.collect_named_outputs(outputs_collections, sc, outputs)
 
 
 def _apply_activation(y, activation_fn, output_collections):
   if activation_fn is not None:
     y = activation_fn(y)
-  ops.add_to_collections(list(output_collections or []) +
-                         [ops.GraphKeys.ACTIVATIONS], y)
+  ops.add_to_collections(
+      list(output_collections or []) + [ops.GraphKeys.ACTIVATIONS], y)
   return y
 
 
@@ -2374,7 +2408,7 @@ def repeat(inputs, repetitions, layer, *args, **kwargs):
         scope = 'repeat'
     outputs = inputs
     for i in range(repetitions):
-      kwargs['scope'] = scope + '_' + str(i+1)
+      kwargs['scope'] = scope + '_' + str(i + 1)
       outputs = layer(outputs, *args, **kwargs)
     return outputs
 
@@ -2389,8 +2423,8 @@ def _scale_gradient_grad(op, grad):
   return [grad * op.inputs[1], None]
 
 
-@function.Defun(python_grad_func=_scale_gradient_grad,
-                shape_func=_scale_gradient_shape)
+@function.Defun(
+    python_grad_func=_scale_gradient_grad, shape_func=_scale_gradient_shape)
 def scale_gradient(inputs, gradient_multiplier):
   """Identity operation, but with the gradient multiplied by a tensor.
 
@@ -2495,18 +2529,21 @@ def separable_convolution2d(
   """
   if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
     raise ValueError('data_format has to be either NCHW or NHWC.')
-  layer_variable_getter = _build_variable_getter(
-      {'bias': 'biases',
-       'depthwise_kernel': 'depthwise_weights',
-       'pointwise_kernel': 'pointwise_weights'})
+  layer_variable_getter = _build_variable_getter({
+      'bias': 'biases',
+      'depthwise_kernel': 'depthwise_weights',
+      'pointwise_kernel': 'pointwise_weights'
+  })
 
   with variable_scope.variable_scope(
-      scope, 'SeparableConv2d', [inputs], reuse=reuse,
+      scope,
+      'SeparableConv2d', [inputs],
+      reuse=reuse,
       custom_getter=layer_variable_getter) as sc:
     inputs = ops.convert_to_tensor(inputs)
 
-    df = ('channels_first' if data_format and data_format.startswith('NC')
-          else 'channels_last')
+    df = ('channels_first'
+          if data_format and data_format.startswith('NC') else 'channels_last')
     if num_outputs is not None:
       # Apply separable conv using the SeparableConvolution2D layer.
       layer = convolutional_layers.SeparableConvolution2D(
@@ -2539,8 +2576,8 @@ def separable_convolution2d(
       _add_variable_to_collections(layer.pointwise_kernel,
                                    variables_collections, 'weights')
       if layer.bias is not None:
-        _add_variable_to_collections(layer.bias,
-                                     variables_collections, 'biases')
+        _add_variable_to_collections(layer.bias, variables_collections,
+                                     'biases')
 
       if normalizer_fn is not None:
         normalizer_params = normalizer_params or {}
@@ -2555,8 +2592,7 @@ def separable_convolution2d(
       weights_collections = utils.get_variable_collections(
           variables_collections, 'weights')
 
-      depthwise_shape = [kernel_h, kernel_w,
-                         num_filters_in, depth_multiplier]
+      depthwise_shape = [kernel_h, kernel_w, num_filters_in, depth_multiplier]
       depthwise_weights = variables.model_variable(
           'depthwise_weights',
           shape=depthwise_shape,
@@ -2570,9 +2606,13 @@ def separable_convolution2d(
                      1, stride_h, stride_w, 1
                  ]
 
-      outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding,
-                                    rate=utils.two_element_tuple(rate),
-                                    data_format=data_format)
+      outputs = nn.depthwise_conv2d(
+          inputs,
+          depthwise_weights,
+          strides,
+          padding,
+          rate=utils.two_element_tuple(rate),
+          data_format=data_format)
       num_outputs = depth_multiplier * num_filters_in
 
       if normalizer_fn is not None:
@@ -2582,13 +2622,16 @@ def separable_convolution2d(
         if biases_initializer is not None:
           biases_collections = utils.get_variable_collections(
               variables_collections, 'biases')
-          biases = variables.model_variable('biases',
-                                            shape=[num_outputs,],
-                                            dtype=dtype,
-                                            initializer=biases_initializer,
-                                            regularizer=biases_regularizer,
-                                            trainable=trainable,
-                                            collections=biases_collections)
+          biases = variables.model_variable(
+              'biases',
+              shape=[
+                  num_outputs,
+              ],
+              dtype=dtype,
+              initializer=biases_initializer,
+              regularizer=biases_regularizer,
+              trainable=trainable,
+              collections=biases_collections)
           outputs = nn.bias_add(outputs, biases, data_format=data_format)
 
     if activation_fn is not None:
@@ -2673,23 +2716,24 @@ def spatial_softmax(features,
 
     with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]):
       # Create tensors for x and y coordinate values, scaled to range [-1, 1].
-      pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height),
-                                        math_ops.lin_space(-1., 1., num=width),
-                                        indexing='ij')
+      pos_x, pos_y = array_ops.meshgrid(
+          math_ops.lin_space(-1., 1., num=height),
+          math_ops.lin_space(-1., 1., num=width),
+          indexing='ij')
       pos_x = array_ops.reshape(pos_x, [height * width])
       pos_y = array_ops.reshape(pos_y, [height * width])
-      
+
       if temperature is None:
         temp_initializer = init_ops.ones_initializer()
       else:
         temp_initializer = init_ops.constant_initializer(temperature)
-          
+
       if not trainable:
         temp_collections = None
       else:
         temp_collections = utils.get_variable_collections(
-              variables_collections, 'temperature')
-      
+            variables_collections, 'temperature')
+
       temperature = variables.model_variable(
           'temperature',
           shape=(),
@@ -2703,14 +2747,14 @@ def spatial_softmax(features,
         features = array_ops.reshape(
             array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width])
 
-      softmax_attention = nn.softmax(features/temperature)
+      softmax_attention = nn.softmax(features / temperature)
       expected_x = math_ops.reduce_sum(
           pos_x * softmax_attention, [1], keep_dims=True)
       expected_y = math_ops.reduce_sum(
           pos_y * softmax_attention, [1], keep_dims=True)
       expected_xy = array_ops.concat([expected_x, expected_y], 1)
-      feature_keypoints = array_ops.reshape(
-          expected_xy, [-1, num_channels.value * 2])
+      feature_keypoints = array_ops.reshape(expected_xy,
+                                            [-1, num_channels.value * 2])
       feature_keypoints.set_shape([None, num_channels.value * 2])
   return feature_keypoints
 
@@ -2762,7 +2806,7 @@ def stack(inputs, layer, stack_args, **kwargs):
         scope = 'stack'
     outputs = inputs
     for i in range(len(stack_args)):
-      kwargs['scope'] = scope + '_' + str(i+1)
+      kwargs['scope'] = scope + '_' + str(i + 1)
       layer_args = stack_args[i]
       if not isinstance(layer_args, (list, tuple)):
         layer_args = [layer_args]
@@ -2793,11 +2837,10 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None):
       raise ValueError('The input rank must be known.')
     input_rank = len(inputs.get_shape().as_list())
     if dim < 0 or dim >= input_rank:
-      raise ValueError(
-          'dim must be positive but smaller than the input rank.')
+      raise ValueError('dim must be positive but smaller than the input rank.')
 
-    lengths = math_ops.sqrt(epsilon + math_ops.reduce_sum(
-        math_ops.square(inputs), dim, True))
+    lengths = math_ops.sqrt(
+        epsilon + math_ops.reduce_sum(math_ops.square(inputs), dim, True))
     multiples = []
     if dim > 0:
       multiples.append(array_ops.ones([dim], dtypes.int32))
@@ -2938,29 +2981,31 @@ def legacy_fully_connected(x,
       raise ValueError('last dimension of x must be known but is None')
     dtype = x.dtype.base_dtype
 
-    weight_collections = set(list(weight_collections or []) +
-                             [ops.GraphKeys.GLOBAL_VARIABLES])
-    w = variable_scope.get_variable('weights',
-                                    shape=[num_input_units, num_output_units],
-                                    dtype=dtype,
-                                    initializer=weight_init,
-                                    collections=weight_collections,
-                                    regularizer=weight_regularizer,
-                                    trainable=trainable)
-    x_2_dim = x if len(dims) <= 2 else array_ops.reshape(x,
-                                                         [-1, num_input_units])
+    weight_collections = set(
+        list(weight_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES])
+    w = variable_scope.get_variable(
+        'weights',
+        shape=[num_input_units, num_output_units],
+        dtype=dtype,
+        initializer=weight_init,
+        collections=weight_collections,
+        regularizer=weight_regularizer,
+        trainable=trainable)
+    x_2_dim = x if len(dims) <= 2 else array_ops.reshape(
+        x, [-1, num_input_units])
     y = standard_ops.matmul(x_2_dim, w)
 
     if bias_init is not None:
-      bias_collections = set(list(bias_collections or []) +
-                             [ops.GraphKeys.GLOBAL_VARIABLES])
-      b = variable_scope.get_variable('bias',
-                                      shape=[num_output_units],
-                                      dtype=dtype,
-                                      initializer=bias_init,
-                                      collections=bias_collections,
-                                      regularizer=bias_regularizer,
-                                      trainable=trainable)
+      bias_collections = set(
+          list(bias_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES])
+      b = variable_scope.get_variable(
+          'bias',
+          shape=[num_output_units],
+          dtype=dtype,
+          initializer=bias_init,
+          collections=bias_collections,
+          regularizer=bias_regularizer,
+          trainable=trainable)
 
       y = nn.bias_add(y, b)
 
index 47509ecca6f6eb59b2b463b7ccd95cb33328ad1a..a7c97a1da2baf29914337094c6153447c997af08 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
-"""Wrapper optimizer for Model Average """
+"""Wrapper optimizer for Model Average."""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import constant_op
-from tensorflow.python.training import optimizer
-from tensorflow.python.training import session_run_hook
-from tensorflow.python.ops import math_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import session_run_hook
 
-GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GLOBAL_VARIABLE_NAME = "global_center_variable"
 
 
 class ModelAverageCustomGetter(object):
-  """Custom_getter class is used to do:
+  """Custom_getter class is used to do.
+
   1. Change trainable variables to local collection and place them at worker
     device
   2. Generate global variables
@@ -73,15 +73,18 @@ class ModelAverageCustomGetter(object):
   def __call__(self, getter, name, trainable, collections, *args, **kwargs):
     if trainable:
       with ops.device(self._worker_device):
-        local_var = getter(name, trainable=True,
-                           collections=[ops.GraphKeys.LOCAL_VARIABLES],
-                           *args, **kwargs)
+        local_var = getter(
+            name,
+            trainable=True,
+            collections=[ops.GraphKeys.LOCAL_VARIABLES],
+            *args,
+            **kwargs)
 
       global_variable = variable_scope.variable(
-        name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
-        initial_value=local_var.initialized_value(),
-        trainable=False,
-        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+          name="%s/%s" % (GLOBAL_VARIABLE_NAME, name),
+          initial_value=local_var.initialized_value(),
+          trainable=False,
+          collections=[ops.GraphKeys.GLOBAL_VARIABLES])
 
       self._local_2_global[local_var] = global_variable
       return local_var
@@ -91,6 +94,7 @@ class ModelAverageCustomGetter(object):
 
 class ModelAverageOptimizer(optimizer.Optimizer):
   """Wrapper optimizer that implements the Model Average algorithm.
+
   This is a sync optimizer. During the training, each worker will update
   the local variables and maintains its own local_step, which starts from 0
   and is incremented by 1 after each update of local variables. Whenever the
@@ -99,15 +103,14 @@ class ModelAverageOptimizer(optimizer.Optimizer):
   local variables will be assigned by global center variables.
   """
 
-  def __init__(
-      self,
-      opt,
-      num_worker,
-      is_chief,
-      ma_custom_getter,
-      interval_steps=100,
-      use_locking=True,
-      name="ModelAverageOptimizer"):
+  def __init__(self,
+               opt,
+               num_worker,
+               is_chief,
+               ma_custom_getter,
+               interval_steps=100,
+               use_locking=True,
+               name="ModelAverageOptimizer"):
     """Construct a new model average optimizer.
 
     Args:
@@ -124,18 +127,18 @@ class ModelAverageOptimizer(optimizer.Optimizer):
     self._opt = opt
     self._num_worker = num_worker
     self._is_chief = is_chief
-    self._local_2_global = ma_custom_getter._local_2_global
+    self._local_2_global = ma_custom_getter._local_2_global  # pylint:disable=protected-access
     self._interval_steps = interval_steps
     self._accumulator_list = []
     self._chief_init_op = None
 
     self._local_step = variable_scope.get_variable(
-      initializer=0,
-      trainable=False,
-      collections=[ops.GraphKeys.LOCAL_VARIABLES],
-      name="local_step")
+        initializer=0,
+        trainable=False,
+        collections=[ops.GraphKeys.LOCAL_VARIABLES],
+        name="local_step")
 
-    self._opt._prepare()
+    self._opt._prepare()  # pylint:disable=protected-access
 
   def compute_gradients(self, *args, **kwargs):
     """Compute gradients of "loss" for the variables in "var_list".
@@ -159,10 +162,12 @@ class ModelAverageOptimizer(optimizer.Optimizer):
 
     Returns:
       An update op
+
+    Raises:
+      ValueError: if var_list is empty.
     """
     if not var_list:
-      raise ValueError(
-        'The list of local_variables should not be empty')
+      raise ValueError("The list of local_variables should not be empty")
     update_ops = []
     global_center_vars = [self._local_2_global[var] for var in var_list]
     for lvar, gvar in zip(var_list, global_center_vars):
@@ -204,28 +209,29 @@ class ModelAverageOptimizer(optimizer.Optimizer):
     apply_updates = self._opt.apply_gradients(grads_and_vars)
     with ops.control_dependencies([apply_updates]):
       local_update = state_ops.assign_add(
-        self._local_step, 1, name='local_step_update').op
+          self._local_step, 1, name="local_step_update").op
 
     # update global variables.
-    def _Update_global_variables():
+    def _update_global_variables():  # pylint: disable=missing-docstring
       local_vars = [v for g, v in grads_and_vars if g is not None]
       global_vars = [self._local_2_global[v] for v in local_vars]
       # sync queue
       with ops.colocate_with(global_step):
-        sync_queue = data_flow_ops.FIFOQueue(-1, [dtypes.bool], shapes=[[]],
-                                             shared_name='sync_queue')
+        sync_queue = data_flow_ops.FIFOQueue(
+            -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue")
       train_ops = []
       aggregated_vars = []
-      with ops.name_scope(None, self._name + '/global'):
+      with ops.name_scope(None, self._name + "/global"):
         for var, gvar in zip(local_vars, global_vars):
+          # pylint: disable=protected-access
           with ops.device(gvar.device):
             if isinstance(var._ref(), ops.Tensor):
               var_accum = data_flow_ops.ConditionalAccumulator(
-                var.dtype,
-                shape=var.get_shape(),
-                shared_name=gvar.name + "/var_accum")
+                  var.dtype,
+                  shape=var.get_shape(),
+                  shared_name=gvar.name + "/var_accum")
               train_ops.append(
-                var_accum.apply_grad(var._ref(), local_step=global_step))
+                  var_accum.apply_grad(var._ref(), local_step=global_step))
               aggregated_vars.append(var_accum.take_grad(self._num_worker))
             else:
               raise ValueError("Unknown local variable type!")
@@ -254,24 +260,26 @@ class ModelAverageOptimizer(optimizer.Optimizer):
       return local_update_op
 
     with ops.control_dependencies([local_update]):
-      condition = math_ops.equal(math_ops.mod(
-        self._local_step, self._interval_steps), 0)
+      condition = math_ops.equal(
+          math_ops.mod(self._local_step, self._interval_steps), 0)
       conditional_update = control_flow_ops.cond(
-        condition, _Update_global_variables, control_flow_ops.no_op)
+          condition, _update_global_variables, control_flow_ops.no_op)
 
     chief_init_ops = []
     for accum, dev in self._accumulator_list:
       with ops.device(dev):
         chief_init_ops.append(
-          accum.set_global_step(
-            global_step, name="SetGlobalStep"))
+            accum.set_global_step(global_step, name="SetGlobalStep"))
     self._chief_init_op = control_flow_ops.group(*(chief_init_ops))
 
     return conditional_update
 
   def get_init_op(self):
-    """Returns the op to let all the local variables equal to the global
-     variables before the training begins"""
+    """Returns the op.
+
+    This method lets all the local variables equal to the global
+    variables before the training begins.
+    """
     return self._local_vars_update(variables.trainable_variables())
 
   def make_session_run_hook(self):
@@ -279,12 +287,13 @@ class ModelAverageOptimizer(optimizer.Optimizer):
     return _ModelAverageOptimizerHook(self, self._is_chief)
 
 
-class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook):
+class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook):  # pylint: disable=missing-docstring
+
   def __init__(self, ma_optimizer, is_chief):
     """Creates hook to handle ModelAverageOptimizer initialization ops.
 
     Args:
-      ea_optimizer: `ModelAverageOptimizer` which this hook will initialize.
+      ma_optimizer: `ModelAverageOptimizer` which this hook will initialize.
       is_chief: `Bool`, whether is this a chief replica or not.
     """
     self._ma_optimizer = ma_optimizer
@@ -295,5 +304,5 @@ class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook):
     self._global_init_op = None
     if self._is_chief:
       self._global_init_op = variables.global_variables_initializer()
-      self._chief_init_op = self._ma_optimizer._chief_init_op
+      self._chief_init_op = self._ma_optimizer._chief_init_op  # pylint: disable=protected-access
     self._variable_init_op = self._ma_optimizer.get_init_op()
index a73aa772bbd211d439d93c7feda8fb4d7f1dc75e..29ecd228390ff9e736a0f3bc69bc019f728f5f78 100644 (file)
@@ -18,18 +18,18 @@ from __future__ import division
 from __future__ import print_function
 
 import portpicker
+
+from tensorflow.contrib.opt.python.training import model_average_optimizer
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
 from tensorflow.python.training import gradient_descent
 from tensorflow.python.training import server_lib
 from tensorflow.python.training import training
 from tensorflow.python.training import training_util
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import device_setter
-from tensorflow.contrib.opt.python.training.model_average_optimizer import \
-  ModelAverageOptimizer, ModelAverageCustomGetter, GLOBAL_VARIABLE_NAME
 
 
 def create_local_cluster(num_workers, num_ps, protocol="grpc"):
@@ -37,20 +37,20 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
   worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
   ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
   cluster_dict = {
-    "worker": ["localhost:%s" % port for port in worker_ports],
-    "ps": ["localhost:%s" % port for port in ps_ports]
+      "worker": ["localhost:%s" % port for port in worker_ports],
+      "ps": ["localhost:%s" % port for port in ps_ports]
   }
   cs = server_lib.ClusterSpec(cluster_dict)
 
   workers = [
-    server_lib.Server(
-      cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
-    for ix in range(num_workers)
+      server_lib.Server(
+          cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+      for ix in range(num_workers)
   ]
   ps_servers = [
-    server_lib.Server(
-      cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
-    for ix in range(num_ps)
+      server_lib.Server(
+          cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+      for ix in range(num_ps)
   ]
 
   return cluster_dict, workers, ps_servers
@@ -67,16 +67,16 @@ def _get_workers(num_workers, steps, workers):
     is_chief = (worker_id == 0)
     with graph.as_default():
       worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
-      ma_coustom = ModelAverageCustomGetter(
-        worker_device=worker_device)
-      with variable_scope.variable_scope('',
-                                         custom_getter=ma_coustom), ops.device(
-        device_setter.replica_device_setter(worker_device=worker_device,
-                                            ps_device="/job:ps/task:0/cpu:0",
-                                            ps_tasks=1)):
-
-        global_step = variables.Variable(0, name='global_step',
-                                         trainable=False)
+      ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
+          worker_device=worker_device)
+      with variable_scope.variable_scope(
+          "", custom_getter=ma_coustom), ops.device(
+              device_setter.replica_device_setter(
+                  worker_device=worker_device,
+                  ps_device="/job:ps/task:0/cpu:0",
+                  ps_tasks=1)):
+
+        global_step = variables.Variable(0, name="global_step", trainable=False)
         var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
         var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
 
@@ -88,22 +88,20 @@ def _get_workers(num_workers, steps, workers):
           grads_0 = constant_op.constant(-2.0)
           grads_1 = constant_op.constant(-2.0)
         sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
-        opt = ModelAverageOptimizer(
-          opt=sgd_opt,
-          num_worker=num_workers,
-          ma_custom_getter=ma_coustom,
-          is_chief=is_chief,
-          interval_steps=steps
-        )
+        opt = model_average_optimizer.ModelAverageOptimizer(
+            opt=sgd_opt,
+            num_worker=num_workers,
+            ma_custom_getter=ma_coustom,
+            is_chief=is_chief,
+            interval_steps=steps)
         train_op = [
-          opt.apply_gradients(
-            [[grads_0, var_0],
-             [grads_1, var_1]], global_step)
+            opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+                                global_step)
         ]
       easgd_hook = opt.make_session_run_hook()
       # Creates MonitoredSession
-      sess = training.MonitoredTrainingSession(workers[worker_id].target,
-                                               hooks=[easgd_hook])
+      sess = training.MonitoredTrainingSession(
+          workers[worker_id].target, hooks=[easgd_hook])
 
     sessions.append(sess)
     graphs.append(graph)
@@ -112,6 +110,7 @@ def _get_workers(num_workers, steps, workers):
 
 
 class ModelAverageOptimizerTest(test.TestCase):
+
   def _run(self, train_op, sess):
     sess.run(train_op)
 
@@ -119,18 +118,18 @@ class ModelAverageOptimizerTest(test.TestCase):
     num_workers = 2
     steps = 2
     num_ps = 1
-    cluster, workers, _ = create_local_cluster(num_workers=num_workers,
-                                               num_ps=num_ps)
+    _, workers, _ = create_local_cluster(
+        num_workers=num_workers, num_ps=num_ps)
 
-    sessions, graphs, train_ops = _get_workers(num_workers,
-                                               steps,
-                                               workers)
+    sessions, graphs, train_ops = _get_workers(num_workers, steps, workers)
 
-    var_0 = graphs[0].get_tensor_by_name('v0:0')
-    var_1 = graphs[0].get_tensor_by_name('v1:0')
+    var_0 = graphs[0].get_tensor_by_name("v0:0")
+    var_1 = graphs[0].get_tensor_by_name("v1:0")
     global_step = training_util.get_global_step(graphs[0])
-    global_var_0 = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
-    global_var_1 = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+    global_var_0 = graphs[0].get_tensor_by_name(
+        model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+    global_var_1 = graphs[0].get_tensor_by_name(
+        model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
 
     # Verify the initialized value.
     self.assertAllEqual(0.0, sessions[0].run(var_0))
@@ -150,9 +149,9 @@ class ModelAverageOptimizerTest(test.TestCase):
 
     # iteration 2, global varibale update
     thread_0 = self.checkedThread(
-      target=self._run, args=(train_ops[0], sessions[0]))
+        target=self._run, args=(train_ops[0], sessions[0]))
     thread_1 = self.checkedThread(
-      target=self._run, args=(train_ops[1], sessions[1]))
+        target=self._run, args=(train_ops[1], sessions[1]))
     thread_0.start()
     thread_1.start()
     thread_0.join()
@@ -175,20 +174,20 @@ class ModelAverageOptimizerTest(test.TestCase):
 
   def testPS2TasksWithClusterSpecClass(self):
     cluster_spec = server_lib.ClusterSpec({
-      "ps": ["ps0:2222", "ps1:2222"],
-      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+        "ps": ["ps0:2222", "ps1:2222"],
+        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     })
     worker_device = "/job:worker/task:0"
-    ma_coustom = ModelAverageCustomGetter(
-      worker_device=worker_device)
+    ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
+        worker_device=worker_device)
     from tensorflow.python.training import device_setter
     with ops.device(
         device_setter.replica_device_setter(cluster=cluster_spec,
                                             worker_device=worker_device,
                                             ps_device="/job:ps")), \
-         variable_scope.variable_scope('', custom_getter=ma_coustom):
+         variable_scope.variable_scope("", custom_getter=ma_coustom):
       v = variable_scope.get_variable(initializer=[1, 2], name="v")
-      w = variable_scope.get_variable(initializer=[2, 1], name='w')
+      w = variable_scope.get_variable(initializer=[2, 1], name="w")
       v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w]
       self.assertDeviceEqual("/job:worker/task:0", v.device)
       self.assertDeviceEqual("job:ps/task:0", v_g.device)
@@ -196,5 +195,5 @@ class ModelAverageOptimizerTest(test.TestCase):
       self.assertDeviceEqual("job:ps/task:1", w_g.device)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
   test.main()
index 30a207757040ebd34b65175364f768c8c832d548..a25de55e18b223db2b724aafb54b18d8f48a5baa 100644 (file)
@@ -53,12 +53,11 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
 
   def testPeriodicResampleBasic3D(self):
 
-    input_tensor = numpy.arange(2*2*4).reshape((2, 2, 4))
+    input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4))
     desired_shape = numpy.array([4, 4, None])
-    output_tensor = numpy.array([[[0], [2], [4], [6]],
-                                 [[1], [3], [5], [7]],
-                                 [[8], [10], [12], [14]],
-                                 [[9], [11], [13], [15]]])
+    output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]],
+                                 [[8], [10], [12], [14]], [[9], [11], [13],
+                                                           [15]]])
 
     # NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
     with self.test_session():
@@ -72,24 +71,18 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
 
   def testPeriodicResampleBasic4D(self):
 
-    input_tensor = numpy.arange(2*2*2*8).reshape((2, 2, 2, 8))
+    input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8))
     desired_shape = numpy.array([4, 4, 4, None])
-    output_tensor = numpy.array([[[[0], [4], [8], [12]],
-                                  [[2], [6], [10], [14]],
-                                  [[16], [20], [24], [28]],
-                                  [[18], [22], [26], [30]]],
-                                 [[[1], [5], [9], [13]],
-                                  [[3], [7], [11], [15]],
-                                  [[17], [21], [25], [29]],
-                                  [[19], [23], [27], [31]]],
-                                 [[[32], [36], [40], [44]],
-                                  [[34], [38], [42], [46]],
-                                  [[48], [52], [56], [60]],
-                                  [[50], [54], [58], [62]]],
-                                 [[[33], [37], [41], [45]],
-                                  [[35], [39], [43], [47]],
-                                  [[49], [53], [57], [61]],
-                                  [[51], [55], [59], [63]]]])
+    output_tensor = numpy.array(
+        [[[[0], [4], [8], [12]], [[2], [6], [10], [14]],
+          [[16], [20], [24], [28]], [[18], [22], [26], [30]]],
+         [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25],
+                                                          [29]],
+          [[19], [23], [27],
+           [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]],
+                    [[48], [52], [56], [60]], [[50], [54], [58], [62]]],
+         [[[33], [37], [41], [45]], [[35], [39], [43], [47]],
+          [[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
 
     # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
     with self.test_session():
@@ -111,5 +104,5 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
         periodic_resample(input_tensor, [None, 4, 4]).eval()
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
   googletest.main()
index 70aaba172814ff9f7866889bb59a75c1c4204469..c780e85d72bc5b86b96f90f88f4ec158966962c2 100644 (file)
@@ -53,14 +53,12 @@ class RNNCellTest(test.TestCase):
       batch_size = 3
       input_size = 4
       expected_output = np.array(
-          [[0.121753, 0.121753],
-           [0.103349, 0.103349],
-           [0.100178, 0.100178]],
+          [[0.121753, 0.121753], [0.103349, 0.103349], [0.100178, 0.100178]],
           dtype=np.float32)
       expected_state = np.array(
-          [[0.137523, 0.137523, 0.121753, 0.121753],
-           [0.105450, 0.105450, 0.103349, 0.103349],
-           [0.100742, 0.100742, 0.100178, 0.100178]],
+          [[0.137523, 0.137523, 0.121753, 0.121753], [
+              0.105450, 0.105450, 0.103349, 0.103349
+          ], [0.100742, 0.100742, 0.100178, 0.100178]],
           dtype=np.float32)
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
@@ -69,14 +67,14 @@ class RNNCellTest(test.TestCase):
         output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
             num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([output, state], {
-            x.name:
-                np.array([[1., 1., 1., 1.],
-                          [2., 2., 2., 2.],
-                          [3., 3., 3., 3.]]),
-            m.name:
-                0.1 * np.ones((batch_size, state_size))
-        })
+        res = sess.run(
+            [output, state], {
+                x.name:
+                    np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+                              [3., 3., 3., 3.]]),
+                m.name:
+                    0.1 * np.ones((batch_size, state_size))
+            })
         # This is a smoke test: Only making sure expected values didn't change.
         self.assertEqual(len(res), 2)
         self.assertAllClose(res[0], expected_output)
@@ -101,14 +99,14 @@ class RNNCellTest(test.TestCase):
             frequency_skip=frequency_skip,
             forget_bias=1.0)(x, m)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([output, state], {
-            x.name:
-                np.array([[1., 1., 1., 1.],
-                          [2., 2., 2., 2.],
-                          [3., 3., 3., 3.]]),
-            m.name:
-                0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
-        })
+        res = sess.run(
+            [output, state], {
+                x.name:
+                    np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+                              [3., 3., 3., 3.]]),
+                m.name:
+                    0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
+            })
         self.assertEqual(len(res), 2)
         # The numbers in results were not calculated, this is mostly just a
         # smoke test.
@@ -141,17 +139,14 @@ class RNNCellTest(test.TestCase):
             state_is_tuple=True)
         inputs = constant_op.constant(
             np.array(
-                [[1., 1., 1., 1.],
-                 [2., 2., 2., 2.],
-                 [3., 3., 3., 3.]],
+                [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
                 dtype=np.float32),
             dtype=dtypes.float32)
         state_value = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
-        init_state = cell.state_tuple_type(
-            *([state_value, state_value] * num_shifts))
+        init_state = cell.state_tuple_type(*(
+            [state_value, state_value] * num_shifts))
         output, state = cell(inputs, init_state)
         sess.run([variables.global_variables_initializer()])
         res = sess.run([output, state])
@@ -198,11 +193,10 @@ class RNNCellTest(test.TestCase):
                 dtype=np.float32),
             dtype=dtypes.float32)
         state_value = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
-        init_state = cell.state_tuple_type(
-            *([state_value, state_value] * total_blocks))
+        init_state = cell.state_tuple_type(*(
+            [state_value, state_value] * total_blocks))
         output, state = cell(inputs, init_state)
         sess.run([variables.global_variables_initializer()])
         res = sess.run([output, state])
@@ -230,20 +224,28 @@ class RNNCellTest(test.TestCase):
     frequency_skip = 1
     num_shifts = int((input_size - feature_size) / frequency_skip + 1)
     expected_output = np.array(
-        [[0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
-          0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699],
-         [0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
-          0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171],
-         [0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
-          0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250]],
+        [[
+            0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
+            0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699
+        ], [
+            0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
+            0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171
+        ], [
+            0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
+            0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250
+        ]],
         dtype=np.float32)
     expected_state = np.array(
-        [[0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
-          0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865],
-         [0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
-          0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245],
-         [0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
-          0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390]],
+        [[
+            0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
+            0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865
+        ], [
+            0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
+            0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245
+        ], [
+            0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
+            0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390
+        ]],
         dtype=np.float32)
     for state_is_tuple in [False, True]:
       with self.test_session() as sess:
@@ -259,18 +261,16 @@ class RNNCellTest(test.TestCase):
               couple_input_forget_gates=True,
               state_is_tuple=state_is_tuple)
           inputs = constant_op.constant(
-              np.array([[1., 1., 1., 1.],
-                        [2., 2., 2., 2.],
-                        [3., 3., 3., 3.]],
-                       dtype=np.float32),
+              np.array(
+                  [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+                  dtype=np.float32),
               dtype=dtypes.float32)
           if state_is_tuple:
             state_value = constant_op.constant(
-                0.1 * np.ones(
-                    (batch_size, num_units), dtype=np.float32),
+                0.1 * np.ones((batch_size, num_units), dtype=np.float32),
                 dtype=dtypes.float32)
-            init_state = cell.state_tuple_type(
-                *([state_value, state_value] * num_shifts))
+            init_state = cell.state_tuple_type(*(
+                [state_value, state_value] * num_shifts))
           else:
             init_state = constant_op.constant(
                 0.1 * np.ones(
@@ -302,32 +302,40 @@ class RNNCellTest(test.TestCase):
       frequency_skip = 1
       num_shifts = int((input_size - feature_size) / frequency_skip + 1)
       expected_output = np.array(
-          [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
-            0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
-            0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
-            0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218],
-           [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
-            0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
-            0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
-            0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840],
-           [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
-            0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
-            0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
-            0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731]],
+          [[
+              0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
+              0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
+              0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
+              0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218
+          ], [
+              0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
+              0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
+              0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
+              0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840
+          ], [
+              0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
+              0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
+              0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
+              0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731
+          ]],
           dtype=np.float32)
       expected_state = np.array(
-          [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
-            0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
-            0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
-            0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773],
-           [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
-            0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
-            0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
-            0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984],
-           [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
-            0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
-            1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
-            0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035]],
+          [[
+              0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
+              0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
+              0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
+              0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773
+          ], [
+              0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
+              0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
+              0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
+              0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984
+          ], [
+              1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
+              0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
+              1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
+              0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035
+          ]],
           dtype=np.float32)
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
@@ -339,17 +347,16 @@ class RNNCellTest(test.TestCase):
             forget_bias=1.0,
             num_frequency_blocks=[num_shifts])
         inputs = constant_op.constant(
-            np.array([[1.0, 1.1, 1.2, 1.3],
-                      [2.0, 2.1, 2.2, 2.3],
-                      [3.0, 3.1, 3.2, 3.3]],
-                     dtype=np.float32),
+            np.array(
+                [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
+                 [3.0, 3.1, 3.2, 3.3]],
+                dtype=np.float32),
             dtype=dtypes.float32)
         state_value = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
-        init_state = cell.state_tuple_type(
-            *([state_value, state_value] * num_shifts * 2))
+        init_state = cell.state_tuple_type(*(
+            [state_value, state_value] * num_shifts * 2))
         output, state = cell(inputs, init_state)
         sess.run([variables.global_variables_initializer()])
         res = sess.run([output, state])
@@ -375,32 +382,40 @@ class RNNCellTest(test.TestCase):
       frequency_skip = 1
       num_shifts = int((input_size - feature_size) / frequency_skip + 1)
       expected_output = np.array(
-          [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
-            0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
-            0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
-            0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071],
-           [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
-            0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
-            0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
-            0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958],
-           [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
-            0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
-            0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
-            0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858]],
+          [[
+              0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
+              0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
+              0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
+              0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071
+          ], [
+              0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
+              0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
+              0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
+              0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958
+          ], [
+              0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
+              0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
+              0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
+              0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858
+          ]],
           dtype=np.float32)
       expected_state = np.array(
-          [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
-            0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
-            0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
-            0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446],
-           [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
-            0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
-            0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
-            0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429],
-           [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
-            0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
-            0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
-            0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597]],
+          [[
+              0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
+              0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
+              0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
+              0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446
+          ], [
+              0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
+              0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
+              0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
+              0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429
+          ], [
+              1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
+              0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
+              0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
+              0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597
+          ]],
           dtype=np.float32)
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
@@ -413,17 +428,16 @@ class RNNCellTest(test.TestCase):
             num_frequency_blocks=[num_shifts],
             backward_slice_offset=1)
         inputs = constant_op.constant(
-            np.array([[1.0, 1.1, 1.2, 1.3],
-                      [2.0, 2.1, 2.2, 2.3],
-                      [3.0, 3.1, 3.2, 3.3]],
-                     dtype=np.float32),
+            np.array(
+                [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
+                 [3.0, 3.1, 3.2, 3.3]],
+                dtype=np.float32),
             dtype=dtypes.float32)
         state_value = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
-        init_state = cell.state_tuple_type(
-            *([state_value, state_value] * num_shifts * 2))
+        init_state = cell.state_tuple_type(*(
+            [state_value, state_value] * num_shifts * 2))
         output, state = cell(inputs, init_state)
         sess.run([variables.global_variables_initializer()])
         res = sess.run([output, state])
@@ -474,8 +488,8 @@ class RNNCellTest(test.TestCase):
     for state_is_tuple in [False, True]:
       with ops.Graph().as_default():
         with self.test_session() as sess:
-          with variable_scope.variable_scope("state_is_tuple_" + str(
-              state_is_tuple)):
+          with variable_scope.variable_scope(
+              "state_is_tuple_" + str(state_is_tuple)):
             lstm_cell = rnn_cell.BasicLSTMCell(
                 num_units, state_is_tuple=state_is_tuple)
             cell = contrib_rnn_cell.AttentionCellWrapper(
@@ -525,16 +539,15 @@ class RNNCellTest(test.TestCase):
     for state_is_tuple in [False, True]:
       with ops.Graph().as_default():
         with self.test_session() as sess:
-          with variable_scope.variable_scope("state_is_tuple_" + str(
-              state_is_tuple)):
+          with variable_scope.variable_scope(
+              "state_is_tuple_" + str(state_is_tuple)):
             lstm_cell = rnn_cell.BasicLSTMCell(
                 num_units, state_is_tuple=state_is_tuple)
             cell = contrib_rnn_cell.AttentionCellWrapper(
                 lstm_cell, attn_length, state_is_tuple=state_is_tuple)
             if state_is_tuple:
               zeros = constant_op.constant(
-                  0.1 * np.ones(
-                      [batch_size, num_units], dtype=np.float32),
+                  0.1 * np.ones([batch_size, num_units], dtype=np.float32),
                   dtype=dtypes.float32)
               attn_state_zeros = constant_op.constant(
                   0.1 * np.ones(
@@ -579,22 +592,25 @@ class RNNCellTest(test.TestCase):
          [1.018088, 0.378983, -0.572179, 0.268591]],
         dtype=np.float32)
     expected_state = np.array(
-        [[0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
-          0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
-          0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
-          0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
-          0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
-          0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
-          0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
-          0.51843399],
-         [0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
-          0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
-          0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
-          0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
-          0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
-          0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
-          0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
-          0.70582712]],
+        [[
+            0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
+            0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
+            0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
+            0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
+            0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
+            0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
+            0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
+            0.51843399
+        ], [
+            0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
+            0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
+            0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
+            0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
+            0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
+            0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
+            0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
+            0.70582712
+        ]],
         dtype=np.float32)
     seed = 12345
     random_seed.set_random_seed(seed)
@@ -602,7 +618,8 @@ class RNNCellTest(test.TestCase):
     for state_is_tuple in [False, True]:
       with session.Session() as sess:
         with variable_scope.variable_scope(
-            "state_is_tuple", reuse=state_is_tuple,
+            "state_is_tuple",
+            reuse=state_is_tuple,
             initializer=init_ops.glorot_uniform_initializer()):
           lstm_cell = rnn_cell.BasicLSTMCell(
               num_units, state_is_tuple=state_is_tuple)
@@ -646,36 +663,31 @@ class RNNCellTest(test.TestCase):
   def testNASCell(self):
     num_units = 6
     batch_size = 3
-    expected_output = np.array([[0.576751, 0.576751, 0.576751, 0.576751,
-                                 0.576751, 0.576751],
-                                [0.618936, 0.618936, 0.618936, 0.618936,
-                                 0.618936, 0.618936],
-                                [0.627393, 0.627393, 0.627393, 0.627393,
-                                 0.627393, 0.627393]])
-    expected_state = np.array([[0.71579772, 0.71579772, 0.71579772, 0.71579772,
-                                0.71579772, 0.71579772, 0.57675087, 0.57675087,
-                                0.57675087, 0.57675087, 0.57675087, 0.57675087],
-                               [0.78041625, 0.78041625, 0.78041625, 0.78041625,
-                                0.78041625, 0.78041625, 0.6189357, 0.6189357,
-                                0.61893570, 0.6189357, 0.6189357, 0.6189357],
-                               [0.79457647, 0.79457647, 0.79457647, 0.79457647,
-                                0.79457653, 0.79457653, 0.62739348, 0.62739348,
-                                0.62739348, 0.62739348, 0.62739348, 0.62739348]
-                              ])
+    expected_output = np.array(
+        [[0.576751, 0.576751, 0.576751, 0.576751, 0.576751, 0.576751],
+         [0.618936, 0.618936, 0.618936, 0.618936, 0.618936, 0.618936],
+         [0.627393, 0.627393, 0.627393, 0.627393, 0.627393, 0.627393]])
+    expected_state = np.array([[
+        0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772,
+        0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087
+    ], [
+        0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625,
+        0.6189357, 0.6189357, 0.61893570, 0.6189357, 0.6189357, 0.6189357
+    ], [
+        0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
+        0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
+    ]])
     with self.test_session() as sess:
       with variable_scope.variable_scope(
-          "nas_test",
-          initializer=init_ops.constant_initializer(0.5)):
+          "nas_test", initializer=init_ops.constant_initializer(0.5)):
         cell = contrib_rnn_cell.NASCell(num_units=num_units)
         inputs = constant_op.constant(
-            np.array([[1., 1., 1., 1.],
-                      [2., 2., 2., 2.],
-                      [3., 3., 3., 3.]],
-                     dtype=np.float32),
+            np.array(
+                [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+                dtype=np.float32),
             dtype=dtypes.float32)
         state_value = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
         init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
         output, state = cell(inputs, init_state)
@@ -699,39 +711,34 @@ class RNNCellTest(test.TestCase):
     num_units = 6
     batch_size = 3
     num_proj = 5
-    expected_output = np.array([[1.697418, 1.697418, 1.697418, 1.697418,
-                                 1.697418],
-                                [1.840037, 1.840037, 1.840037, 1.840037,
-                                 1.840037],
-                                [1.873985, 1.873985, 1.873985, 1.873985,
-                                 1.873985]])
-    expected_state = np.array([[0.69855207, 0.69855207, 0.69855207, 0.69855207,
-                                0.69855207, 0.69855207, 1.69741797, 1.69741797,
-                                1.69741797, 1.69741797, 1.69741797],
-                               [0.77073824, 0.77073824, 0.77073824, 0.77073824,
-                                0.77073824, 0.77073824, 1.84003687, 1.84003687,
-                                1.84003687, 1.84003687, 1.84003687],
-                               [0.78973997, 0.78973997, 0.78973997, 0.78973997,
-                                0.78973997, 0.78973997, 1.87398517, 1.87398517,
-                                1.87398517, 1.87398517, 1.87398517]])
+    expected_output = np.array(
+        [[1.697418, 1.697418, 1.697418, 1.697418,
+          1.697418], [1.840037, 1.840037, 1.840037, 1.840037, 1.840037],
+         [1.873985, 1.873985, 1.873985, 1.873985, 1.873985]])
+    expected_state = np.array([[
+        0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207,
+        1.69741797, 1.69741797, 1.69741797, 1.69741797, 1.69741797
+    ], [
+        0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824,
+        1.84003687, 1.84003687, 1.84003687, 1.84003687, 1.84003687
+    ], [
+        0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
+        1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
+    ]])
     with self.test_session() as sess:
       with variable_scope.variable_scope(
-          "nas_proj_test",
-          initializer=init_ops.constant_initializer(0.5)):
+          "nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
         cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
         inputs = constant_op.constant(
-            np.array([[1., 1., 1., 1.],
-                      [2., 2., 2., 2.],
-                      [3., 3., 3., 3.]],
-                     dtype=np.float32),
+            np.array(
+                [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+                dtype=np.float32),
             dtype=dtypes.float32)
         state_value_c = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
         state_value_h = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_proj), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_proj), dtype=np.float32),
             dtype=dtypes.float32)
         init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
         output, state = cell(inputs, init_state)
@@ -755,24 +762,20 @@ class RNNCellTest(test.TestCase):
     num_units = 2
     batch_size = 3
     expected_state_and_output = np.array(
-        [[0.13752282, 0.13752282],
-         [0.10545051, 0.10545051],
+        [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
          [0.10074195, 0.10074195]],
         dtype=np.float32)
     with self.test_session() as sess:
       with variable_scope.variable_scope(
-          "ugrnn_cell_test",
-          initializer=init_ops.constant_initializer(0.5)):
+          "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
         cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
         inputs = constant_op.constant(
-            np.array([[1., 1., 1., 1.],
-                      [2., 2., 2., 2.],
-                      [3., 3., 3., 3.]],
-                     dtype=np.float32),
+            np.array(
+                [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+                dtype=np.float32),
             dtype=dtypes.float32)
         init_state = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
         output, state = cell(inputs, init_state)
         sess.run([variables.global_variables_initializer()])
@@ -786,13 +789,11 @@ class RNNCellTest(test.TestCase):
     num_units = 2
     batch_size = 3
     expected_state = np.array(
-        [[0.13752282, 0.13752282],
-         [0.10545051, 0.10545051],
+        [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
          [0.10074195, 0.10074195]],
         dtype=np.float32)
     expected_output = np.array(
-        [[2.00431061, 2.00431061],
-         [4.00060606, 4.00060606],
+        [[2.00431061, 2.00431061], [4.00060606, 4.00060606],
          [6.00008249, 6.00008249]],
         dtype=np.float32)
     with self.test_session() as sess:
@@ -802,14 +803,12 @@ class RNNCellTest(test.TestCase):
         cell = contrib_rnn_cell.IntersectionRNNCell(
             num_units=num_units, num_in_proj=num_units)
         inputs = constant_op.constant(
-            np.array([[1., 1., 1., 1.],
-                      [2., 2., 2., 2.],
-                      [3., 3., 3., 3.]],
-                     dtype=np.float32),
+            np.array(
+                [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+                dtype=np.float32),
             dtype=dtypes.float32)
         init_state = constant_op.constant(
-            0.1 * np.ones(
-                (batch_size, num_units), dtype=np.float32),
+            0.1 * np.ones((batch_size, num_units), dtype=np.float32),
             dtype=dtypes.float32)
         output, state = cell(inputs, init_state)
         sess.run([variables.global_variables_initializer()])
@@ -824,19 +823,17 @@ class RNNCellTest(test.TestCase):
     batch_size = 3
     cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
     inputs = constant_op.constant(
-        np.array([[1., 1., 1., 1.],
-                  [2., 2., 2., 2.],
-                  [3., 3., 3., 3.]],
-                 dtype=np.float32),
+        np.array(
+            [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+            dtype=np.float32),
         dtype=dtypes.float32)
     init_state = constant_op.constant(
-        0.1 * np.ones(
-            (batch_size, num_units), dtype=np.float32),
+        0.1 * np.ones((batch_size, num_units), dtype=np.float32),
         dtype=dtypes.float32)
-    with self.assertRaisesRegexp(
-        ValueError, "Must have input size == output size for "
-                    "Intersection RNN. To fix, num_in_proj should "
-                    "be set to num_units at cell init."):
+    with self.assertRaisesRegexp(ValueError,
+                                 "Must have input size == output size for "
+                                 "Intersection RNN. To fix, num_in_proj should "
+                                 "be set to num_units at cell init."):
       cell(inputs, init_state)
 
   def testPhasedLSTMCell(self):
@@ -845,13 +842,11 @@ class RNNCellTest(test.TestCase):
       batch_size = 3
       input_size = 4
       expected_state_c = np.array(
-          [[6.450831e-04, 4.697885e-04],
-           [9.862894e-05, 7.212213e-04],
+          [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04],
            [4.401947e-04, 9.143004e-04]],
           dtype=np.float32)
       expected_state_h = np.array(
-          [[4.621217e-04, 3.365449e-04],
-           [7.438179e-05, 5.439147e-04],
+          [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04],
            [3.347936e-04, 6.953785e-04]],
           dtype=np.float32)
       with variable_scope.variable_scope(
@@ -864,14 +859,14 @@ class RNNCellTest(test.TestCase):
         output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
             (t, x), state0)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([output, state], {
-            t.name:
-                np.array([[1.], [2.], [3.]]),
-            x.name:
-                np.array([[1., 1., 1., 1.],
-                          [2., 2., 2., 2.],
-                          [3., 3., 3., 3.]]),
-        })
+        res = sess.run(
+            [output, state], {
+                t.name:
+                    np.array([[1.], [2.], [3.]]),
+                x.name:
+                    np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+                              [3., 3., 3., 3.]]),
+            })
         # This is a smoke test, making sure expected values are unchanged.
         self.assertEqual(len(res), 2)
         self.assertAllClose(res[0], res[1].h)
@@ -880,36 +875,32 @@ class RNNCellTest(test.TestCase):
 
   def testConv1DLSTMCell(self):
     with self.test_session() as sess:
-      shape = [2,1]
+      shape = [2, 1]
       filter_size = [3]
       num_features = 1
       batch_size = 2
       expected_state_c = np.array(
-          [[[1.4375670191], [1.4375670191]],
-           [[2.7542609292], [2.7542609292]]],
+          [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]],
           dtype=np.float32)
       expected_state_h = np.array(
-          [[[0.6529865603], [0.6529865603]],
-           [[0.8736877431], [0.8736877431]]],
+          [[[0.6529865603], [0.6529865603]], [[0.8736877431], [0.8736877431]]],
           dtype=np.float32)
       with variable_scope.variable_scope(
-          "root", initializer=init_ops.constant_initializer(1.0/2.0)):
+          "root", initializer=init_ops.constant_initializer(1.0 / 2.0)):
         x = array_ops.placeholder(dtypes.float32, [None, None, 1])
-        cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape,
-                                               kernel_shape=filter_size,
-                                               output_channels=num_features)
+        cell = contrib_rnn_cell.Conv1DLSTMCell(
+            input_shape=shape,
+            kernel_shape=filter_size,
+            output_channels=num_features)
         hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
         output, state = cell(x, hidden)
 
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([output, state], {
-            hidden[0].name:
-                np.array([[[1.],[1.]],
-                          [[2.],[2.]]]),
-            x.name:
-                np.array([[[1.],[1.]],
-                          [[2.],[2.]]]),
-        })
+        res = sess.run(
+            [output, state], {
+                hidden[0].name: np.array([[[1.], [1.]], [[2.], [2.]]]),
+                x.name: np.array([[[1.], [1.]], [[2.], [2.]]]),
+            })
         # This is a smoke test, making sure expected values are unchanged.
         self.assertEqual(len(res), 2)
         self.assertAllClose(res[0], res[1].h)
@@ -918,44 +909,40 @@ class RNNCellTest(test.TestCase):
 
   def testConv2DLSTMCell(self):
     with self.test_session() as sess:
-      shape = [2,2,1]
-      filter_size = [3,3]
+      shape = [2, 2, 1]
+      filter_size = [3, 3]
       num_features = 1
       batch_size = 2
       expected_state_c = np.array(
-          [[[[1.4375670191], [1.4375670191]],
-            [[1.4375670191], [1.4375670191]]],
-           [[[2.7542609292], [2.7542609292]],
-            [[2.7542609292], [2.7542609292]]]],
+          [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]],
+           [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
+           ]],
           dtype=np.float32)
       expected_state_h = np.array(
-          [[[[0.6529865603], [0.6529865603]],
-            [[0.6529865603], [0.6529865603]]],
-           [[[0.8736877431], [0.8736877431]],
-            [[0.8736877431], [0.8736877431]]]],
+          [[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]],
+           [[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
+           ]],
           dtype=np.float32)
       with variable_scope.variable_scope(
-          "root", initializer=init_ops.constant_initializer(1.0/4.0)):
+          "root", initializer=init_ops.constant_initializer(1.0 / 4.0)):
         x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
-        cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape,
-                                               kernel_shape=filter_size,
-                                               output_channels=num_features)
+        cell = contrib_rnn_cell.Conv2DLSTMCell(
+            input_shape=shape,
+            kernel_shape=filter_size,
+            output_channels=num_features)
         hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
         output, state = cell(x, hidden)
 
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([output, state], {
-            hidden[0].name:
-                np.array([[[[1.],[1.]],
-                           [[1.],[1.]]],
-                          [[[2.],[2.]],
-                           [[2.],[2.]]]]),
-            x.name:
-                np.array([[[[1.],[1.]],
-                           [[1.],[1.]]],
-                          [[[2.],[2.]],
-                           [[2.],[2.]]]]),
-        })
+        res = sess.run(
+            [output, state], {
+                hidden[0].name:
+                    np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
+                                                             [[2.], [2.]]]]),
+                x.name:
+                    np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
+                                                             [[2.], [2.]]]]),
+            })
         # This is a smoke test, making sure expected values are unchanged.
         self.assertEqual(len(res), 2)
         self.assertAllClose(res[0], res[1].h)
@@ -964,36 +951,33 @@ class RNNCellTest(test.TestCase):
 
   def testConv3DLSTMCell(self):
     with self.test_session() as sess:
-      shape = [2,2,2,1]
-      filter_size = [3,3,3]
+      shape = [2, 2, 2, 1]
+      filter_size = [3, 3, 3]
       num_features = 1
       batch_size = 2
       expected_state_c = np.array(
-         [[[[[1.4375670191], [1.4375670191]],
-            [[1.4375670191], [1.4375670191]]],
-           [[[1.4375670191], [1.4375670191]],
-            [[1.4375670191], [1.4375670191]]]],
-          [[[[2.7542609292], [2.7542609292]],
-            [[2.7542609292], [2.7542609292]]],
-           [[[2.7542609292], [2.7542609292]],
-            [[2.7542609292], [2.7542609292]]]]],
+          [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]
+            ], [[[1.4375670191], [1.4375670191]], [[1.4375670191],
+                                                   [1.4375670191]]]],
+           [[[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
+            ], [[[2.7542609292], [2.7542609292]], [[2.7542609292],
+                                                   [2.7542609292]]]]],
           dtype=np.float32)
       expected_state_h = np.array(
-         [[[[[0.6529865603], [0.6529865603]],
-            [[0.6529865603], [0.6529865603]]],
-           [[[0.6529865603], [0.6529865603]],
-            [[0.6529865603], [0.6529865603]]]],
-          [[[[0.8736877431], [0.8736877431]],
-            [[0.8736877431], [0.8736877431]]],
-           [[[0.8736877431], [0.8736877431]],
-            [[0.8736877431], [0.8736877431]]]]],
+          [[[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]
+            ], [[[0.6529865603], [0.6529865603]], [[0.6529865603],
+                                                   [0.6529865603]]]],
+           [[[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
+            ], [[[0.8736877431], [0.8736877431]], [[0.8736877431],
+                                                   [0.8736877431]]]]],
           dtype=np.float32)
       with variable_scope.variable_scope(
-          "root", initializer=init_ops.constant_initializer(1.0/8.0)):
+          "root", initializer=init_ops.constant_initializer(1.0 / 8.0)):
         x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
-        cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape,
-                                               kernel_shape=filter_size,
-                                               output_channels=num_features)
+        cell = contrib_rnn_cell.Conv3DLSTMCell(
+            input_shape=shape,
+            kernel_shape=filter_size,
+            output_channels=num_features)
         hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
         output, state = cell(x, hidden)
 
@@ -1056,8 +1040,8 @@ class RNNCellTest(test.TestCase):
             num_units=num_units, number_of_groups=number_of_groups)
         cell = rnn_cell.LSTMCell(num_units=num_units)
         self.assertTrue(isinstance(gcell.state_size, tuple))
-        zero_state = gcell.zero_state(batch_size=batch_size,
-                                      dtype=dtypes.float32)
+        zero_state = gcell.zero_state(
+            batch_size=batch_size, dtype=dtypes.float32)
         gh, gs = gcell(x, zero_state)
         h, g = cell(x, zero_state)
 
@@ -1080,16 +1064,16 @@ class RNNCellTest(test.TestCase):
         glstm_input = array_ops.ones([batch_size, num_units])
         gcell = contrib_rnn_cell.GLSTMCell(
             num_units=num_units, number_of_groups=number_of_groups)
-        gcell_zero_state = gcell.zero_state(batch_size=batch_size,
-                                            dtype=dtypes.float32)
+        gcell_zero_state = gcell.zero_state(
+            batch_size=batch_size, dtype=dtypes.float32)
         gh, gs = gcell(glstm_input, gcell_zero_state)
 
         # input for LSTM cell simulating single G-LSTM group
         lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
         # note division by number_of_groups. This cell one simulates G-LSTM group
         cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
-        cell_zero_state = cell.zero_state(batch_size=batch_size,
-                                          dtype=dtypes.float32)
+        cell_zero_state = cell.zero_state(
+            batch_size=batch_size, dtype=dtypes.float32)
         h, g = cell(lstm_input, cell_zero_state)
 
         sess.run([variables.global_variables_initializer()])
@@ -1099,6 +1083,7 @@ class RNNCellTest(test.TestCase):
         self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
                             h_res, 1e-5)
 
+
 class LayerNormBasicLSTMCellTest(test.TestCase):
 
   # NOTE: all the values in the current test case have been calculated.
@@ -1119,13 +1104,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
         g, out_m = cell(x, state)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([g, out_m], {
-            x.name: np.array([[1., 1.]]),
-            c0.name: 0.1 * np.asarray([[0, 1]]),
-            h0.name: 0.1 * np.asarray([[2, 3]]),
-            c1.name: 0.1 * np.asarray([[4, 5]]),
-            h1.name: 0.1 * np.asarray([[6, 7]]),
-        })
+        res = sess.run(
+            [g, out_m], {
+                x.name: np.array([[1., 1.]]),
+                c0.name: 0.1 * np.asarray([[0, 1]]),
+                h0.name: 0.1 * np.asarray([[2, 3]]),
+                c1.name: 0.1 * np.asarray([[4, 5]]),
+                h1.name: 0.1 * np.asarray([[6, 7]]),
+            })
 
         expected_h = np.array([[-0.38079708, 0.38079708]])
         expected_state0_c = np.array([[-1.0, 1.0]])
@@ -1155,11 +1141,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
         g, out_m = cell(x, state)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([g, out_m], {
-            x.name: np.array([[1., 1., 1.]]),
-            c.name: 0.1 * np.asarray([[0, 1]]),
-            h.name: 0.1 * np.asarray([[2, 3]]),
-        })
+        res = sess.run(
+            [g, out_m], {
+                x.name: np.array([[1., 1., 1.]]),
+                c.name: 0.1 * np.asarray([[0, 1]]),
+                h.name: 0.1 * np.asarray([[2, 3]]),
+            })
 
         expected_h = np.array([[-0.38079708, 0.38079708]])
         expected_c = np.array([[-1.0, 1.0]])
@@ -1168,7 +1155,6 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         self.assertAllClose(res[1].c, expected_c, 1e-5)
         self.assertAllClose(res[1].h, expected_h, 1e-5)
 
-
   def testBasicLSTMCellWithoutNorm(self):
     """Tests that BasicLSTMCell with layer_norm=False."""
     with self.test_session() as sess:
@@ -1186,19 +1172,20 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
         g, out_m = cell(x, state)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([g, out_m], {
-          x.name: np.array([[1., 1.]]),
-          c0.name: 0.1 * np.asarray([[0, 1]]),
-          h0.name: 0.1 * np.asarray([[2, 3]]),
-          c1.name: 0.1 * np.asarray([[4, 5]]),
-          h1.name: 0.1 * np.asarray([[6, 7]]),
-        })
-
-        expected_h = np.array([[ 0.70230919, 0.72581059]])
-        expected_state0_c = np.array([[ 0.8020075,  0.89599884]])
-        expected_state0_h = np.array([[ 0.56668288,  0.60858738]])
-        expected_state1_c = np.array([[ 1.17500675,  1.26892781]])
-        expected_state1_h = np.array([[ 0.70230919,  0.72581059]])
+        res = sess.run(
+            [g, out_m], {
+                x.name: np.array([[1., 1.]]),
+                c0.name: 0.1 * np.asarray([[0, 1]]),
+                h0.name: 0.1 * np.asarray([[2, 3]]),
+                c1.name: 0.1 * np.asarray([[4, 5]]),
+                h1.name: 0.1 * np.asarray([[6, 7]]),
+            })
+
+        expected_h = np.array([[0.70230919, 0.72581059]])
+        expected_state0_c = np.array([[0.8020075, 0.89599884]])
+        expected_state0_h = np.array([[0.56668288, 0.60858738]])
+        expected_state1_c = np.array([[1.17500675, 1.26892781]])
+        expected_state1_h = np.array([[0.70230919, 0.72581059]])
 
         actual_h = res[0]
         actual_state0_c = res[1][0].c
@@ -1215,21 +1202,22 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
       with variable_scope.variable_scope(
           "other", initializer=init_ops.constant_initializer(0.5)) as vs:
         x = array_ops.zeros(
-          [1, 3])  # Test BasicLSTMCell with input_size != num_units.
+            [1, 3])  # Test BasicLSTMCell with input_size != num_units.
         c = array_ops.zeros([1, 2])
         h = array_ops.zeros([1, 2])
         state = rnn_cell.LSTMStateTuple(c, h)
         cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
         g, out_m = cell(x, state)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([g, out_m], {
-          x.name: np.array([[1., 1., 1.]]),
-          c.name: 0.1 * np.asarray([[0, 1]]),
-          h.name: 0.1 * np.asarray([[2, 3]]),
-        })
-
-        expected_h = np.array([[ 0.64121795, 0.68166804]])
-        expected_c = np.array([[ 0.88477188, 0.98103917]])
+        res = sess.run(
+            [g, out_m], {
+                x.name: np.array([[1., 1., 1.]]),
+                c.name: 0.1 * np.asarray([[0, 1]]),
+                h.name: 0.1 * np.asarray([[2, 3]]),
+            })
+
+        expected_h = np.array([[0.64121795, 0.68166804]])
+        expected_c = np.array([[0.88477188, 0.98103917]])
         self.assertEqual(len(res), 2)
         self.assertAllClose(res[0], expected_h, 1e-5)
         self.assertAllClose(res[1].c, expected_c, 1e-5)
@@ -1250,13 +1238,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
             [contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
         h, (s0, s1) = cell(x, (state0, state1))
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([h, s0, s1], {
-            x.name: np.array([[1., 1.]]),
-            c0.name: 0.1 * np.asarray([[0, 1]]),
-            h0.name: 0.1 * np.asarray([[2, 3]]),
-            c1.name: 0.1 * np.asarray([[4, 5]]),
-            h1.name: 0.1 * np.asarray([[6, 7]]),
-        })
+        res = sess.run(
+            [h, s0, s1], {
+                x.name: np.array([[1., 1.]]),
+                c0.name: 0.1 * np.asarray([[0, 1]]),
+                h0.name: 0.1 * np.asarray([[2, 3]]),
+                c1.name: 0.1 * np.asarray([[4, 5]]),
+                h1.name: 0.1 * np.asarray([[6, 7]]),
+            })
 
         expected_h = np.array([[-0.38079708, 0.38079708]])
         expected_h0 = np.array([[-0.38079708, 0.38079708]])
@@ -1344,11 +1333,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
 
         g, s = cell(x, state)
         sess.run([variables.global_variables_initializer()])
-        res = sess.run([g, s], {
-            x.name: np.ones([1, 5]),
-            c.name: np.ones([1, 5]),
-            h.name: np.ones([1, 5]),
-        })
+        res = sess.run(
+            [g, s], {
+                x.name: np.ones([1, 5]),
+                c.name: np.ones([1, 5]),
+                h.name: np.ones([1, 5]),
+            })
 
         # Since the returned tensors are of size [1,n]
         # get the first component right now.
@@ -1374,35 +1364,35 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         self.assertIn(dropped_count, allowed_low)
 
 
-def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
-                                num_layers, max_time, compiled):
+def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, num_layers,
+                                max_time, compiled):
   with variable_scope.variable_scope(
       "root",
       initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
     inputs = variable_scope.get_variable(
-        "inputs", initializer=random_ops.random_uniform(
+        "inputs",
+        initializer=random_ops.random_uniform(
             (max_time, batch_size, input_depth), seed=1))
     maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
     cell = rnn_cell.MultiRNNCell(
         [maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
-    initial_state = cell.zero_state(
-        batch_size=batch_size, dtype=dtypes.float32)
+    initial_state = cell.zero_state(batch_size=batch_size, dtype=dtypes.float32)
     outputs, final_state = rnn.dynamic_rnn(
-        cell=cell, inputs=inputs, initial_state=initial_state,
-        time_major=True)
+        cell=cell, inputs=inputs, initial_state=initial_state, time_major=True)
     flat_final_state = nest.flatten(final_state)
     trainable_variables = variables.trainable_variables()
     outputs_grad = gradients_impl.gradients(
-        [outputs],
-        trainable_variables + [inputs] + nest.flatten(initial_state))
+        [outputs], trainable_variables + [inputs] + nest.flatten(initial_state))
     final_state_grad = gradients_impl.gradients(
         flat_final_state,
         trainable_variables + [inputs] + nest.flatten(initial_state))
 
-    return {"outputs": outputs,
-            "final_state": flat_final_state,
-            "outputs_grad": outputs_grad,
-            "final_state_grad": final_state_grad}
+    return {
+        "outputs": outputs,
+        "final_state": flat_final_state,
+        "outputs_grad": outputs_grad,
+        "final_state_grad": final_state_grad
+    }
 
 
 class CompiledWrapperTest(test.TestCase):
@@ -1420,8 +1410,10 @@ class CompiledWrapperTest(test.TestCase):
     random_seed.set_random_seed(1234)
     with self.test_session(graph=ops.Graph()) as sess:
       xla_ops = _create_multi_lstm_cell_ops(
-          batch_size=batch_size, num_units=num_units,
-          input_depth=input_depth, num_layers=num_layers,
+          batch_size=batch_size,
+          num_units=num_units,
+          input_depth=input_depth,
+          num_layers=num_layers,
           max_time=max_time,
           compiled=True)
       sess.run([variables.global_variables_initializer()])
@@ -1430,8 +1422,10 @@ class CompiledWrapperTest(test.TestCase):
     random_seed.set_random_seed(1234)
     with self.test_session(graph=ops.Graph()) as sess:
       non_xla_ops = _create_multi_lstm_cell_ops(
-          batch_size=batch_size, num_units=num_units,
-          input_depth=input_depth, num_layers=num_layers,
+          batch_size=batch_size,
+          num_units=num_units,
+          input_depth=input_depth,
+          num_layers=num_layers,
           max_time=max_time,
           compiled=False)
       sess.run([variables.global_variables_initializer()])
@@ -1440,16 +1434,16 @@ class CompiledWrapperTest(test.TestCase):
     self.assertAllClose(
         non_xla_results["outputs"], xla_results["outputs"], atol=atol)
 
-    for xla_value, non_xla_value in zip(
-        xla_results["final_state"], non_xla_results["final_state"]):
+    for xla_value, non_xla_value in zip(xla_results["final_state"],
+                                        non_xla_results["final_state"]):
       self.assertAllClose(xla_value, non_xla_value, atol=atol)
 
-    for xla_g, non_xla_g in zip(
-        xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
+    for xla_g, non_xla_g in zip(xla_results["outputs_grad"],
+                                non_xla_results["outputs_grad"]):
       self.assertAllClose(xla_g, non_xla_g, atol=atol)
 
-    for xla_g, non_xla_g in zip(
-        xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
+    for xla_g, non_xla_g in zip(xla_results["final_state_grad"],
+                                non_xla_results["final_state_grad"]):
       self.assertAllClose(xla_g, non_xla_g, atol=atol)
 
   def testMultiRNNCellWithStateTuple(self):
@@ -1463,19 +1457,20 @@ class CompiledWrapperTest(test.TestCase):
         # Test incorrectness of state
         with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
           rnn_cell.MultiRNNCell(
-              [rnn_cell.GRUCell(2)
-               for _ in range(2)], state_is_tuple=True)(x, m_bad)
+              [rnn_cell.GRUCell(2) for _ in range(2)],
+              state_is_tuple=True)(x, m_bad)
 
         _, ml = rnn_cell.MultiRNNCell(
-            [rnn_cell.GRUCell(2)
-             for _ in range(2)], state_is_tuple=True)(x, m_good)
+            [rnn_cell.GRUCell(2) for _ in range(2)],
+            state_is_tuple=True)(x, m_good)
 
         sess.run([variables.global_variables_initializer()])
-        res = sess.run(ml, {
-            x.name: np.array([[1., 1.]]),
-            m_good[0].name: np.array([[0.1, 0.1]]),
-            m_good[1].name: np.array([[0.1, 0.1]])
-        })
+        res = sess.run(
+            ml, {
+                x.name: np.array([[1., 1.]]),
+                m_good[0].name: np.array([[0.1, 0.1]]),
+                m_good[1].name: np.array([[0.1, 0.1]])
+            })
 
         # The numbers in results were not calculated, this is just a
         # smoke test.  However, these numbers should match those of
@@ -1490,24 +1485,20 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
     num_layers = 3
     max_time = 50
     print("benchmarkDynamicRNNWithMultiLSTMCell")
-    print("\t" +
-          "\t".join(["inter_th", "intra_th",
-                     "batch_size", "num_units", "input_depth", "device",
-                     "compiled", "wall_time"]))
+    print("\t" + "\t".join([
+        "inter_th", "intra_th", "batch_size", "num_units", "input_depth",
+        "device", "compiled", "wall_time"
+    ]))
 
     warmup_run = True
-    for (threads,
-         device,
-         num_units,
-         batch_size,
-         input_depth,
-         compiled) in itertools.product(
-             [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
-             ["cpu", "gpu"],
-             [32, 512],
-             [1, 32, 256],
-             [32, 512],
-             [False, True]):
+    for (threads, device, num_units, batch_size, input_depth,
+         compiled) in itertools.product([{
+             "inter": 0,
+             "intra": 0
+         }, {
+             "inter": 1,
+             "intra": 4
+         }], ["cpu", "gpu"], [32, 512], [1, 32, 256], [32, 512], [False, True]):
       if threads["inter"] != 0:
         # We only care about testing inter/intra op limitations on
         # CPU with small batch size, to mimic embedded devices.
@@ -1523,30 +1514,35 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
       with session.Session(config=config, graph=ops.Graph()) as sess:
         with ops.device("/%s:0" % device):
           ops_dict = _create_multi_lstm_cell_ops(
-              batch_size=batch_size, num_units=num_units,
-              input_depth=input_depth, num_layers=num_layers,
+              batch_size=batch_size,
+              num_units=num_units,
+              input_depth=input_depth,
+              num_layers=num_layers,
               max_time=max_time,
               compiled=compiled)
         sess.run([variables.global_variables_initializer()])
         all_ops = nest.flatten(ops_dict.values())
         all_ops_group = control_flow_ops.group(*all_ops)
-        name_suffix = (
-            "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
-            "_device_%s_xla_%s" % (
-                threads["inter"], threads["intra"],
-                batch_size, num_units, input_depth, device, compiled))
+        name_suffix = ("inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
+                       "_device_%s_xla_%s" %
+                       (threads["inter"], threads["intra"], batch_size,
+                        num_units, input_depth, device, compiled))
         if warmup_run:
           self.run_op_benchmark(
               sess, all_ops_group, min_iters=30, name="ignore_warmup")
           warmup_run = False
         benchmark_results = self.run_op_benchmark(
-            sess, all_ops_group, min_iters=50,
+            sess,
+            all_ops_group,
+            min_iters=50,
             name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
-        print("\t" +
-              "\t".join(["%s" % x for x in [
-                  threads["inter"], threads["intra"],
-                  batch_size, num_units, input_depth, device, compiled,
-                  benchmark_results["wall_time"]]]))
+        print("\t" + "\t".join([
+            "%s" % x
+            for x in [
+                threads["inter"], threads["intra"], batch_size, num_units,
+                input_depth, device, compiled, benchmark_results["wall_time"]
+            ]
+        ]))
 
 
 class WeightNormLSTMCellTest(test.TestCase):
@@ -1557,8 +1553,7 @@ class WeightNormLSTMCellTest(test.TestCase):
 
     with self.test_session() as sess:
       init = init_ops.constant_initializer(0.5)
-      with variable_scope.variable_scope("root",
-                                         initializer=init):
+      with variable_scope.variable_scope("root", initializer=init):
         x = array_ops.zeros([1, 2])
         c0 = array_ops.zeros([1, 2])
         h0 = array_ops.zeros([1, 2])
@@ -1568,11 +1563,12 @@ class WeightNormLSTMCellTest(test.TestCase):
         xout, sout = cell()(x, state0)
 
       sess.run([variables.global_variables_initializer()])
-      res = sess.run([xout, sout], {
-          x.name: np.array([[1., 1.]]),
-          c0.name: 0.1 * np.asarray([[0, 1]]),
-          h0.name: 0.1 * np.asarray([[2, 3]]),
-      })
+      res = sess.run(
+          [xout, sout], {
+              x.name: np.array([[1., 1.]]),
+              c0.name: 0.1 * np.asarray([[0, 1]]),
+              h0.name: 0.1 * np.asarray([[2, 3]]),
+          })
 
     actual_state_c = res[1].c
     actual_state_h = res[1].h
@@ -1583,9 +1579,8 @@ class WeightNormLSTMCellTest(test.TestCase):
     """Tests cell w/o peepholes and w/o normalisation"""
 
     def cell():
-      return contrib_rnn_cell.WeightNormLSTMCell(2,
-                                                 norm=False,
-                                                 use_peepholes=False)
+      return contrib_rnn_cell.WeightNormLSTMCell(
+          2, norm=False, use_peepholes=False)
 
     actual_c, actual_h = self._cell_output(cell)
 
@@ -1599,9 +1594,8 @@ class WeightNormLSTMCellTest(test.TestCase):
     """Tests cell with peepholes and w/o normalisation"""
 
     def cell():
-      return contrib_rnn_cell.WeightNormLSTMCell(2,
-                                                 norm=False,
-                                                 use_peepholes=True)
+      return contrib_rnn_cell.WeightNormLSTMCell(
+          2, norm=False, use_peepholes=True)
 
     actual_c, actual_h = self._cell_output(cell)
 
@@ -1611,14 +1605,12 @@ class WeightNormLSTMCellTest(test.TestCase):
     self.assertAllClose(expected_c, actual_c, 1e-5)
     self.assertAllClose(expected_h, actual_h, 1e-5)
 
-
   def testBasicCellWithNorm(self):
     """Tests cell w/o peepholes and with normalisation"""
 
     def cell():
-      return contrib_rnn_cell.WeightNormLSTMCell(2,
-                                                 norm=True,
-                                                 use_peepholes=False)
+      return contrib_rnn_cell.WeightNormLSTMCell(
+          2, norm=True, use_peepholes=False)
 
     actual_c, actual_h = self._cell_output(cell)
 
@@ -1632,9 +1624,8 @@ class WeightNormLSTMCellTest(test.TestCase):
     """Tests cell with peepholes and with normalisation"""
 
     def cell():
-      return contrib_rnn_cell.WeightNormLSTMCell(2,
-                                                 norm=True,
-                                                 use_peepholes=True)
+      return contrib_rnn_cell.WeightNormLSTMCell(
+          2, norm=True, use_peepholes=True)
 
     actual_c, actual_h = self._cell_output(cell)
 
@@ -1644,5 +1635,6 @@ class WeightNormLSTMCellTest(test.TestCase):
     self.assertAllClose(expected_c, actual_c, 1e-5)
     self.assertAllClose(expected_h, actual_h, 1e-5)
 
+
 if __name__ == "__main__":
   test.main()
index d7ae6621dba8c380af9a61612a675c66bfeaccc3..8adf5dce6ec76d8ac4f182929e0dfc81be946277 100644 (file)
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Module for constructing RNN Cells."""
 from __future__ import absolute_import
 from __future__ import division
@@ -56,16 +55,15 @@ def _get_concat_variable(name, shape, dtype, num_shards):
       return value
 
   concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
-  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
-                        concat_variable)
+  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
   return concat_variable
 
 
 def _get_sharded_variable(name, shape, dtype, num_shards):
   """Get a list of sharded variables with the given dtype."""
   if num_shards > shape[0]:
-    raise ValueError("Too many shards: shape=%s, num_shards=%d" %
-                     (shape, num_shards))
+    raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
+                                                                   num_shards))
   unit_shard_size = int(math.floor(shape[0] / num_shards))
   remaining_rows = shape[0] - unit_shard_size * num_shards
 
@@ -74,8 +72,9 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
     current_size = unit_shard_size
     if i < remaining_rows:
       current_size += 1
-    shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
-                                  dtype=dtype))
+    shards.append(
+        vs.get_variable(
+            name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
   return shards
 
 
@@ -177,9 +176,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
     """
     super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
     if not state_is_tuple:
-      logging.warn(
-          "%s: Using a concatenated state is slower and will soon be "
-          "deprecated.  Use state_is_tuple=True.", self)
+      logging.warn("%s: Using a concatenated state is slower and will soon be "
+                   "deprecated.  Use state_is_tuple=True.", self)
     self._num_units = num_units
     self._use_peepholes = use_peepholes
     self._initializer = initializer
@@ -196,12 +194,14 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
     self._norm_shift = norm_shift
 
     if num_proj:
-      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
-                          if state_is_tuple else num_units + num_proj)
+      self._state_size = (
+          rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
+          if state_is_tuple else num_units + num_proj)
       self._output_size = num_proj
     else:
-      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
-                          if state_is_tuple else 2 * num_units)
+      self._state_size = (
+          rnn_cell_impl.LSTMStateTuple(num_units, num_units)
+          if state_is_tuple else 2 * num_units)
       self._output_size = num_units
 
   @property
@@ -251,8 +251,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
     concat_w = _get_concat_variable(
-        "W", [input_size.value + num_proj, 3 * self._num_units],
-        dtype, self._num_unit_shards)
+        "W", [input_size.value + num_proj, 3 * self._num_units], dtype,
+        self._num_unit_shards)
 
     b = vs.get_variable(
         "B",
@@ -299,9 +299,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
       m = sigmoid(o) * self._activation(c)
 
     if self._num_proj is not None:
-      concat_w_proj = _get_concat_variable(
-          "W_P", [self._num_units, self._num_proj],
-          dtype, self._num_proj_shards)
+      concat_w_proj = _get_concat_variable("W_P",
+                                           [self._num_units, self._num_proj],
+                                           dtype, self._num_proj_shards)
 
       m = math_ops.matmul(m, concat_w_proj)
       if self._proj_clip is not None:
@@ -309,8 +309,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
         m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
         # pylint: enable=invalid-unary-operand-type
 
-    new_state = (rnn_cell_impl.LSTMStateTuple(c, m)
-                 if self._state_is_tuple else array_ops.concat([c, m], 1))
+    new_state = (
+        rnn_cell_impl.LSTMStateTuple(c, m)
+        if self._state_is_tuple else array_ops.concat([c, m], 1))
     return m, new_state
 
 
@@ -326,10 +327,15 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
   It uses peep-hole connections and optional cell clipping.
   """
 
-  def __init__(self, num_units, use_peepholes=False,
-               cell_clip=None, initializer=None,
-               num_unit_shards=1, forget_bias=1.0,
-               feature_size=None, frequency_skip=1,
+  def __init__(self,
+               num_units,
+               use_peepholes=False,
+               cell_clip=None,
+               initializer=None,
+               num_unit_shards=1,
+               forget_bias=1.0,
+               feature_size=None,
+               frequency_skip=1,
                reuse=None):
     """Initialize the parameters for an LSTM cell.
 
@@ -399,7 +405,7 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
     actual_input_size = freq_inputs[0].get_shape().as_list()[1]
 
     concat_w = _get_concat_variable(
-        "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
+        "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
         dtype, self._num_unit_shards)
 
     b = vs.get_variable(
@@ -418,23 +424,23 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
           "W_O_diag", shape=[self._num_units], dtype=dtype)
 
     # initialize the first freq state to be zero
-    m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
-                                   self._num_units], dtype)
+    m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
+                                  dtype)
     for fq in range(len(freq_inputs)):
-      c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
+      c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
                                [-1, self._num_units])
-      m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
+      m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
                                [-1, self._num_units])
       # i = input_gate, j = new_input, f = forget_gate, o = output_gate
-      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
-                                     1)
+      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
       lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
       i, j, f, o = array_ops.split(
           value=lstm_matrix, num_or_size_splits=4, axis=1)
 
       if self._use_peepholes:
-        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
-             sigmoid(i + w_i_diag * c_prev) * tanh(j))
+        c = (
+            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+            sigmoid(i + w_i_diag * c_prev) * tanh(j))
       else:
         c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
 
@@ -472,11 +478,11 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
     input_size = input_feat.get_shape().with_rank(2)[-1].value
     if input_size is None:
       raise ValueError("Cannot infer input_size from static shape inference.")
-    num_feats = int((input_size - self._feature_size) / (
-        self._frequency_skip)) + 1
+    num_feats = int(
+        (input_size - self._feature_size) / (self._frequency_skip)) + 1
     freq_inputs = []
     for f in range(num_feats):
-      cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
+      cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
                                   [-1, self._feature_size])
       freq_inputs.append(cur_input)
     return freq_inputs
@@ -498,11 +504,16 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
   The code uses optional peephole connections, shared_weights and cell clipping.
   """
 
-  def __init__(self, num_units, use_peepholes=False,
+  def __init__(self,
+               num_units,
+               use_peepholes=False,
                share_time_frequency_weights=False,
-               cell_clip=None, initializer=None,
-               num_unit_shards=1, forget_bias=1.0,
-               feature_size=None, frequency_skip=None,
+               cell_clip=None,
+               initializer=None,
+               num_unit_shards=1,
+               forget_bias=1.0,
+               feature_size=None,
+               frequency_skip=None,
                num_frequency_blocks=None,
                start_freqindex_list=None,
                end_freqindex_list=None,
@@ -580,10 +591,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         for freq_index in range(self._num_frequency_blocks[block_index]):
           name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
           state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
-      self._state_tuple_type = collections.namedtuple(
-          "GridLSTMStateTuple", state_names.strip(","))
-      self._state_size = self._state_tuple_type(
-          *([num_units, num_units] * self._total_blocks))
+      self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
+                                                      state_names.strip(","))
+      self._state_size = self._state_tuple_type(*(
+          [num_units, num_units] * self._total_blocks))
     else:
       self._state_tuple_type = None
       self._state_size = num_units * self._total_blocks * 2
@@ -626,7 +637,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
     state_out_lst = []
     for block in range(len(freq_inputs)):
       m_out_lst_current, state_out_lst_current = self._compute(
-          freq_inputs[block], block, state, batch_size,
+          freq_inputs[block],
+          block,
+          state,
+          batch_size,
           state_is_tuple=self._state_is_tuple)
       m_out_lst.extend(m_out_lst_current)
       state_out_lst.extend(state_out_lst_current)
@@ -637,7 +651,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
     m_out = array_ops.concat(m_out_lst, 1)
     return m_out, state_out
 
-  def _compute(self, freq_inputs, block, state, batch_size,
+  def _compute(self,
+               freq_inputs,
+               block,
+               state,
+               batch_size,
                state_prefix="state",
                state_is_tuple=True):
     """Run the actual computation of one step LSTM.
@@ -666,8 +684,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
     actual_input_size = freq_inputs[0].get_shape().as_list()[1]
 
     concat_w_f = _get_concat_variable(
-        "W_f_%d" % block, [actual_input_size + 2 * self._num_units,
-                           num_gates * self._num_units],
+        "W_f_%d" % block,
+        [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
         dtype, self._num_unit_shards)
     b_f = vs.get_variable(
         "B_f_%d" % block,
@@ -675,10 +693,9 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         initializer=init_ops.zeros_initializer(),
         dtype=dtype)
     if not self._share_time_frequency_weights:
-      concat_w_t = _get_concat_variable(
-          "W_t_%d" % block, [actual_input_size + 2 * self._num_units,
-                             num_gates * self._num_units],
-          dtype, self._num_unit_shards)
+      concat_w_t = _get_concat_variable("W_t_%d" % block, [
+          actual_input_size + 2 * self._num_units, num_gates * self._num_units
+      ], dtype, self._num_unit_shards)
       b_t = vs.get_variable(
           "B_t_%d" % block,
           shape=[num_gates * self._num_units],
@@ -691,7 +708,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         w_f_diag_freqf = vs.get_variable(
             "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
         w_f_diag_freqt = vs.get_variable(
-            "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype)
+            "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
       w_i_diag_freqf = vs.get_variable(
           "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
       w_i_diag_freqt = vs.get_variable(
@@ -725,8 +742,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         m_prev_time = getattr(state, name_prefix + "_m")
       else:
         c_prev_time = array_ops.slice(
-            state, [0, 2 * freq_index * self._num_units],
-            [-1, self._num_units])
+            state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
         m_prev_time = array_ops.slice(
             state, [0, (2 * freq_index + 1) * self._num_units],
             [-1, self._num_units])
@@ -736,8 +752,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
           [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
 
       # F-LSTM
-      lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs,
-                                                         concat_w_f), b_f)
+      lstm_matrix_freq = nn_ops.bias_add(
+          math_ops.matmul(cell_inputs, concat_w_f), b_f)
       if self._couple_input_forget_gates:
         i_freq, j_freq, o_freq = array_ops.split(
             value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
@@ -752,8 +768,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         f_time = f_freq
         o_time = o_freq
       else:
-        lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs,
-                                                           concat_w_t), b_t)
+        lstm_matrix_time = nn_ops.bias_add(
+            math_ops.matmul(cell_inputs, concat_w_t), b_t)
         if self._couple_input_forget_gates:
           i_time, j_time, o_time = array_ops.split(
               value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
@@ -765,8 +781,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
       # F-LSTM c_freq
       # input gate activations
       if self._use_peepholes:
-        i_freq_g = sigmoid(i_freq +
-                           w_i_diag_freqf * c_prev_freq +
+        i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
                            w_i_diag_freqt * c_prev_time)
       else:
         i_freq_g = sigmoid(i_freq)
@@ -775,9 +790,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         f_freq_g = 1.0 - i_freq_g
       else:
         if self._use_peepholes:
-          f_freq_g = sigmoid(f_freq + self._forget_bias +
-                             w_f_diag_freqf * c_prev_freq +
-                             w_f_diag_freqt * c_prev_time)
+          f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
+                             c_prev_freq + w_f_diag_freqt * c_prev_time)
         else:
           f_freq_g = sigmoid(f_freq + self._forget_bias)
       # cell state
@@ -792,12 +806,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
       # input gate activations
       if self._use_peepholes:
         if self._share_time_frequency_weights:
-          i_time_g = sigmoid(i_time +
-                             w_i_diag_freqf * c_prev_freq +
+          i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
                              w_i_diag_freqt * c_prev_time)
         else:
-          i_time_g = sigmoid(i_time +
-                             w_i_diag_timef * c_prev_freq +
+          i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
                              w_i_diag_timet * c_prev_time)
       else:
         i_time_g = sigmoid(i_time)
@@ -807,13 +819,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
       else:
         if self._use_peepholes:
           if self._share_time_frequency_weights:
-            f_time_g = sigmoid(f_time + self._forget_bias +
-                               w_f_diag_freqf * c_prev_freq +
-                               w_f_diag_freqt * c_prev_time)
+            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
+                               c_prev_freq + w_f_diag_freqt * c_prev_time)
           else:
-            f_time_g = sigmoid(f_time + self._forget_bias +
-                               w_f_diag_timef * c_prev_freq +
-                               w_f_diag_timet * c_prev_time)
+            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
+                               c_prev_freq + w_f_diag_timet * c_prev_time)
         else:
           f_time_g = sigmoid(f_time + self._forget_bias)
       # cell state
@@ -826,8 +836,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
 
       # F-LSTM m_freq
       if self._use_peepholes:
-        m_freq = sigmoid(o_freq +
-                         w_o_diag_freqf * c_freq +
+        m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
                          w_o_diag_freqt * c_time) * tanh(c_freq)
       else:
         m_freq = sigmoid(o_freq) * tanh(c_freq)
@@ -835,12 +844,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
       # T-LSTM m_time
       if self._use_peepholes:
         if self._share_time_frequency_weights:
-          m_time = sigmoid(o_time +
-                           w_o_diag_freqf * c_freq +
+          m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
                            w_o_diag_freqt * c_time) * tanh(c_time)
         else:
-          m_time = sigmoid(o_time +
-                           w_o_diag_timef * c_freq +
+          m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
                            w_o_diag_timet * c_time) * tanh(c_time)
       else:
         m_time = sigmoid(o_time) * tanh(c_time)
@@ -879,16 +886,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
       raise ValueError("Cannot infer input_size from static shape inference.")
     if slice_offset > 0:
       # Padding to the end
-      inputs = array_ops.pad(
-          input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2],
-                                         dtype=dtypes.int32),
-          "CONSTANT")
+      inputs = array_ops.pad(input_feat,
+                             array_ops.constant(
+                                 [0, 0, 0, slice_offset],
+                                 shape=[2, 2],
+                                 dtype=dtypes.int32), "CONSTANT")
     elif slice_offset < 0:
       # Padding to the front
-      inputs = array_ops.pad(
-          input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2],
-                                         dtype=dtypes.int32),
-          "CONSTANT")
+      inputs = array_ops.pad(input_feat,
+                             array_ops.constant(
+                                 [0, 0, -slice_offset, 0],
+                                 shape=[2, 2],
+                                 dtype=dtypes.int32), "CONSTANT")
       slice_offset = 0
     else:
       inputs = input_feat
@@ -898,13 +907,13 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         raise ValueError("Length of num_frequency_blocks"
                          " is not 1, but instead is %d",
                          len(self._num_frequency_blocks))
-      num_feats = int((input_size - self._feature_size) / (
-          self._frequency_skip)) + 1
+      num_feats = int(
+          (input_size - self._feature_size) / (self._frequency_skip)) + 1
       if num_feats != self._num_frequency_blocks[0]:
         raise ValueError(
             "Invalid num_frequency_blocks, requires %d but gets %d, please"
-            " check the input size and filter config are correct." % (
-                self._num_frequency_blocks[0], num_feats))
+            " check the input size and filter config are correct." %
+            (self._num_frequency_blocks[0], num_feats))
       block_inputs = []
       for f in range(num_feats):
         cur_input = array_ops.slice(
@@ -927,18 +936,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
         start_index = self._start_freqindex_list[b]
         end_index = self._end_freqindex_list[b]
         cur_size = end_index - start_index
-        block_feats = int((cur_size - self._feature_size) / (
-            self._frequency_skip)) + 1
+        block_feats = int(
+            (cur_size - self._feature_size) / (self._frequency_skip)) + 1
         if block_feats != self._num_frequency_blocks[b]:
           raise ValueError(
               "Invalid num_frequency_blocks, requires %d but gets %d, please"
-              " check the input size and filter config are correct." % (
-                  self._num_frequency_blocks[b], block_feats))
+              " check the input size and filter config are correct." %
+              (self._num_frequency_blocks[b], block_feats))
         block_inputs = []
         for f in range(block_feats):
           cur_input = array_ops.slice(
-              inputs, [0, start_index + slice_offset + f *
-                       self._frequency_skip],
+              inputs,
+              [0, start_index + slice_offset + f * self._frequency_skip],
               [-1, self._feature_size])
           block_inputs.append(cur_input)
         freq_inputs.append(block_inputs)
@@ -954,11 +963,16 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
   The current implementation uses different weights for the two directions.
   """
 
-  def __init__(self, num_units, use_peepholes=False,
+  def __init__(self,
+               num_units,
+               use_peepholes=False,
                share_time_frequency_weights=False,
-               cell_clip=None, initializer=None,
-               num_unit_shards=1, forget_bias=1.0,
-               feature_size=None, frequency_skip=None,
+               cell_clip=None,
+               initializer=None,
+               num_unit_shards=1,
+               forget_bias=1.0,
+               feature_size=None,
+               frequency_skip=None,
                num_frequency_blocks=None,
                start_freqindex_list=None,
                end_freqindex_list=None,
@@ -1017,8 +1031,8 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
           state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
     self._state_tuple_type = collections.namedtuple(
         "BidirectionalGridLSTMStateTuple", state_names.strip(","))
-    self._state_size = self._state_tuple_type(
-        *([num_units, num_units] * self._total_blocks * 2))
+    self._state_size = self._state_tuple_type(*(
+        [num_units, num_units] * self._total_blocks * 2))
     self._output_size = 2 * num_units * self._total_blocks * 2
 
   def call(self, inputs, state):
@@ -1052,8 +1066,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
       fwd_state_out_lst = []
       for block in range(len(fwd_inputs)):
         fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
-            fwd_inputs[block], block, state, batch_size,
-            state_prefix="fwd_state", state_is_tuple=True)
+            fwd_inputs[block],
+            block,
+            state,
+            batch_size,
+            state_prefix="fwd_state",
+            state_is_tuple=True)
         fwd_m_out_lst.extend(fwd_m_out_lst_current)
         fwd_state_out_lst.extend(fwd_state_out_lst_current)
     # Backward processing
@@ -1064,8 +1082,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
         # Reverse the blocks
         bwd_inputs_reverse = bwd_inputs[block][::-1]
         bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
-            bwd_inputs_reverse, block, state, batch_size,
-            state_prefix="bwd_state", state_is_tuple=True)
+            bwd_inputs_reverse,
+            block,
+            state,
+            batch_size,
+            state_prefix="bwd_state",
+            state_is_tuple=True)
         bwd_m_out_lst.extend(bwd_m_out_lst_current)
         bwd_state_out_lst.extend(bwd_state_out_lst_current)
     state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
@@ -1076,6 +1098,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
 
 # pylint: disable=protected-access
 _Linear = core_rnn_cell._Linear  # pylint: disable=invalid-name
+
 # pylint: enable=protected-access
 
 
@@ -1085,8 +1108,14 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
   Implementation based on https://arxiv.org/abs/1409.0473.
   """
 
-  def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
-               input_size=None, state_is_tuple=True, reuse=None):
+  def __init__(self,
+               cell,
+               attn_length,
+               attn_size=None,
+               attn_vec_size=None,
+               input_size=None,
+               state_is_tuple=True,
+               reuse=None):
     """Create a cell with attention.
 
     Args:
@@ -1116,16 +1145,15 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
     if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
       raise TypeError("The parameter cell is not RNNCell.")
     if nest.is_sequence(cell.state_size) and not state_is_tuple:
-      raise ValueError("Cell returns tuple of states, but the flag "
-                       "state_is_tuple is not set. State size is: %s"
-                       % str(cell.state_size))
+      raise ValueError(
+          "Cell returns tuple of states, but the flag "
+          "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
     if attn_length <= 0:
-      raise ValueError("attn_length should be greater than zero, got %s"
-                       % str(attn_length))
+      raise ValueError(
+          "attn_length should be greater than zero, got %s" % str(attn_length))
     if not state_is_tuple:
-      logging.warn(
-          "%s: Using a concatenated state is slower and will soon be "
-          "deprecated.  Use state_is_tuple=True.", self)
+      logging.warn("%s: Using a concatenated state is slower and will soon be "
+                   "deprecated.  Use state_is_tuple=True.", self)
     if attn_size is None:
       attn_size = cell.output_size
     if attn_vec_size is None:
@@ -1161,8 +1189,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
     else:
       states = state
       state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
-      attns = array_ops.slice(
-          states, [0, self._cell.state_size], [-1, self._attn_size])
+      attns = array_ops.slice(states, [0, self._cell.state_size],
+                              [-1, self._attn_size])
       attn_states = array_ops.slice(
           states, [0, self._cell.state_size + self._attn_size],
           [-1, self._attn_size * self._attn_length])
@@ -1200,8 +1228,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
     tanh = math_ops.tanh
 
     with vs.variable_scope("attention"):
-      k = vs.get_variable(
-          "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
+      k = vs.get_variable("attn_w",
+                          [1, 1, self._attn_size, self._attn_vec_size])
       v = vs.get_variable("attn_v", [self._attn_vec_size])
       hidden = array_ops.reshape(attn_states,
                                  [-1, self._attn_length, 1, self._attn_size])
@@ -1228,7 +1256,8 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
     https://arxiv.org/abs/1505.00387
   """
 
-  def __init__(self, cell,
+  def __init__(self,
+               cell,
                couple_carry_transform_gates=True,
                carry_bias_init=1.0):
     """Constructs a `HighwayWrapper` for `cell`.
@@ -1260,8 +1289,7 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
     carry_weight = vs.get_variable("carry_w", [input_size, input_size])
     carry_bias = vs.get_variable(
         "carry_b", [input_size],
-        initializer=init_ops.constant_initializer(
-            self._carry_bias_init))
+        initializer=init_ops.constant_initializer(self._carry_bias_init))
     carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
     if self._couple_carry_transform_gates:
       transform = 1 - carry
@@ -1270,11 +1298,9 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
                                          [input_size, input_size])
       transform_bias = vs.get_variable(
           "transform_b", [input_size],
-          initializer=init_ops.constant_initializer(
-              -self._carry_bias_init))
-      transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp,
-                                                    transform_weight,
-                                                    transform_bias))
+          initializer=init_ops.constant_initializer(-self._carry_bias_init))
+      transform = math_ops.sigmoid(
+          nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
     return inp * carry + out * transform
 
   def __call__(self, inputs, state, scope=None):
@@ -1294,9 +1320,11 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
     """
     outputs, new_state = self._cell(inputs, state, scope=scope)
     nest.assert_same_structure(inputs, outputs)
+
     # Ensure shapes match
     def assert_shape_match(inp, out):
       inp.get_shape().assert_is_compatible_with(out.get_shape())
+
     nest.map_structure(assert_shape_match, inputs, outputs)
     res_outputs = nest.map_structure(self._highway, inputs, outputs)
     return (res_outputs, new_state)
@@ -1322,10 +1350,16 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
   Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
   """
 
-  def __init__(self, num_units, forget_bias=1.0,
-               input_size=None, activation=math_ops.tanh,
-               layer_norm=True, norm_gain=1.0, norm_shift=0.0,
-               dropout_keep_prob=1.0, dropout_prob_seed=None,
+  def __init__(self,
+               num_units,
+               forget_bias=1.0,
+               input_size=None,
+               activation=math_ops.tanh,
+               layer_norm=True,
+               norm_gain=1.0,
+               norm_shift=0.0,
+               dropout_keep_prob=1.0,
+               dropout_prob_seed=None,
                reuse=None):
     """Initializes the basic LSTM cell.
 
@@ -1410,8 +1444,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
     if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
       g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
 
-    new_c = (c * math_ops.sigmoid(f + self._forget_bias)
-             + math_ops.sigmoid(i) * g)
+    new_c = (
+        c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
     if self._layer_norm:
       new_c = self._norm(new_c, "state", dtype=dtype)
     new_h = self._activation(new_c) * math_ops.sigmoid(o)
@@ -1433,8 +1467,7 @@ class NASCell(rnn_cell_impl.RNNCell):
   The class uses an optional projection layer.
   """
 
-  def __init__(self, num_units, num_proj=None,
-               use_biases=False, reuse=None):
+  def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None):
     """Initialize the parameters for a NAS cell.
 
     Args:
@@ -1504,12 +1537,10 @@ class NASCell(rnn_cell_impl.RNNCell):
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
     # Variables for the NAS cell. W_m is all matrices multiplying the
     # hiddenstate and W_inputs is all matrices multiplying the inputs.
-    concat_w_m = vs.get_variable(
-        "recurrent_kernel", [num_proj, 8 * self._num_units],
-        dtype)
+    concat_w_m = vs.get_variable("recurrent_kernel",
+                                 [num_proj, 8 * self._num_units], dtype)
     concat_w_inputs = vs.get_variable(
-        "kernel", [input_size.value, 8 * self._num_units],
-        dtype)
+        "kernel", [input_size.value, 8 * self._num_units], dtype)
 
     m_matrix = math_ops.matmul(m_prev, concat_w_m)
     inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
@@ -1524,10 +1555,10 @@ class NASCell(rnn_cell_impl.RNNCell):
 
     # The NAS cell branches into 8 different splits for both the hiddenstate
     # and the input
-    m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
-                                      value=m_matrix)
-    inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
-                                           value=inputs_matrix)
+    m_matrix_splits = array_ops.split(
+        axis=1, num_or_size_splits=8, value=m_matrix)
+    inputs_matrix_splits = array_ops.split(
+        axis=1, num_or_size_splits=8, value=inputs_matrix)
 
     # First layer
     layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
@@ -1559,9 +1590,8 @@ class NASCell(rnn_cell_impl.RNNCell):
 
     # Projection layer if specified
     if self._num_proj is not None:
-      concat_w_proj = vs.get_variable(
-          "projection_weights", [self._num_units, self._num_proj],
-          dtype)
+      concat_w_proj = vs.get_variable("projection_weights",
+                                      [self._num_units, self._num_proj], dtype)
       new_m = math_ops.matmul(new_m, concat_w_proj)
 
     new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
@@ -1584,8 +1614,12 @@ class UGRNNCell(rnn_cell_impl.RNNCell):
   "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
   """
 
-  def __init__(self, num_units, initializer=None, forget_bias=1.0,
-               activation=math_ops.tanh, reuse=None):
+  def __init__(self,
+               num_units,
+               initializer=None,
+               forget_bias=1.0,
+               activation=math_ops.tanh,
+               reuse=None):
     """Initialize the parameters for an UGRNN cell.
 
     Args:
@@ -1640,8 +1674,8 @@ class UGRNNCell(rnn_cell_impl.RNNCell):
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
 
-    with vs.variable_scope(vs.get_variable_scope(),
-                           initializer=self._initializer):
+    with vs.variable_scope(
+        vs.get_variable_scope(), initializer=self._initializer):
       cell_inputs = array_ops.concat([inputs, state], 1)
       if self._linear is None:
         self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
@@ -1681,9 +1715,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
   RNNs so it may not achieve best performance with depth 1.
   """
 
-  def __init__(self, num_units, num_in_proj=None,
-               initializer=None, forget_bias=1.0,
-               y_activation=nn_ops.relu, reuse=None):
+  def __init__(self,
+               num_units,
+               num_in_proj=None,
+               initializer=None,
+               forget_bias=1.0,
+               y_activation=nn_ops.relu,
+               reuse=None):
     """Initialize the parameters for an +RNN cell.
 
     Args:
@@ -1747,8 +1785,8 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
 
-    with vs.variable_scope(vs.get_variable_scope(),
-                           initializer=self._initializer):
+    with vs.variable_scope(
+        vs.get_variable_scope(), initializer=self._initializer):
       # read-in projections (should be used for first layer in deep +RNN
       # to transform size of inputs from I --> N)
       if input_size.value != self._num_units:
@@ -1765,13 +1803,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
       n_dim = i_dim = self._num_units
       cell_inputs = array_ops.concat([inputs, state], 1)
       if self._linear2 is None:
-        self._linear2 = _Linear(cell_inputs, 2*n_dim + 2*i_dim, True)
+        self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
       rnn_matrix = self._linear2(cell_inputs)
 
-      gh_act = rnn_matrix[:, :n_dim]                           # b x n
-      h_act = rnn_matrix[:, n_dim:2*n_dim]                     # b x n
-      gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim]            # b x i
-      y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim]     # b x i
+      gh_act = rnn_matrix[:, :n_dim]  # b x n
+      h_act = rnn_matrix[:, n_dim:2 * n_dim]  # b x n
+      gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim]  # b x i
+      y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim]  # b x i
 
       h = tanh(h_act)
       y = self._y_activation(y_act)
@@ -1817,6 +1855,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell):
     if self._compile_stateful:
       compile_ops = True
     else:
+
       def compile_ops(node_def):
         global _REGISTERED_OPS
         if _REGISTERED_OPS is None:
@@ -1827,10 +1866,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell):
       return self._cell(inputs, state, scope=scope)
 
 
-def _random_exp_initializer(minval,
-                            maxval,
-                            seed=None,
-                            dtype=dtypes.float32):
+def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
   """Returns an exponential distribution initializer.
 
   Args:
@@ -1849,10 +1885,7 @@ def _random_exp_initializer(minval,
     del partition_info  # Unused.
     return math_ops.exp(
         random_ops.random_uniform(
-            shape,
-            math_ops.log(minval),
-            math_ops.log(maxval),
-            dtype,
+            shape, math_ops.log(minval), math_ops.log(maxval), dtype,
             seed=seed))
 
   return _initializer
@@ -1956,8 +1989,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
       if self._linear1 is None:
         self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
 
-      mask_gates = math_ops.sigmoid(
-          self._linear1(in_mask_gates))
+      mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
       [input_gate, forget_gate] = array_ops.split(
           axis=1, num_or_size_splits=2, value=mask_gates)
 
@@ -1981,12 +2013,12 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
 
     period = vs.get_variable(
         "period", [self._num_units],
-        initializer=_random_exp_initializer(
-            self._period_init_min, self._period_init_max))
+        initializer=_random_exp_initializer(self._period_init_min,
+                                            self._period_init_max))
     phase = vs.get_variable(
         "phase", [self._num_units],
-        initializer=init_ops.random_uniform_initializer(
-            0., period.initial_value))
+        initializer=init_ops.random_uniform_initializer(0.,
+                                                        period.initial_value))
     ratio_on = vs.get_variable(
         "ratio_on", [self._num_units],
         initializer=init_ops.constant_initializer(self._ratio_on),
@@ -2008,6 +2040,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
 
     return new_h, new_state
 
+
 class ConvLSTMCell(rnn_cell_impl.RNNCell):
   """Convolutional LSTM recurrent network cell.
 
@@ -2041,7 +2074,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
     """
     super(ConvLSTMCell, self).__init__(name=name)
 
-    if conv_ndims != len(input_shape)-1:
+    if conv_ndims != len(input_shape) - 1:
       raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
           input_shape, conv_ndims))
 
@@ -2060,8 +2093,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
     state_size = tensor_shape.TensorShape(
         self._input_shape[:-1] + [self._output_channels])
     self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
-    self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
-                                                 + [self._total_output_channels])
+    self._output_size = tensor_shape.TensorShape(
+        self._input_shape[:-1] + [self._total_output_channels])
 
   @property
   def output_size(self):
@@ -2073,13 +2106,10 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
 
   def call(self, inputs, state, scope=None):
     cell, hidden = state
-    new_hidden = _conv([inputs, hidden],
-                       self._kernel_shape,
-                       4*self._output_channels,
-                       self._use_bias)
-    gates = array_ops.split(value=new_hidden,
-                            num_or_size_splits=4,
-                            axis=self._conv_ndims+1)
+    new_hidden = _conv([inputs, hidden], self._kernel_shape,
+                       4 * self._output_channels, self._use_bias)
+    gates = array_ops.split(
+        value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
 
     input_gate, new_input, forget_gate, output_gate = gates
     new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
@@ -2091,29 +2121,35 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
     new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
     return output, new_state
 
+
 class Conv1DLSTMCell(ConvLSTMCell):
   """1D Convolutional LSTM recurrent network cell.
 
   https://arxiv.org/pdf/1506.04214v1.pdf
   """
+
   def __init__(self, name="conv_1d_lstm_cell", **kwargs):
     """Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
     super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
 
+
 class Conv2DLSTMCell(ConvLSTMCell):
   """2D Convolutional LSTM recurrent network cell.
 
   https://arxiv.org/pdf/1506.04214v1.pdf
   """
+
   def __init__(self, name="conv_2d_lstm_cell", **kwargs):
     """Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
     super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
 
+
 class Conv3DLSTMCell(ConvLSTMCell):
   """3D Convolutional LSTM recurrent network cell.
 
   https://arxiv.org/pdf/1506.04214v1.pdf
   """
+
   def __init__(self, name="conv_3d_lstm_cell", **kwargs):
     """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
     super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
@@ -2138,7 +2174,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
   shapes = [a.get_shape().as_list() for a in args]
   shape_length = len(shapes[0])
   for shape in shapes:
-    if len(shape) not in [3,4,5]:
+    if len(shape) not in [3, 4, 5]:
       raise ValueError("Conv Linear expects 3D, 4D "
                        "or 5D arguments: %s" % str(shapes))
     if len(shape) != len(shapes[0]):
@@ -2149,40 +2185,36 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
   dtype = [a.dtype for a in args][0]
 
   # determine correct conv operation
-  if   shape_length == 3:
+  if shape_length == 3:
     conv_op = nn_ops.conv1d
     strides = 1
   elif shape_length == 4:
     conv_op = nn_ops.conv2d
-    strides = shape_length*[1]
+    strides = shape_length * [1]
   elif shape_length == 5:
     conv_op = nn_ops.conv3d
-    strides = shape_length*[1]
+    strides = shape_length * [1]
 
   # Now the computation.
   kernel = vs.get_variable(
-      "kernel",
-      filter_size + [total_arg_size_depth, num_features],
-      dtype=dtype)
+      "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
   if len(args) == 1:
-    res = conv_op(args[0],
-                  kernel,
-                  strides,
-                  padding='SAME')
+    res = conv_op(args[0], kernel, strides, padding="SAME")
   else:
-    res = conv_op(array_ops.concat(axis=shape_length-1, values=args),
-                  kernel,
-                  strides,
-                  padding='SAME')
+    res = conv_op(
+        array_ops.concat(axis=shape_length - 1, values=args),
+        kernel,
+        strides,
+        padding="SAME")
   if not bias:
     return res
   bias_term = vs.get_variable(
       "biases", [num_features],
       dtype=dtype,
-      initializer=init_ops.constant_initializer(
-          bias_start, dtype=dtype))
+      initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
   return res + bias_term
 
+
 class GLSTMCell(rnn_cell_impl.RNNCell):
   """Group LSTM cell (G-LSTM).
 
@@ -2194,8 +2226,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
   "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
   """
 
-  def __init__(self, num_units, initializer=None, num_proj=None,
-               number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh,
+  def __init__(self,
+               num_units,
+               initializer=None,
+               num_proj=None,
+               number_of_groups=1,
+               forget_bias=1.0,
+               activation=math_ops.tanh,
                reuse=None):
     """Initialize the parameters of G-LSTM cell.
 
@@ -2232,11 +2269,15 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
     if self._num_proj:
       if self._num_proj % self._number_of_groups != 0:
         raise ValueError("num_proj must be divisible by number_of_groups")
-      self._group_shape = [int(self._num_proj / self._number_of_groups),
-                           int(self._num_units / self._number_of_groups)]
+      self._group_shape = [
+          int(self._num_proj / self._number_of_groups),
+          int(self._num_units / self._number_of_groups)
+      ]
     else:
-      self._group_shape = [int(self._num_units / self._number_of_groups),
-                           int(self._num_units / self._number_of_groups)]
+      self._group_shape = [
+          int(self._num_units / self._number_of_groups),
+          int(self._num_units / self._number_of_groups)
+      ]
 
     if num_proj:
       self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
@@ -2268,10 +2309,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
       subset of inputs corresponding to group "group_id",
       a Tensor, 2D, [batch x num_units/number_of_groups]
     """
-    return array_ops.slice(input_=inputs,
-                           begin=[0, group_id * group_size],
-                           size=[self._batch_size, group_size],
-                           name=("GLSTM_group%d_input_generation" % group_id))
+    return array_ops.slice(
+        input_=inputs,
+        begin=[0, group_id * group_size],
+        size=[self._batch_size, group_size],
+        name=("GLSTM_group%d_input_generation" % group_id))
 
   def call(self, inputs, state):
     """Run one step of G-LSTM.
@@ -2310,10 +2352,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
       for group_id in range(self._number_of_groups):
         with vs.variable_scope("group%d" % group_id):
           x_g_id = array_ops.concat(
-            [self._get_input_for_group(inputs, group_id,
-                                       self._group_shape[0]),
-             self._get_input_for_group(m_prev, group_id,
-                                       self._group_shape[0])], axis=1)
+              [
+                  self._get_input_for_group(inputs, group_id,
+                                            self._group_shape[0]),
+                  self._get_input_for_group(m_prev, group_id,
+                                            self._group_shape[0])
+              ],
+              axis=1)
           if self._linear1 is None:
             self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False)
           R_k = self._linear1(x_g_id)  # pylint: disable=invalid-name
@@ -2324,34 +2369,35 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
         f_parts.append(f_k)
         o_parts.append(o_k)
 
-      bi = vs.get_variable(name="bias_i",
-                           shape=[self._num_units],
-                           dtype=dtype,
-                           initializer=
-                           init_ops.constant_initializer(0.0, dtype=dtype))
-      bj = vs.get_variable(name="bias_j",
-                           shape=[self._num_units],
-                           dtype=dtype,
-                           initializer=
-                           init_ops.constant_initializer(0.0, dtype=dtype))
-      bf = vs.get_variable(name="bias_f",
-                           shape=[self._num_units],
-                           dtype=dtype,
-                           initializer=
-                           init_ops.constant_initializer(0.0, dtype=dtype))
-      bo = vs.get_variable(name="bias_o",
-                           shape=[self._num_units],
-                           dtype=dtype,
-                           initializer=
-                           init_ops.constant_initializer(0.0, dtype=dtype))
+      bi = vs.get_variable(
+          name="bias_i",
+          shape=[self._num_units],
+          dtype=dtype,
+          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+      bj = vs.get_variable(
+          name="bias_j",
+          shape=[self._num_units],
+          dtype=dtype,
+          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+      bf = vs.get_variable(
+          name="bias_f",
+          shape=[self._num_units],
+          dtype=dtype,
+          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+      bo = vs.get_variable(
+          name="bias_o",
+          shape=[self._num_units],
+          dtype=dtype,
+          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
 
       i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
       j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
       f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
       o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
 
-    c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
-         math_ops.sigmoid(i) * math_ops.tanh(j))
+    c = (
+        math_ops.sigmoid(f + self._forget_bias) * c_prev +
+        math_ops.sigmoid(i) * math_ops.tanh(j))
     m = math_ops.sigmoid(o) * self._activation(c)
 
     if self._num_proj is not None:
@@ -2636,10 +2682,12 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
 
 class SRUCell(rnn_cell_impl._LayerRNNCell):
   """SRU, Simple Recurrent Unit
+
      Implementation based on
      Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
 
-     This variation of RNN cell is characterized by the simplified data dependence
+     This variation of RNN cell is characterized by the simplified data
+     dependence
      between hidden states of two consecutive time steps. Traditionally, hidden
      states from a cell at time step t-1 needs to be multiplied with a matrix
      W_hh before being fed into the ensuing cell at time step t.
@@ -2657,8 +2705,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
       will share weights, but to avoid mistakes we require reuse=True in such
       cases.
   """
-  def __init__(self, num_units,
-               activation=None, reuse=None, name=None):
+
+  def __init__(self, num_units, activation=None, reuse=None, name=None):
     super(SRUCell, self).__init__(_reuse=reuse, name=name)
     self._num_units = num_units
     self._activation = activation or math_ops.tanh
@@ -2676,8 +2724,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
 
   def build(self, inputs_shape):
     if inputs_shape[1].value is None:
-      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % inputs_shape)
+      raise ValueError(
+          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
 
     input_depth = inputs_shape[1].value
 
@@ -2712,12 +2760,12 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
     """Simple recurrent unit (SRU) with num_units cells."""
 
     U = math_ops.matmul(inputs, self._kernel)
-    x_bar, f_intermediate, r_intermediate = array_ops.split(value=U,
-                                                            num_or_size_splits=3,
-                                                            axis=1)
+    x_bar, f_intermediate, r_intermediate = array_ops.split(
+        value=U, num_or_size_splits=3, axis=1)
 
-    f_r = math_ops.sigmoid(nn_ops.bias_add(array_ops.concat(
-        [f_intermediate, r_intermediate], 1), self._bias))
+    f_r = math_ops.sigmoid(
+        nn_ops.bias_add(
+            array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
     f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
 
     c = f * state + (1.0 - f) * x_bar
@@ -2750,9 +2798,16 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
     large scale acoustic modeling." INTERSPEECH, 2014.
   """
 
-  def __init__(self, num_units, norm=True, use_peepholes=False,
-               cell_clip=None, initializer=None, num_proj=None,
-               proj_clip=None, forget_bias=1, activation=None,
+  def __init__(self,
+               num_units,
+               norm=True,
+               use_peepholes=False,
+               cell_clip=None,
+               initializer=None,
+               num_proj=None,
+               proj_clip=None,
+               forget_bias=1,
+               activation=None,
                reuse=None):
     """Initialize the parameters of a weight-normalized LSTM cell.
 
@@ -2779,7 +2834,7 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
     """
     super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
 
-    self._scope = 'wn_lstm_cell'
+    self._scope = "wn_lstm_cell"
     self._num_units = num_units
     self._norm = norm
     self._initializer = initializer
@@ -2822,7 +2877,8 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
     g = vs.get_variable(name, [output_size], dtype=weight.dtype)
     return nn_impl.l2_normalize(weight, dim=0) * g
 
-  def _linear(self, args,
+  def _linear(self,
+              args,
               output_size,
               norm,
               bias,
@@ -2877,8 +2933,8 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
         with ops.control_dependencies(None):
           for i in range(len(args)):
             en = st + shapes[i][1].value
-            wn.append(self._normalize(weights[st:en, :],
-                                      name='norm_{}'.format(i)))
+            wn.append(
+                self._normalize(weights[st:en, :], name="norm_{}".format(i)))
             st = en
 
           weights = array_ops.concat(wn, axis=0)
@@ -2936,8 +2992,8 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
 
     with vs.variable_scope(self._scope, initializer=self._initializer):
 
-      concat = self._linear([inputs, h], 4 * num_units,
-                            norm=self._norm, bias=True)
+      concat = self._linear(
+          [inputs, h], 4 * num_units, norm=self._norm, bias=True)
 
       # i = input_gate, j = new_input, f = forget_gate, o = output_gate
       i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
@@ -2947,11 +3003,13 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
         w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
         w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
 
-        new_c = (c * sigmoid(f + self._forget_bias + w_f_diag * c)
-                 + sigmoid(i + w_i_diag * c) * self._activation(j))
+        new_c = (
+            c * sigmoid(f + self._forget_bias + w_f_diag * c) +
+            sigmoid(i + w_i_diag * c) * self._activation(j))
       else:
-        new_c = (c * sigmoid(f + self._forget_bias)
-                 + sigmoid(i) * self._activation(j))
+        new_c = (
+            c * sigmoid(f + self._forget_bias) +
+            sigmoid(i) * self._activation(j))
 
       if self._cell_clip is not None:
         # pylint: disable=invalid-unary-operand-type
@@ -2964,15 +3022,12 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
 
       if self._num_proj is not None:
         with vs.variable_scope("projection"):
-          new_h = self._linear(new_h,
-                               self._num_proj,
-                               norm=self._norm,
-                               bias=False)
+          new_h = self._linear(
+              new_h, self._num_proj, norm=self._norm, bias=False)
 
         if self._proj_clip is not None:
           # pylint: disable=invalid-unary-operand-type
-          new_h = clip_ops.clip_by_value(new_h,
-                                         -self._proj_clip,
+          new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
                                          self._proj_clip)
           # pylint: enable=invalid-unary-operand-type
 
index f498b2bb5709ea28faca1c5cfa21ad30aac14ab7..926554031775202d7f7d9018cf6ae4efb34fe96b 100644 (file)
@@ -46,20 +46,18 @@ class TestGatherTree(test.TestCase):
 
     # create (batch_size, max_time, beam_width) matrix and transpose it
     predicted_ids = np.array(
-        [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
-         [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
+        [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
         dtype=np.int32).transpose([1, 0, 2])
     parent_ids = np.array(
-        [[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
-         [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
+        [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
         dtype=np.int32).transpose([1, 0, 2])
 
     # sequence_lengths is shaped (batch_size = 3)
     max_sequence_lengths = [3, 3]
 
-    expected_result = np.array(
-        [[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
-         [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
+    expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
+                                [[2, 4, 4], [7, 6, 6],
+                                 [8, 9, 10]]]).transpose([1, 0, 2])
 
     res = beam_search_ops.gather_tree(
         predicted_ids,
@@ -157,8 +155,8 @@ class TestBeamStep(test.TestCase):
     self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
     self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
     self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
-    self.assertAllEqual(next_state_.finished, [[False, False, False],
-                                               [False, False, False]])
+    self.assertAllEqual(next_state_.finished,
+                        [[False, False, False], [False, False, False]])
 
     expected_log_probs = []
     expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -212,8 +210,8 @@ class TestBeamStep(test.TestCase):
     self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
     self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
     self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
-    self.assertAllEqual(next_state_.finished, [[True, False, False],
-                                               [False, True, False]])
+    self.assertAllEqual(next_state_.finished,
+                        [[True, False, False], [False, True, False]])
 
     expected_log_probs = []
     expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -226,9 +224,10 @@ class TestBeamStep(test.TestCase):
 
 
 class TestLargeBeamStep(test.TestCase):
-  """
-  Tests a single step of beam search in such
-  case that beam size is larger than vocabulary size.
+  """Tests large beam step.
+
+  Tests a single step of beam search in such case that beam size is larger than
+  vocabulary size.
   """
 
   def setUp(self):
@@ -239,19 +238,21 @@ class TestLargeBeamStep(test.TestCase):
     self.end_token = 0
     self.length_penalty_weight = 0.6
 
-
   def test_step(self):
-    def get_probs():
-      """this simulates the initialize method in BeamSearchDecoder"""
-      log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size],
-                                                        dtype=dtypes.int32),
-                                        depth=self.beam_width, on_value=True,
-                                        off_value=False, dtype=dtypes.bool)
 
-      log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width],
-                                       dtype=dtypes.float32)
-      log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width],
-                                        dtype=dtypes.float32) * -np.Inf
+    def get_probs():
+      """this simulates the initialize method in BeamSearchDecoder."""
+      log_prob_mask = array_ops.one_hot(
+          array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+          depth=self.beam_width,
+          on_value=True,
+          off_value=False,
+          dtype=dtypes.bool)
+
+      log_prob_zeros = array_ops.zeros(
+          [self.batch_size, self.beam_width], dtype=dtypes.float32)
+      log_prob_neg_inf = array_ops.ones(
+          [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
 
       log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
                                   log_prob_neg_inf)
@@ -260,12 +261,15 @@ class TestLargeBeamStep(test.TestCase):
     log_probs = get_probs()
     dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
 
+    # pylint: disable=invalid-name
     _finished = array_ops.one_hot(
         array_ops.zeros([self.batch_size], dtype=dtypes.int32),
-        depth=self.beam_width, on_value=False,
-        off_value=True, dtype=dtypes.bool)
+        depth=self.beam_width,
+        on_value=False,
+        off_value=True,
+        dtype=dtypes.bool)
     _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
-    _lengths[:, 0]=2
+    _lengths[:, 0] = 2
     _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
 
     beam_state = beam_search_decoder.BeamSearchDecoderState(
@@ -298,20 +302,20 @@ class TestLargeBeamStep(test.TestCase):
         length_penalty_weight=self.length_penalty_weight)
 
     with self.test_session() as sess:
-      outputs_, next_state_, state_, log_probs_ = sess.run(
+      outputs_, next_state_, _, _ = sess.run(
           [outputs, next_beam_state, beam_state, log_probs])
 
     self.assertEqual(outputs_.predicted_ids[0, 0], 3)
     self.assertEqual(outputs_.predicted_ids[0, 1], 2)
     self.assertEqual(outputs_.predicted_ids[1, 0], 1)
     neg_inf = -np.Inf
-    self.assertAllEqual(next_state_.log_probs[:, -3:],
-                        [[neg_inf, neg_inf, neg_inf],
-                         [neg_inf, neg_inf, neg_inf]])
+    self.assertAllEqual(
+        next_state_.log_probs[:, -3:],
+        [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
     self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
     self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
-    self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0],
-                                                      [0, 0, 0]])
+    self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
+
 
 class BeamSearchDecoderTest(test.TestCase):
 
@@ -338,8 +342,8 @@ class BeamSearchDecoderTest(test.TestCase):
       initial_state = cell.zero_state(batch_size, dtypes.float32)
       if has_attention:
         inputs = array_ops.placeholder_with_default(
-            np.random.randn(batch_size, decoder_max_time,
-                            input_depth).astype(np.float32),
+            np.random.randn(batch_size, decoder_max_time, input_depth).astype(
+                np.float32),
             shape=(None, None, input_depth))
         tiled_inputs = beam_search_decoder.tile_batch(
             inputs, multiplier=beam_width)
@@ -359,8 +363,7 @@ class BeamSearchDecoderTest(test.TestCase):
       cell_state = cell.zero_state(
           dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
       if has_attention:
-        cell_state = cell_state.clone(
-            cell_state=initial_state)
+        cell_state = cell_state.clone(cell_state=initial_state)
       bsd = beam_search_decoder.BeamSearchDecoder(
           cell=cell,
           embedding=embedding,
index a5f7169c3106d12cd22e822dca96c6adf43a45fe..d6184d61095f727f9dcab56fe59e2601868c1624 100644 (file)
@@ -37,7 +37,6 @@ from tensorflow.python.ops import rnn_cell_impl
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.util import nest
 
-
 __all__ = [
     "BeamSearchDecoderOutput",
     "BeamSearchDecoderState",
@@ -48,8 +47,8 @@ __all__ = [
 
 
 class BeamSearchDecoderState(
-    collections.namedtuple("BeamSearchDecoderState", ("cell_state", "log_probs",
-                                                      "finished", "lengths"))):
+    collections.namedtuple("BeamSearchDecoderState",
+                           ("cell_state", "log_probs", "finished", "lengths"))):
   pass
 
 
@@ -85,11 +84,12 @@ def _tile_batch(t, multiplier):
   tiled_static_batch_size = (
       t.shape[0].value * multiplier if t.shape[0].value is not None else None)
   tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
-  tiled = array_ops.reshape(
-      tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
+  tiled = array_ops.reshape(tiled,
+                            array_ops.concat(
+                                ([shape_t[0] * multiplier], shape_t[1:]), 0))
   tiled.set_shape(
-      tensor_shape.TensorShape(
-          [tiled_static_batch_size]).concatenate(t.shape[1:]))
+      tensor_shape.TensorShape([tiled_static_batch_size]).concatenate(
+          t.shape[1:]))
   return tiled
 
 
@@ -197,8 +197,8 @@ class BeamSearchDecoder(decoder.Decoder):
     """
     if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
       raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
-    if (output_layer is not None
-        and not isinstance(output_layer, layers_base.Layer)):
+    if (output_layer is not None and
+        not isinstance(output_layer, layers_base.Layer)):
       raise TypeError(
           "output_layer must be a Layer, received: %s" % type(output_layer))
     self._cell = cell
@@ -223,16 +223,17 @@ class BeamSearchDecoder(decoder.Decoder):
     self._beam_width = beam_width
     self._length_penalty_weight = length_penalty_weight
     self._initial_cell_state = nest.map_structure(
-        self._maybe_split_batch_beams,
-        initial_state, self._cell.state_size)
+        self._maybe_split_batch_beams, initial_state, self._cell.state_size)
     self._start_tokens = array_ops.tile(
         array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
     self._start_inputs = self._embedding_fn(self._start_tokens)
-    
+
     self._finished = array_ops.one_hot(
         array_ops.zeros([self._batch_size], dtype=dtypes.int32),
-        depth=self._beam_width, on_value=False,
-        off_value=True, dtype=dtypes.bool)
+        depth=self._beam_width,
+        on_value=False,
+        off_value=True,
+        dtype=dtypes.bool)
 
   @property
   def batch_size(self):
@@ -250,8 +251,7 @@ class BeamSearchDecoder(decoder.Decoder):
       # dimensions to get the output size of the rnn with the layer
       # applied to the top.
       output_shape_with_unknown_batch = nest.map_structure(
-          lambda s: tensor_shape.TensorShape([None]).concatenate(s),
-          size)
+          lambda s: tensor_shape.TensorShape([None]).concatenate(s), size)
       layer_output_shape = self._output_layer.compute_output_shape(
           output_shape_with_unknown_batch)
       return nest.map_structure(lambda s: s[1:], layer_output_shape)
@@ -302,10 +302,11 @@ class BeamSearchDecoder(decoder.Decoder):
 
     log_probs = array_ops.one_hot(  # shape(batch_sz, beam_sz)
         array_ops.zeros([self._batch_size], dtype=dtypes.int32),
-        depth=self._beam_width, on_value=0.0, off_value=-np.Inf,
+        depth=self._beam_width,
+        on_value=0.0,
+        off_value=-np.Inf,
         dtype=nest.flatten(self._initial_cell_state)[0].dtype)
 
-
     initial_state = BeamSearchDecoderState(
         cell_state=self._initial_cell_state,
         log_probs=log_probs,
@@ -365,11 +366,12 @@ class BeamSearchDecoder(decoder.Decoder):
     t_shape = array_ops.shape(t)
     static_batch_size = tensor_util.constant_value(self._batch_size)
     batch_size_beam_width = (
-        None if static_batch_size is None
-        else static_batch_size * self._beam_width)
+        None
+        if static_batch_size is None else static_batch_size * self._beam_width)
     reshaped_t = array_ops.reshape(
-        t, array_ops.concat(
-            ([self._batch_size * self._beam_width], t_shape[2:]), 0))
+        t,
+        array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]),
+                         0))
     reshaped_t.set_shape(
         (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s)))
     return reshaped_t
@@ -398,8 +400,9 @@ class BeamSearchDecoder(decoder.Decoder):
       s = tensor_shape.TensorShape(s)
     t_shape = array_ops.shape(t)
     reshaped_t = array_ops.reshape(
-        t, array_ops.concat(
-            ([self._batch_size, self._beam_width], t_shape[1:]), 0))
+        t,
+        array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]),
+                         0))
     static_batch_size = tensor_util.constant_value(self._batch_size)
     expected_reshaped_shape = tensor_shape.TensorShape(
         [static_batch_size, self._beam_width]).concatenate(s)
@@ -409,8 +412,8 @@ class BeamSearchDecoder(decoder.Decoder):
                        "We expected it to have shape "
                        "(batch_size, beam_width, depth) == %s.  Perhaps you "
                        "forgot to create a zero_state with "
-                       "batch_size=encoder_batch_size * beam_width?"
-                       (reshaped_t.shape, expected_reshaped_shape))
+                       "batch_size=encoder_batch_size * beam_width?" %
+                       (reshaped_t.shape, expected_reshaped_shape))
     reshaped_t.set_shape(expected_reshaped_shape)
     return reshaped_t
 
@@ -482,15 +485,13 @@ class BeamSearchDecoder(decoder.Decoder):
       cell_state = state.cell_state
       inputs = nest.map_structure(
           lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
-      cell_state = nest.map_structure(
-          self._maybe_merge_batch_beams,
-          cell_state, self._cell.state_size)
+      cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,
+                                      self._cell.state_size)
       cell_outputs, next_cell_state = self._cell(inputs, cell_state)
       cell_outputs = nest.map_structure(
           lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
       next_cell_state = nest.map_structure(
-          self._maybe_split_batch_beams,
-          next_cell_state, self._cell.state_size)
+          self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
 
       if self._output_layer is not None:
         cell_outputs = self._output_layer(cell_outputs)
@@ -553,7 +554,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
   lengths_to_add = array_ops.one_hot(
       indices=array_ops.fill([batch_size, beam_width], end_token),
       depth=vocab_size,
-      on_value=np.int64(0), off_value=np.int64(1),
+      on_value=np.int64(0),
+      off_value=np.int64(1),
       dtype=dtypes.int64)
   add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
   lengths_to_add *= array_ops.expand_dims(add_mask, 2)
@@ -572,8 +574,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
   scores_flat = array_ops.reshape(scores, [batch_size, -1])
 
   # Pick the next beams according to the specified successors function
-  next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32,
-                                         name="beam_width")
+  next_beam_size = ops.convert_to_tensor(
+      beam_width, dtype=dtypes.int32, name="beam_width")
   next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
 
   next_beam_scores.set_shape([static_batch_size, beam_width])
@@ -592,11 +594,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
   #       name="next_beam_word_ids")
   # would be a lot cleaner but for reasons unclear, that hides the results of
   # the op which prevents capturing it with tfdbg debug ops.
-  raw_next_word_ids = math_ops.mod(word_indices, vocab_size,
-                                   name="next_beam_word_ids")
+  raw_next_word_ids = math_ops.mod(
+      word_indices, vocab_size, name="next_beam_word_ids")
   next_word_ids = math_ops.to_int32(raw_next_word_ids)
-  next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
-                                    name="next_beam_parent_ids")
+  next_beam_ids = math_ops.to_int32(
+      word_indices / vocab_size, name="next_beam_parent_ids")
 
   # Append new ids to current predictions
   previously_finished = _tensor_gather_helper(
@@ -605,9 +607,10 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
       batch_size=batch_size,
       range_size=beam_width,
       gather_shape=[-1])
-  next_finished = math_ops.logical_or(previously_finished,
-                                      math_ops.equal(next_word_ids, end_token),
-                                      name="next_beam_finished")
+  next_finished = math_ops.logical_or(
+      previously_finished,
+      math_ops.equal(next_word_ids, end_token),
+      name="next_beam_finished")
 
   # Calculate the length of the next predictions.
   # 1. Finished beams remain unchanged.
@@ -768,8 +771,12 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
     return gather_from
 
 
-def _tensor_gather_helper(gather_indices, gather_from, batch_size,
-                          range_size, gather_shape, name=None):
+def _tensor_gather_helper(gather_indices,
+                          gather_from,
+                          batch_size,
+                          range_size,
+                          gather_shape,
+                          name=None):
   """Helper for gathering the right indices from the tensor.
 
   This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
@@ -800,9 +807,9 @@ def _tensor_gather_helper(gather_indices, gather_from, batch_size,
         array_ops.reshape(gather_from, gather_shape), gather_indices)
     final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
     static_batch_size = tensor_util.constant_value(batch_size)
-    final_static_shape = (tensor_shape.TensorShape([static_batch_size])
-                          .concatenate(
-                              gather_from.shape[1:1 + len(gather_shape)]))
+    final_static_shape = (
+        tensor_shape.TensorShape([static_batch_size]).concatenate(
+            gather_from.shape[1:1 + len(gather_shape)]))
     output = array_ops.reshape(output, final_shape, name="output")
     output.set_shape(final_static_shape)
     return output
index ec5271abe04e51f7c4c3fb4358f8ed79835b74c9..7d95b6522c5149aaf960acc43bc78806c3bd5954 100644 (file)
@@ -15,10 +15,11 @@ limitations under the License.
 
 #ifdef TENSORFLOW_USE_VERBS
 
+#include <fcntl.h>
+#include <cstdlib>
+
 #include "tensorflow/contrib/verbs/rdma.h"
 #include "tensorflow/contrib/verbs/verbs_service.pb.h"
-#include <cstdlib>
-#include <fcntl.h>
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/common_runtime/process_util.h"
@@ -27,15 +28,15 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/gpu/process_state.h"
 #endif
 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
-#include "tensorflow/core/distributed_runtime/session_mgr.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
 #include "tensorflow/core/framework/rendezvous.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/core/threadpool.h"
 
 namespace tensorflow {
 
@@ -447,9 +448,9 @@ void RdmaAdapter::Process_CQ() {
     CHECK_GE(ne, 0);
     for (int i = 0; i < ne; ++i) {
       CHECK(wc_[i].status == IBV_WC_SUCCESS)
-          << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
-          << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
-          << wc_[i].vendor_err;
+          << "Failed status \n"
+          << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+          << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
       if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
         RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
         // put back a recv wr.
@@ -538,7 +539,7 @@ int RdmaChannel::PingPostRecv() {
 int RdmaChannel::PingPostSend() {
   struct ibv_send_wr wr, *bad_wr;
   memset(&wr, 0, sizeof(wr));
-  wr.wr_id = (uint64_t) this;
+  wr.wr_id = (uint64_t)this;
   wr.sg_list = &ping_sge_list_;
   wr.num_sge = 1;
   wr.opcode = IBV_WR_SEND;
@@ -658,7 +659,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
 void RdmaChannel::Recv() {
   struct ibv_recv_wr wr;
   memset(&wr, 0, sizeof(wr));
-  wr.wr_id = (uint64_t) this;
+  wr.wr_id = (uint64_t)this;
   struct ibv_recv_wr* bad_wr;
   CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
 }
@@ -729,11 +730,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
     attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
 
     int r;
-    CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
-                                              IBV_QP_PATH_MTU |
-                                              IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
-                                              IBV_QP_MAX_DEST_RD_ATOMIC |
-                                              IBV_QP_MIN_RNR_TIMER)))
+    CHECK(!(r = ibv_modify_qp(qp_, &attr,
+                              IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+                                  IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+                                  IBV_QP_MAX_DEST_RD_ATOMIC |
+                                  IBV_QP_MIN_RNR_TIMER)))
         << "QP to Ready to Receive " << r;
 
     memset(&attr, 0, sizeof(ibv_qp_attr));
@@ -744,10 +745,10 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
     attr.rnr_retry = 7; /* infinite */
     attr.max_rd_atomic = 1;
 
-    CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
-                                              IBV_QP_RETRY_CNT |
-                                              IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
-                                              IBV_QP_MAX_QP_RD_ATOMIC)))
+    CHECK(!(r = ibv_modify_qp(qp_, &attr,
+                              IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+                                  IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+                                  IBV_QP_MAX_QP_RD_ATOMIC)))
         << "QP to Ready to Send " << r;
 
     connected_ = true;
@@ -897,16 +898,16 @@ static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
   }
   if ((++numTotalCopies % 0x400) == 0) {
     RDMA_LOG(0) << "Tensor copies:"
-                << " GPU to CPU: " << numGPUToCPUCopies
-                << " (" << numGPUToCPUCopiedBytes << " Bytes)"
-                << " CPU to GPU: " << numCPUToGPUCopies
-                << " (" << numCPUToGPUCopiedBytes << " Bytes)";
+                << " GPU to CPU: " << numGPUToCPUCopies << " ("
+                << numGPUToCPUCopiedBytes << " Bytes)"
+                << " CPU to GPU: " << numCPUToGPUCopies << " ("
+                << numCPUToGPUCopiedBytes << " Bytes)";
   }
-  RDMA_LOG(2) << "Copying tensor " << key
-              << " From: " << src_addr << " To: " << dst_addr;
-#endif // RDMA_COUNT_COPIES
+  RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr
+              << " To: " << dst_addr;
+#endif  // RDMA_COUNT_COPIES
 }
-#endif // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA
 
 #ifdef RDMA_DATA_VALIDATION
 static uint64_t Checksum(Device* device, const DeviceContext* device_context,
@@ -920,7 +921,7 @@ static uint64_t Checksum(Device* device, const DeviceContext* device_context,
     checksum = (device_context != nullptr)
                    ? GPUUtil::Checksum(device, device_context, in)
                    : GPUUtil::Checksum(in);
-#endif // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA
   } else {
     string s = in.SummarizeValue(999999);
     checksum = Hash64(s.c_str(), s.size(), 0);
@@ -955,17 +956,16 @@ static void ValidateChecksum(uint64_t expected, uint64_t actual,
     }
   }
 }
-#endif // RDMA_DATA_VALIDATION
+#endif  // RDMA_DATA_VALIDATION
 
 #if GOOGLE_CUDA
 // Sync the 'done' operation on the GPU stream, but without all the data
 // copying.
-static void StreamGPUOp(Device* gpu_device,
-                        const DeviceContext* device_context,
+static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context,
                         StatusCallback done) {
   Tensor dummy1, dummy2;
-  GPUUtil::CopyGPUTensorToCPU(
-      gpu_device, device_context, &dummy1, &dummy2, done);
+  GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2,
+                              done);
 }
 #endif  // GOOGLE_CUDA
 
@@ -1072,7 +1072,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
       // skip the copy here as well.
       if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
           (RdmaMemoryMgr::Singleton().FindMemoryRegion(
-              (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
+               (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
         StreamGPUOp(src_dev_, send_dev_context,
                     [this, in, proto, is_dead](const Status& s) {
                       Send(in, proto, is_dead, s);
@@ -1118,8 +1118,8 @@ void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
     return;
   }
   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
-  bool proto_size_changed = (!can_memcpy) &&
-                            (proto.ByteSize() != rm_.tensor_bytes_);
+  bool proto_size_changed =
+      (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_);
   if (meta_data_changed_ || proto_size_changed) {
     Clone(in, proto, is_dead);
     SendMetaData(in, proto, is_dead);
@@ -1238,9 +1238,8 @@ void RdmaTensorResponse::SendErrorStatus(const Status& status) {
   rm.request_index_ = rm_.request_index_;
   rm.status_ = status;
   LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
-             << ": Sending RDMA_MESSAGE_ERROR_STATUS #"
-             << rm.request_index_ << ": " << rm.name_
-             << ". Status: " << status.ToString();
+             << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_
+             << ": " << rm.name_ << ". Status: " << status.ToString();
 
   string message = RdmaMessage::CreateMessage(rm);
   channel_->tx_message_buffer_->EnqueueItem(message);
@@ -1336,14 +1335,13 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
     uint32_t gsProtoSize = gsProto.ByteSize();
     if (gsProtoSize + 4 > kErrorStatusMaxSize) {
       LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
-                 << "is too big to fit in RDMA message ("
-                 << kErrorStatusMaxSize << " bytes). Truncated.";
+                 << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
+                 << " bytes). Truncated.";
       gsProtoSize = kErrorStatusMaxSize - 4;
     }
     uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
     *proto_size = gsProtoSize;
-    gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4],
-                             gsProtoSize);
+    gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
     message_size += gsProtoSize + 4;
   }
   return string(message, message_size);
@@ -1393,8 +1391,8 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
   if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
     ErrorStatusProto gsProto;
     uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
-    CHECK(ParseProtoUnlimited(
-        &gsProto, &message[kErrorStatusStartIndex + 4], gsProtoSize))
+    CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4],
+                              gsProtoSize))
         << "Failed to parse error status proto from message. Aborting.";
     ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
                       gsProto.error_message(), gsProto.error_details());
@@ -1566,8 +1564,8 @@ void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
   if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
       (proxy_tensor_ == nullptr)) {
 #if GOOGLE_CUDA
-        // We need to sync the memory allocation on the GPU:
-        StreamGPUOp(dst_dev_, recv_args_.device_context, done);
+    // We need to sync the memory allocation on the GPU:
+    StreamGPUOp(dst_dev_, recv_args_.device_context, done);
 #endif
   } else {
     done(Status::OK());
@@ -1594,9 +1592,8 @@ void RdmaTensorRequest::Send(RdmaMessageType message_type) {
   rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
 
   RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
-              << ": Sending  " << MessageTypeToString(message_type)
-              << " #" << index_ << ": "
-              << rm.name_ << " on " << rdma_addr_
+              << ": Sending  " << MessageTypeToString(message_type) << " #"
+              << index_ << ": " << rm.name_ << " on " << rdma_addr_
               << " (rkey: 0x" << std::hex << rm.rkey_ << ")";
 
   string message = RdmaMessage::CreateMessage(rm);
@@ -1610,9 +1607,8 @@ void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
       key_, dtype, shape, is_dead, proto_size);
 
   DeallocateTensors();
-  AllocateTensorsAsync([this](const Status& s) {
-    Send(RDMA_MESSAGE_TENSOR_RE_REQUEST);
-  });
+  AllocateTensorsAsync(
+      [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); });
 }
 
 void RdmaTensorRequest::RecvTensorContent() {
@@ -1620,8 +1616,8 @@ void RdmaTensorRequest::RecvTensorContent() {
   size_t message_size =
       can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
   RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
-              << ": Received tensor content #" << index_ << ": "
-              << key_ << " (Size: 0x" << std::hex << message_size << ")";
+              << ": Received tensor content #" << index_ << ": " << key_
+              << " (Size: 0x" << std::hex << message_size << ")";
 
   Tensor val;
 
@@ -1667,9 +1663,8 @@ void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
 void RdmaTensorRequest::Start() {
   meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
   if (meta_data_ != nullptr) {
-    AllocateTensorsAsync([this](const Status& s) {
-      Send(RDMA_MESSAGE_TENSOR_REQUEST);
-    });
+    AllocateTensorsAsync(
+        [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); });
   } else {
     Send(RDMA_MESSAGE_TENSOR_REQUEST);
   }
index 68b3d59f56bbaee1182ade10ff78d624f386a6c9..b6c41de6eea1195df5b9f870301106f7f30e4532 100644 (file)
@@ -73,15 +73,8 @@ struct RemoteMR {
   uint64_t remote_addr;
   uint32_t rkey;
 };
-enum BufferStatus {
-  none,
-  idle,
-  busy
-};
-enum Location {
-  local,
-  remote
-};
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
 
 enum RdmaMessageType {
   RDMA_MESSAGE_META_DATA_UPDATE,
index f3644af0b4e1dcb735bb93a18158a425897dd037..369bd986df5313955bc22d6e5c6d38815908ada3 100644 (file)
@@ -116,9 +116,9 @@ void RdmaMgr::SetupChannels() {
         }
         CHECK(i == RdmaChannel::kNumMessageBuffers);
       } else {
-        LOG(ERROR) << "Connecting to " << worker_name
-                   << ": Got " << s.error_message() << ". Retrying ("
-                   << (attempts + 1) << "/" << max_num_attempts << ")..." ;
+        LOG(ERROR) << "Connecting to " << worker_name << ": Got "
+                   << s.error_message() << ". Retrying (" << (attempts + 1)
+                   << "/" << max_num_attempts << ")...";
         if (++attempts == max_num_attempts) {
           break;
         }
@@ -159,19 +159,17 @@ bool RdmaMgr::ConnectivityCheck() {
       ibv_wc_status s = rdma_adapter_->wc_[i].status;
       // recv complete
       if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
-        CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
-                                                  rdma_adapter_->wc_[i].status)
-                                   << "(" << rdma_adapter_->wc_[i].status
-                                   << ") for PING_RECV_WRID";
+        CHECK(s == IBV_WC_SUCCESS)
+            << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
+            << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID";
         ++rcnt;
         // send complete
       } else {
         RdmaChannel* rc =
             reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
-        CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
-                                                  rdma_adapter_->wc_[i].status)
-                                   << "(" << rdma_adapter_->wc_[i].status
-                                   << ") to " << rc->remote_name_;
+        CHECK(s == IBV_WC_SUCCESS)
+            << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
+            << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_;
         ++scnt;
       }
     }  // for
@@ -238,8 +236,9 @@ int TryToReadNumaNode(ibv_device* device) {
   if (strings::safe_strto32(content, &value)) {
     if (value < 0) {
       LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
-                << value << "), but there must be at least one NUMA node"
-                            ", so returning NUMA node zero";
+                << value
+                << "), but there must be at least one NUMA node"
+                   ", so returning NUMA node zero";
       return 0;
     }
     LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
@@ -302,8 +301,8 @@ void RdmaMgr::InitAllocators() {
         &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
 
     auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
-    CHECK(visitable_allocator) << "is not visitable for instrumentation"
-                               << allocator->Name();
+    CHECK(visitable_allocator)
+        << "is not visitable for instrumentation" << allocator->Name();
     // Make sure we don't instrument the same allocator twice
     if (instrumented_.find(allocator) == std::end(instrumented_)) {
       visitable_allocator->AddAllocVisitor(alloc_visitor);
index bb5eceab27252e61fa62b6c9843111f184698d58..89d37d2f874c0b8fa7550b1c49c0e3c4106e2ee5 100644 (file)
@@ -65,13 +65,11 @@ class MklAddNOp : public OpKernel {
     TensorShape src1_shape, src2_shape;
     src1_shape = input0.shape();
     src2_shape = input1.shape();
-    if (!src1_shape.IsSameSize(src2_shape) ){
-      ctx->SetStatus( 
-          errors::InvalidArgument(
-          "Inputs to operation ", this->name(), " of type ", this->type_string(),
-          " must have the same size and shape.  Input 0: ",
-          src1_shape.DebugString(), " != input 1: ",
-          src2_shape.DebugString()));
+    if (!src1_shape.IsSameSize(src2_shape)) {
+      ctx->SetStatus(errors::InvalidArgument(
+          "Inputs to operation ", this->name(), " of type ",
+          this->type_string(), " must have the same size and shape.  Input 0: ",
+          src1_shape.DebugString(), " != input 1: ", src2_shape.DebugString()));
     }
     // handle the case of a scalar
     if (!input1_in_mkl_format && input0.dims() == 0) {
@@ -82,17 +80,16 @@ class MklAddNOp : public OpKernel {
                                 mkl_context.output_shape);
       float user_i1 = (input0.scalar<T>()());
       float user_i2 = (input1.scalar<T>()());
-      out_tensor->scalar<T>()() =
-          std::plus<float>{}(user_i1, user_i2);
+      out_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
       return;
     }
 
     mkl_context.in_dims = input1_in_mkl_format
-        ? mkl_context.input1_shape.GetDimension()
-        : input0.dims();
+                              ? mkl_context.input1_shape.GetDimension()
+                              : input0.dims();
     mkl_context.in_dims = input2_in_mkl_format
-        ? mkl_context.input2_shape.GetDimension()
-        : input1.dims();
+                              ? mkl_context.input2_shape.GetDimension()
+                              : input1.dims();
 
     // If there is nothing to compute, return.
     if (!input1_in_mkl_format && !input2_in_mkl_format) {
@@ -101,7 +98,7 @@ class MklAddNOp : public OpKernel {
         Tensor* out_tensor = nullptr;
         mkl_context.output_shape.SetMklTensor(false);
         AllocateOutputSetMklShape(ctx, src1_idx, &out_tensor, o_shape,
-                                 mkl_context.output_shape);
+                                  mkl_context.output_shape);
         return;
       }
     }
@@ -110,9 +107,9 @@ class MklAddNOp : public OpKernel {
     mkl_context.in_strides = new size_t[mkl_context.in_dims];
     // Generate size, stride for input if input is in MKL format.
     if (input1_in_mkl_format || input2_in_mkl_format) {
-      const MklShape* tmp_mkl_shape =
-        (input1_in_mkl_format) ? &mkl_context.input1_shape :
-        &mkl_context.input2_shape;
+      const MklShape* tmp_mkl_shape = (input1_in_mkl_format)
+                                          ? &mkl_context.input1_shape
+                                          : &mkl_context.input2_shape;
       for (int i = 0; i < mkl_context.in_dims; i++) {
         mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
         mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
@@ -136,32 +133,33 @@ class MklAddNOp : public OpKernel {
 
     Tensor mkl_tmp_input1_buf_tensor, mkl_tmp_input2_buf_tensor;
     mkl_context.MklPrepareAddNInputs(ctx, &mkl_tmp_input1_buf_tensor,
-    &mkl_tmp_input2_buf_tensor);
+                                     &mkl_tmp_input2_buf_tensor);
     Tensor* output = nullptr;
     if (input1_in_mkl_format || input2_in_mkl_format) {
-     TensorShape tf_shape;
-     mkl_context.output_shape.SetMklTensor(true);
-     mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise, dnnResourceDst);
-
-     mkl_context.output_shape.SetTfLayout(
-        mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
-     if (input1_in_mkl_format == true) {
-      mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
-      mkl_context.input1_shape.GetTfToMklDimMap());
-     } else {
-      mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
-      mkl_context.input2_shape.GetTfToMklDimMap());
-     }
-     tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
-                        mkl_context.output_shape.GetMklLayout())) /
-                    sizeof(T));
-
-     AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
-                              mkl_context.output_shape);
+      TensorShape tf_shape;
+      mkl_context.output_shape.SetMklTensor(true);
+      mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise,
+                                            dnnResourceDst);
+
+      mkl_context.output_shape.SetTfLayout(
+          mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
+      if (input1_in_mkl_format == true) {
+        mkl_context.output_shape.SetTfDimOrder(
+            mkl_context.in_dims, mkl_context.input1_shape.GetTfToMklDimMap());
+      } else {
+        mkl_context.output_shape.SetTfDimOrder(
+            mkl_context.in_dims, mkl_context.input2_shape.GetTfToMklDimMap());
+      }
+      tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+                          mkl_context.output_shape.GetMklLayout())) /
+                      sizeof(T));
+
+      AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
+                                mkl_context.output_shape);
     } else {
-     const TensorShape& o_shape = input1.shape();
-     mkl_context.output_shape.SetMklTensor(false);
-     AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
+      const TensorShape& o_shape = input1.shape();
+      mkl_context.output_shape.SetMklTensor(false);
+      AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
                                 mkl_context.output_shape);
     }
 
@@ -189,18 +187,16 @@ class MklAddNOp : public OpKernel {
     void MklCreateInputLayouts(OpKernelContext* context) {
       bool input1_in_mkl_format = input1_shape.IsMklTensor();
       if (!input1_in_mkl_format) {
-        CHECK_EQ(
-            dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
-            E_SUCCESS);
+        CHECK_EQ(dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
+                 E_SUCCESS);
       } else {
         lt_input1 = static_cast<dnnLayout_t>(input1_shape.GetCurLayout());
       }
 
       bool input2_in_mkl_format = input2_shape.IsMklTensor();
       if (!input2_in_mkl_format) {
-        CHECK_EQ(
-            dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
-            E_SUCCESS);
+        CHECK_EQ(dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
+                 E_SUCCESS);
       } else {
         lt_input2 = static_cast<dnnLayout_t>(input2_shape.GetCurLayout());
       }
@@ -276,14 +272,14 @@ class MklAddNOp : public OpKernel {
       bool input2_in_mkl_format = input2_shape.IsMklTensor();
       dnnDelete_F32(Eltwise);
       if (!input1_in_mkl_format || !input2_in_mkl_format) {
-         delete [] in_sizes;
-         delete [] in_strides;
+        delete[] in_sizes;
+        delete[] in_strides;
       }
       if (!input1_in_mkl_format) {
-         dnnLayoutDelete_F32(lt_input1);
+        dnnLayoutDelete_F32(lt_input1);
       }
       if (!input2_in_mkl_format) {
-         dnnLayoutDelete_F32(lt_input2);
+        dnnLayoutDelete_F32(lt_input2);
       }
     }
   } MklAddNOpContext;
@@ -315,45 +311,44 @@ class MklAddNOp : public OpKernel {
       GetMklShape(ctx, src2_idx, &src2_mkl_shape);
       bool input1_in_mkl_format = src1_mkl_shape.IsMklTensor();
       bool input2_in_mkl_format = src2_mkl_shape.IsMklTensor();
-      int src1_dims_size = input1_in_mkl_format?
-       src1_mkl_shape.GetDimension(): src1_tensor.dims();
-      int src2_dims_size = input2_in_mkl_format?
-       src2_mkl_shape.GetDimension(): src2_tensor.dims();
+      int src1_dims_size = input1_in_mkl_format ? src1_mkl_shape.GetDimension()
+                                                : src1_tensor.dims();
+      int src2_dims_size = input2_in_mkl_format ? src2_mkl_shape.GetDimension()
+                                                : src2_tensor.dims();
       // if the shapes of two tensors are not same raise op error
       TensorShape src1_shape, src2_shape;
       src1_shape = src1_tensor.shape();
       src2_shape = src2_tensor.shape();
-      if (!src1_shape.IsSameSize(src2_shape) ){
-       ctx->SetStatus( 
-            errors::InvalidArgument(
-            "Inputs to operation ", this->name(), " of type ", this->type_string(),
+      if (!src1_shape.IsSameSize(src2_shape){
+        ctx->SetStatus(errors::InvalidArgument(
+            "Inputs to operation ", this->name(), " of type ",
+            this->type_string(),
             " must have the same size and shape.  Input 0: ",
-            src1_shape.DebugString(), " != input 1: ",
-            src2_shape.DebugString()));
+            src1_shape.DebugString(),
+            " != input 1: ", src2_shape.DebugString()));
       }
 
       if (!input1_in_mkl_format && src1_dims_size == 0) {
-         Tensor* dst_tensor = nullptr;
-         MklShape mkl_shape_dst;
-         mkl_shape_dst.SetMklTensor(false);
-         AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
-         src1_tensor.shape(), mkl_shape_dst);
-         float user_i1 = (src1_tensor.scalar<T>()());
-         float user_i2 = (src2_tensor.scalar<T>()());
-         dst_tensor->scalar<T>()() =
-           std::plus<float>{}(user_i1, user_i2);
-         return;
-       }
+        Tensor* dst_tensor = nullptr;
+        MklShape mkl_shape_dst;
+        mkl_shape_dst.SetMklTensor(false);
+        AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+                                  src1_tensor.shape(), mkl_shape_dst);
+        float user_i1 = (src1_tensor.scalar<T>()());
+        float user_i2 = (src2_tensor.scalar<T>()());
+        dst_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
+        return;
+      }
 
       // If there is nothing to compute, return.
       if (!input1_in_mkl_format && !input2_in_mkl_format) {
         if (src1_tensor.shape().num_elements() == 0) {
-           Tensor* dst_tensor = nullptr;
-           MklShape mkl_shape_dst;
-           mkl_shape_dst.SetMklTensor(false);
-           AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
-           src1_tensor.shape(), mkl_shape_dst);
-           return;
+          Tensor* dst_tensor = nullptr;
+          MklShape mkl_shape_dst;
+          mkl_shape_dst.SetMklTensor(false);
+          AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+                                    src1_tensor.shape(), mkl_shape_dst);
+          return;
         }
       }
 
@@ -362,7 +357,7 @@ class MklAddNOp : public OpKernel {
       MklDnnData<T> src2(&cpu_engine);
       MklDnnData<T> dst(&cpu_engine);
 
-      int tmp_size = input1_in_mkl_format ? src2_dims_size: src1_dims_size;
+      int tmp_size = input1_in_mkl_format ? src2_dims_size : src1_dims_size;
       memory::dims dims(tmp_size);
       memory::dims strides(tmp_size);
       memory::desc md1({}, memory::data_undef, memory::format_undef);
@@ -392,21 +387,19 @@ class MklAddNOp : public OpKernel {
         md1 = src1_mkl_shape.GetMklLayout();
 
         memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
-        auto src1_tf_data_format = MklDnnDataFormatToTFDataFormat(
-                                    src1_mkl_data_format);
-        auto src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
-                                    src1_tf_data_format);
-        md2 = memory::desc(src2_dims, MklDnnType<T>(),
-                           src1_mkl_data_format);
+        auto src1_tf_data_format =
+            MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
+        auto src2_dims =
+            TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+        md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
       } else if (input2_in_mkl_format && !input1_in_mkl_format) {
         // Same comment as above.
         memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
-        auto src2_tf_data_format = MklDnnDataFormatToTFDataFormat(
-                                     src2_mkl_data_format);
-        auto src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
-                                    src2_tf_data_format);
-        md1 = memory::desc(src1_dims, MklDnnType<T>(),
-                           src2_mkl_data_format);
+        auto src2_tf_data_format =
+            MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
+        auto src1_dims =
+            TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+        md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
 
         md2 = src2_mkl_shape.GetMklLayout();
       } else {
@@ -480,20 +473,19 @@ class MklAddNOp : public OpKernel {
         output_mkl_shape.SetMklTensor(false);
         output_tf_shape = src1_tensor.shape();
       }
-      AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
-                                output_tf_shape, output_mkl_shape);
+      AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, output_tf_shape,
+                                output_mkl_shape);
       dst.SetUsrMemDataHandle(dst_tensor);
 
       // Create Sum op, and submit net for execution.
       net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
       stream(stream::kind::eager).submit(net).wait();
-    } catch (mkldnn::error &e) {
+    } catch (mkldnn::errore) {
       string error_msg = "Status: " + std::to_string(e.status) +
-                       ", message: " + string(e.message) +
-                       ", in file " + string(__FILE__) + ":" +
-                       std::to_string(__LINE__);
-      OP_REQUIRES_OK(ctx, errors::Aborted("Operation received an exception:",
-                                            error_msg));
+                         ", message: " + string(e.message) + ", in file " +
+                         string(__FILE__) + ":" + std::to_string(__LINE__);
+      OP_REQUIRES_OK(
+          ctx, errors::Aborted("Operation received an exception:", error_msg));
     }
   }
 };
index 896d56293303b06adb554cef7e2f3ef16a5a8eda..c46eabdde103913a712c3d058aa23a627d19f5ea 100644 (file)
@@ -17,13 +17,13 @@ limitations under the License.
 #ifdef INTEL_MKL
 #ifdef INTEL_MKL_DNN
 
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/util/tensor_format.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 
 #include "mkldnn.h"
 #include "mkldnn_types.h"
@@ -31,16 +31,14 @@ limitations under the License.
 #include "tensorflow/core/util/mkl_util.h"
 
 #include "mkldnn.hpp"
-using mkldnn::stream;
 using mkldnn::prop_kind;
 using mkldnn::softmax_forward;
+using mkldnn::stream;
 
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
-
-
 template <typename Device, typename T>
 class MklSoftmaxOp : public OpKernel {
  public:
@@ -60,11 +58,11 @@ class MklSoftmaxOp : public OpKernel {
       MklDnnShape src_mkl_shape;
       GetMklShape(context, src_idx, &src_mkl_shape);
 
-
       // src_dims is the dimenstion of src_tensor
       // dim of the dst will also be same as src_dims
-      auto src_tf_shape = src_mkl_shape.IsMklTensor() ?
-                          src_mkl_shape.GetTfShape() : src_tensor.shape();
+      auto src_tf_shape = src_mkl_shape.IsMklTensor()
+                              ? src_mkl_shape.GetTfShape()
+                              : src_tensor.shape();
       auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
       auto output_dims = src_dims;
 
@@ -77,10 +75,10 @@ class MklSoftmaxOp : public OpKernel {
       // construct input Tf layout. For TF layout, although input shape
       // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
       // layout
-      auto src_md = src_mkl_shape.IsMklTensor()
-                    ? src_mkl_shape.GetMklLayout()
-                    : memory::desc(src_dims, MklDnnType<T>(),
-                                         memory::format::nc);
+      auto src_md =
+          src_mkl_shape.IsMklTensor()
+              ? src_mkl_shape.GetMklLayout()
+              : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
 
       // src: setting memory descriptor and op memory descriptor
       // Basically following two functions maps the TF "src_tensor" to mkl
@@ -95,8 +93,8 @@ class MklSoftmaxOp : public OpKernel {
       int axis = 1;  // axis to which softmax will be applied
       auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
                                                     src.GetOpMemDesc(), axis);
-      auto softmax_fwd_pd = softmax_forward::primitive_desc(softmax_fwd_desc,
-                                                            cpu_engine);
+      auto softmax_fwd_pd =
+          softmax_forward::primitive_desc(softmax_fwd_desc, cpu_engine);
 
       // add: output
       Tensor* output_tensor = nullptr;
@@ -136,9 +134,9 @@ class MklSoftmaxOp : public OpKernel {
       net.push_back(softmax_fwd);
       stream(stream::kind::eager).submit(net).wait();
     } catch (mkldnn::error& e) {
-      string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
-                         string(e.message) + ", in file " + string(__FILE__) +
-                         ":" + std::to_string(__LINE__);
+      string error_msg = "Status: " + std::to_string(e.status) +
+                         ", message: " + string(e.message) + ", in file " +
+                         string(__FILE__) + ":" + std::to_string(__LINE__);
       OP_REQUIRES_OK(
           context,
           errors::Aborted("Operation received an exception:", error_msg));
@@ -148,7 +146,7 @@ class MklSoftmaxOp : public OpKernel {
 
 /* Register DNN kernels for supported operations and supported types - right now
  * it is only Softmax and f32 */
-#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type)             \
+#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type)          \
   REGISTER_KERNEL_BUILDER(Name("_MklSoftmax")                       \
                               .Device(DEVICE_CPU)                   \
                               .TypeConstraint<type>("T")            \
@@ -156,7 +154,6 @@ class MklSoftmaxOp : public OpKernel {
                           MklSoftmaxOp<CPUDevice, type>);
 TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
 
-
 }  // namespace tensorflow
 
 #endif  // INTEL_MKL_DNN
index bc30330d61c89a05096faa70793a56e2fa0f2fbe..872a6e9d1bcce09765d1531c5f2898b2badc66a7 100644 (file)
@@ -72,12 +72,12 @@ bool ReadRawFloatFileToComplexVector(
   while (offset < end) {
 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
     char arr[4];
-    for (int i = 0; i < kBytesPerValue; ++i ) {
+    for (int i = 0; i < kBytesPerValue; ++i) {
       arr[3 - i] = *(data_string.data() + offset + i);
     }
     memcpy(&real_out, arr, kBytesPerValue);
     offset += kBytesPerValue;
-    for (int i = 0; i < kBytesPerValue; ++i ) {
+    for (int i = 0; i < kBytesPerValue; ++i) {
       arr[3 - i] = *(data_string.data() + offset + i);
     }
     memcpy(&imag_out, arr, kBytesPerValue);
index 6594f7ee7ba24bc193011e350ea97c37b3d5ced4..5198df7e16e020f0ee19baa387ccae899e21499a 100644 (file)
@@ -89,17 +89,17 @@ struct Transpose<CPUDevice, T, conjugate> {
                                                        out);
         break;
       case 6:
-       internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
-                                                      out);
-       break;
+        internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
+                                                       out);
+        break;
       case 7:
-       internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
-                                                      out);
-       break;
+        internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
+                                                       out);
+        break;
       case 8:
         internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
-                                                      out);
-       break;
+                                                       out);
+        break;
       default:
         TransposeSimple<T, conjugate>(d, in, perm, out);
         break;
index 7d1650f05eff0f4806fce10c05321278a0150954..f6906b0f79b86910b5354bea420d00f62ff0caf8 100644 (file)
@@ -40,10 +40,10 @@ current_path = os.path.dirname(os.path.realpath(sys.argv[0]))
 
 parser = argparse.ArgumentParser()
 parser.add_argument(
-      '--log_dir',
-      type=str,
-      default=os.path.join(current_path, 'log'),
-      help='The log directory for TensorBoard summaries.')
+    '--log_dir',
+    type=str,
+    default=os.path.join(current_path, 'log'),
+    help='The log directory for TensorBoard summaries.')
 FLAGS, unparsed = parser.parse_known_args()
 
 # Create the directory for TensorBoard variables if there is not.
@@ -81,6 +81,7 @@ def read_data(filename):
     data = tf.compat.as_str(f.read(f.namelist()[0])).split()
   return data
 
+
 vocabulary = read_data(filename)
 print('Data size', len(vocabulary))
 
@@ -106,20 +107,22 @@ def build_dataset(words, n_words):
   reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
   return data, count, dictionary, reversed_dictionary
 
+
 # Filling 4 global variables:
 # data - list of codes (integers from 0 to vocabulary_size-1).
 #   This is the original text but words are replaced by their codes
 # count - map of words(strings) to count of occurrences
 # dictionary - map of words(strings) to their codes(integers)
 # reverse_dictionary - maps codes(integers) to words(strings)
-data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
-                                                            vocabulary_size)
+data, count, dictionary, reverse_dictionary = build_dataset(
+    vocabulary, vocabulary_size)
 del vocabulary  # Hint to reduce memory.
 print('Most common words (+UNK)', count[:5])
 print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])
 
 data_index = 0
 
+
 # Step 3: Function to generate a training batch for the skip-gram model.
 def generate_batch(batch_size, num_skips, skip_window):
   global data_index
@@ -149,28 +152,28 @@ def generate_batch(batch_size, num_skips, skip_window):
   data_index = (data_index + len(data) - span) % len(data)
   return batch, labels
 
+
 batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
 for i in range(8):
-  print(batch[i], reverse_dictionary[batch[i]],
-        '->', labels[i, 0], reverse_dictionary[labels[i, 0]])
+  print(batch[i], reverse_dictionary[batch[i]], '->', labels[i, 0],
+        reverse_dictionary[labels[i, 0]])
 
 # Step 4: Build and train a skip-gram model.
 
 batch_size = 128
 embedding_size = 128  # Dimension of the embedding vector.
-skip_window = 1       # How many words to consider left and right.
-num_skips = 2         # How many times to reuse an input to generate a label.
-num_sampled = 64      # Number of negative examples to sample.
+skip_window = 1  # How many words to consider left and right.
+num_skips = 2  # How many times to reuse an input to generate a label.
+num_sampled = 64  # Number of negative examples to sample.
 
 # We pick a random validation set to sample nearest neighbors. Here we limit the
 # validation samples to the words that have a low numeric ID, which by
 # construction are also the most frequent. These 3 variables are used only for
 # displaying model accuracy, they don't affect calculation.
-valid_size = 16     # Random set of words to evaluate similarity on.
+valid_size = 16  # Random set of words to evaluate similarity on.
 valid_window = 100  # Only pick dev samples in the head of the distribution.
 valid_examples = np.random.choice(valid_window, valid_size, replace=False)
 
-
 graph = tf.Graph()
 
 with graph.as_default():
@@ -192,8 +195,9 @@ with graph.as_default():
     # Construct the variables for the NCE loss
     with tf.name_scope('weights'):
       nce_weights = tf.Variable(
-          tf.truncated_normal([vocabulary_size, embedding_size],
-                              stddev=1.0 / math.sqrt(embedding_size)))
+          tf.truncated_normal(
+              [vocabulary_size, embedding_size],
+              stddev=1.0 / math.sqrt(embedding_size)))
     with tf.name_scope('biases'):
       nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
 
@@ -204,12 +208,13 @@ with graph.as_default():
   #   http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
   with tf.name_scope('loss'):
     loss = tf.reduce_mean(
-        tf.nn.nce_loss(weights=nce_weights,
-                       biases=nce_biases,
-                       labels=train_labels,
-                       inputs=embed,
-                       num_sampled=num_sampled,
-                       num_classes=vocabulary_size))
+        tf.nn.nce_loss(
+            weights=nce_weights,
+            biases=nce_biases,
+            labels=train_labels,
+            inputs=embed,
+            num_sampled=num_sampled,
+            num_classes=vocabulary_size))
 
   # Add the loss value as a scalar to summary.
   tf.summary.scalar('loss', loss)
@@ -221,8 +226,8 @@ with graph.as_default():
   # Compute the cosine similarity between minibatch examples and all embeddings.
   norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
   normalized_embeddings = embeddings / norm
-  valid_embeddings = tf.nn.embedding_lookup(
-      normalized_embeddings, valid_dataset)
+  valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings,
+                                            valid_dataset)
   similarity = tf.matmul(
       valid_embeddings, normalized_embeddings, transpose_b=True)
 
@@ -248,8 +253,8 @@ with tf.Session(graph=graph) as session:
 
   average_loss = 0
   for step in xrange(num_steps):
-    batch_inputs, batch_labels = generate_batch(
-        batch_size, num_skips, skip_window)
+    batch_inputs, batch_labels = generate_batch(batch_size, num_skips,
+                                                skip_window)
     feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}
 
     # Define metadata variable.
@@ -259,9 +264,12 @@ with tf.Session(graph=graph) as session:
     # in the list of returned values for session.run()
     # Also, evaluate the merged op to get all summaries from the returned "summary" variable.
     # Feed metadata variable to session for visualizing the graph in TensorBoard.
-    _, summary, loss_val = session.run([optimizer, merged, loss], feed_dict=feed_dict, run_metadata=run_metadata)
+    _, summary, loss_val = session.run(
+        [optimizer, merged, loss],
+        feed_dict=feed_dict,
+        run_metadata=run_metadata)
     average_loss += loss_val
-    
+
     # Add returned summaries to writer in each step.
     writer.add_summary(summary, step)
     # Add metadata to visualize the graph for the last run.
@@ -295,7 +303,7 @@ with tf.Session(graph=graph) as session:
       f.write(reverse_dictionary[i] + '\n')
 
   # Save the model for checkpoints.
-  saver.save(session, os.path.join(FLAGS.log_dir, "model.ckpt"))
+  saver.save(session, os.path.join(FLAGS.log_dir, 'model.ckpt'))
 
   # Create a configuration for visualizing embeddings with the labels in TensorBoard.
   config = projector.ProjectorConfig()
@@ -317,21 +325,24 @@ def plot_with_labels(low_dim_embs, labels, filename):
   for i, label in enumerate(labels):
     x, y = low_dim_embs[i, :]
     plt.scatter(x, y)
-    plt.annotate(label,
-                 xy=(x, y),
-                 xytext=(5, 2),
-                 textcoords='offset points',
-                 ha='right',
-                 va='bottom')
+    plt.annotate(
+        label,
+        xy=(x, y),
+        xytext=(5, 2),
+        textcoords='offset points',
+        ha='right',
+        va='bottom')
 
   plt.savefig(filename)
 
+
 try:
   # pylint: disable=g-import-not-at-top
   from sklearn.manifold import TSNE
   import matplotlib.pyplot as plt
 
-  tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')
+  tsne = TSNE(
+      perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')
   plot_only = 500
   low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
   labels = [reverse_dictionary[i] for i in xrange(plot_only)]
index eac1c1960df277a24f9ee8e1f72b3cbd050b13df..bd80b9dbf561de16168b05facf0086dadcda6444 100644 (file)
@@ -51,8 +51,9 @@ class BatchDatasetTest(test.TestCase):
     def _map_fn(x, y, z):
       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
 
-    iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
-                .repeat(count).batch(batch_size).make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+        .repeat(count).batch(batch_size).make_initializable_iterator())
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
@@ -68,7 +69,7 @@ class BatchDatasetTest(test.TestCase):
         result = sess.run(get_next)
         for component, result_component in zip(components, result):
           for j in range(14):
-            self.assertAllEqual(component[(i*14 + j) % 7]**2,
+            self.assertAllEqual(component[(i * 14 + j) % 7]**2,
                                 result_component[j])
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -83,12 +84,12 @@ class BatchDatasetTest(test.TestCase):
         result = sess.run(get_next)
         for component, result_component in zip(components, result):
           for j in range(8):
-            self.assertAllEqual(component[(i*8 + j) % 7]**2,
+            self.assertAllEqual(component[(i * 8 + j) % 7]**2,
                                 result_component[j])
       result = sess.run(get_next)
       for component, result_component in zip(components, result):
         for j in range((14 * 7) % 8):
-          self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
+          self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
                               result_component[j])
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -189,33 +190,34 @@ class BatchDatasetTest(test.TestCase):
         sess.run(get_next)
 
   def testBatchShapeError(self):
+
     def generator():
       yield [1.0, 2.0, 3.0]
       yield [4.0, 5.0, 6.0]
       yield [7.0, 8.0, 9.0, 10.0]
 
-    iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
-                                                   output_shapes=[None])
-                .batch(3)
-                .make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_generator(
+            generator, dtypes.float32, output_shapes=[None]).batch(3)
+        .make_initializable_iterator())
     next_element = iterator.get_next()
 
     with self.test_session() as sess:
       sess.run(iterator.initializer)
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
-          r"Cannot batch tensors with different shapes in component 0. "
-          r"First element had shape \[3\] and element 2 had shape \[4\]."):
+          r'Cannot batch tensors with different shapes in component 0. '
+          r'First element had shape \[3\] and element 2 had shape \[4\].'):
         sess.run(next_element)
 
   def testPaddedBatchDataset(self):
     seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
     padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
 
-    iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
-                .map(lambda x: array_ops.fill([x], x)).padded_batch(
-                    4,
-                    padded_shapes=padded_shape).make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(seq_lens)
+        .map(lambda x: array_ops.fill([x], x)).padded_batch(
+            4, padded_shapes=padded_shape).make_initializable_iterator())
 
     init_op = iterator.initializer
     get_next = iterator.get_next()
@@ -223,35 +225,40 @@ class BatchDatasetTest(test.TestCase):
     with self.test_session() as sess:
       # Test with random sequence lengths, and max padding.
       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
-      sess.run(init_op, feed_dict={padded_shape: [-1],
-                                   seq_lens: random_seq_lens})
+      sess.run(
+          init_op, feed_dict={
+              padded_shape: [-1],
+              seq_lens: random_seq_lens
+          })
       for i in range(8):
         result = sess.run(get_next)
         padded_len = np.max(result)
         self.assertEqual((4, padded_len), result.shape)
         for j in range(4):
-          seq_len = random_seq_lens[(i*4)+j]
+          seq_len = random_seq_lens[(i * 4) + j]
           self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
           self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
       # Test with random sequence lengths, and constant padding.
-      sess.run(init_op, feed_dict={padded_shape: [25],
-                                   seq_lens: random_seq_lens})
+      sess.run(
+          init_op, feed_dict={
+              padded_shape: [25],
+              seq_lens: random_seq_lens
+          })
       for i in range(8):
         result = sess.run(get_next)
         self.assertEqual((4, 25), result.shape)
         for j in range(4):
-          seq_len = random_seq_lens[(i*4)+j]
+          seq_len = random_seq_lens[(i * 4) + j]
           self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
           self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
       # Test correct handling of empty tensors.
-      sess.run(init_op, feed_dict={padded_shape: [-1],
-                                   seq_lens: [0, 0, 0, 0]})
+      sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
       result = sess.run(get_next)
       self.assertAllEqual([[], [], [], []], result)
       with self.assertRaises(errors.OutOfRangeError):
@@ -259,8 +266,7 @@ class BatchDatasetTest(test.TestCase):
 
       # Test error handling with constant sequence lengths, and
       # too-short padding.
-      sess.run(init_op, feed_dict={padded_shape: [5],
-                                   seq_lens: [6, 5, 5, 5]})
+      sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
       with self.assertRaises(errors.DataLossError):
         result = sess.run(get_next)
 
@@ -271,11 +277,13 @@ class BatchDatasetTest(test.TestCase):
     def fill_tuple(x):
       filled = array_ops.fill([x], x)
       return (filled, string_ops.as_string(filled))
-    iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
-                .padded_batch(
-                    4,
-                    padded_shapes=(padded_shape, padded_shape),
-                    padding_values=(-1, "<end>")).make_initializable_iterator())
+
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
+        .padded_batch(
+            4,
+            padded_shapes=(padded_shape, padded_shape),
+            padding_values=(-1, '<end>')).make_initializable_iterator())
 
     init_op = iterator.initializer
     get_next = iterator.get_next()
@@ -283,46 +291,46 @@ class BatchDatasetTest(test.TestCase):
     with self.test_session() as sess:
       # Test with random sequence lengths, and max padding.
       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
-      sess.run(init_op, feed_dict={padded_shape: [-1],
-                                   seq_lens: random_seq_lens})
+      sess.run(
+          init_op, feed_dict={
+              padded_shape: [-1],
+              seq_lens: random_seq_lens
+          })
       for i in range(8):
         result = sess.run(get_next)
         padded_len = np.max(result[0])
         self.assertEqual((4, padded_len), result[0].shape)
         self.assertEqual((4, padded_len), result[1].shape)
         for j in range(4):
-          seq_len = random_seq_lens[(i*4)+j]
+          seq_len = random_seq_lens[(i * 4) + j]
           self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
           self.assertAllEqual(result[0][j, seq_len:],
                               [-1] * (padded_len - seq_len))
           self.assertAllEqual(result[1][j, :seq_len],
                               [compat.as_bytes(str(seq_len))] * seq_len)
           self.assertAllEqual(result[1][j, seq_len:],
-                              [b"<end>"] * (padded_len - seq_len))
+                              [b'<end>'] * (padded_len - seq_len))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
   def testPaddedBatchDatasetUnicode(self):
     # See GitHub issue 16149
     def generator():
-      data = [
-          [u'Простой', u'тест', u'юникода'],
-          [u'никогда', u'не', u'бывает', u'простым']]
+      data = [[u'Простой', u'тест', u'юникода'],
+              [u'никогда', u'не', u'бывает', u'простым']]
 
       for seq in data:
         yield seq, [0, 1, 2, 3]
 
     dataset = dataset_ops.Dataset.from_generator(
-        generator,
-        (dtypes.string, dtypes.int32),
+        generator, (dtypes.string, dtypes.int32),
         (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
-    padded_dataset = dataset.padded_batch(2, padded_shapes=([None], [None]),
-                                          padding_values=('', 0))
+    padded_dataset = dataset.padded_batch(
+        2, padded_shapes=([None], [None]), padding_values=('', 0))
     with self.test_session() as sess:
       next_element = padded_dataset.make_one_shot_iterator().get_next()
       sess.run(next_element)
 
-
   def testPaddedBatchDatasetShapeSpecifications(self):
     int_placeholder = array_ops.placeholder(dtypes.int32)
     float_placeholder = array_ops.placeholder(dtypes.float32)
@@ -346,15 +354,16 @@ class BatchDatasetTest(test.TestCase):
                        constant_op.constant([-1, -1], dtype=dtypes.int64),
                        constant_op.constant([37], dtype=dtypes.int64)))
 
-    for dataset in [dynamic_padding_from_tensor_shapes,
-                    dynamic_padding_from_lists,
-                    dynamic_padding_from_lists_with_minus_one,
-                    dynamic_padding_from_tensors]:
+    for dataset in [
+        dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
+        dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
+    ]:
       self.assertEqual([None, None], dataset.output_shapes[0].as_list())
       self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
       self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
 
   def testPaddedBatchSparseError(self):
+
     def _map_fn(i):
       return sparse_tensor.SparseTensorValue(
           indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
@@ -363,5 +372,5 @@ class BatchDatasetTest(test.TestCase):
       _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
   test.main()
index b2de2e5015047ee0e64b8263aa73f207d852d7b6..f079e56b10ed484225d8f09c6eaf7cf85a02d12a 100644 (file)
@@ -74,7 +74,7 @@ def histogram_fixed_width_bins(values,
   ```
   """
   with ops.name_scope(name, 'histogram_fixed_width_bins',
-                      [values, value_range, nbins]) as scope:
+                      [values, value_range, nbins]):
     values = ops.convert_to_tensor(values, name='values')
     shape = array_ops.shape(values)
 
@@ -84,9 +84,10 @@ def histogram_fixed_width_bins(values,
     nbins_float = math_ops.cast(nbins, values.dtype)
 
     # Map tensor values that fall within value_range to [0, 1].
-    scaled_values = math_ops.truediv(values - value_range[0],
-                                     value_range[1] - value_range[0],
-                                     name='scaled_values')
+    scaled_values = math_ops.truediv(
+        values - value_range[0],
+        value_range[1] - value_range[0],
+        name='scaled_values')
 
     # map tensor values within the open interval value_range to {0,.., nbins-1},
     # values outside the open interval will be zero or less, or nbins or more.
@@ -138,5 +139,5 @@ def histogram_fixed_width(values,
   """
   with ops.name_scope(name, 'histogram_fixed_width',
                       [values, value_range, nbins]) as name:
-    return gen_math_ops._histogram_fixed_width(values, value_range, nbins,
-                                               dtype=dtype, name=name)
+    return gen_math_ops._histogram_fixed_width(  # pylint: disable=protected-access
+        values, value_range, nbins, dtype=dtype, name=name)
index 80ee09057581db7298562fc22b443f5ddee73ef8..a226ac81bb536934cd191872ffc1aca84925abc0 100644 (file)
@@ -36,7 +36,8 @@ class BinValuesFixedWidth(test.TestCase):
     values = []
     expected_bins = []
     with self.test_session():
-      bins = histogram_ops.histogram_fixed_width_bins(values, value_range, nbins=5)
+      bins = histogram_ops.histogram_fixed_width_bins(
+          values, value_range, nbins=5)
       self.assertEqual(dtypes.int32, bins.dtype)
       self.assertAllClose(expected_bins, bins.eval())
 
@@ -69,8 +70,7 @@ class BinValuesFixedWidth(test.TestCase):
     #   (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
     value_range = [0.0, 5.0]
     values = constant_op.constant(
-      [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]],
-      shape=(2, 3))
+        [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], shape=(2, 3))
     expected_bins = [[0, 0, 1], [2, 4, 4]]
     with self.test_session():
       bins = histogram_ops.histogram_fixed_width_bins(
@@ -140,8 +140,8 @@ class HistogramFixedWidthTest(test.TestCase):
       self.assertEqual(dtypes.int32, hist.dtype)
       self.assertAllClose(expected_bin_counts, hist.eval())
 
-      hist = histogram_ops.histogram_fixed_width(values, value_range,
-                                                 nbins=placeholder)
+      hist = histogram_ops.histogram_fixed_width(
+          values, value_range, nbins=placeholder)
       self.assertEquals(hist.shape.ndims, 1)
       self.assertIs(hist.shape[0].value, None)
       self.assertEqual(dtypes.int32, hist.dtype)
index b713c4471775ed814363e6e0d58e6470a8642ff5..76da3bed315b6584deb16c796121bbd6c5d36dab 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Implementation of image ops."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import os
-
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -28,7 +25,6 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_image_ops
 from tensorflow.python.ops import gen_nn_ops
@@ -38,7 +34,6 @@ from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.util.tf_export import tf_export
 
-
 ops.NotDifferentiable('RandomCrop')
 # TODO(b/31222613): This op may be differentiable, and there may be
 # latent bugs here.
@@ -110,8 +105,9 @@ def _ImageDimensions(image, rank):
   else:
     static_shape = image.get_shape().with_rank(rank).as_list()
     dynamic_shape = array_ops.unstack(array_ops.shape(image), rank)
-    return [s if s is not None else d
-            for s, d in zip(static_shape, dynamic_shape)]
+    return [
+        s if s is not None else d for s, d in zip(static_shape, dynamic_shape)
+    ]
 
 
 def _Check3DImage(image, require_static=True):
@@ -132,18 +128,19 @@ def _Check3DImage(image, require_static=True):
   try:
     image_shape = image.get_shape().with_rank(3)
   except ValueError:
-    raise ValueError("'image' (shape %s) must be three-dimensional." %
-                     image.shape)
+    raise ValueError(
+        "'image' (shape %s) must be three-dimensional." % image.shape)
   if require_static and not image_shape.is_fully_defined():
-    raise ValueError("'image' (shape %s) must be fully defined." %
-                     image_shape)
+    raise ValueError("'image' (shape %s) must be fully defined." % image_shape)
   if any(x == 0 for x in image_shape):
-    raise ValueError("all dims of 'image.shape' must be > 0: %s" %
-                     image_shape)
+    raise ValueError("all dims of 'image.shape' must be > 0: %s" % image_shape)
   if not image_shape.is_fully_defined():
-    return [check_ops.assert_positive(array_ops.shape(image),
-                                      ["all dims of 'image.shape' "
-                                       "must be > 0."])]
+    return [
+        check_ops.assert_positive(
+            array_ops.shape(image),
+            ["all dims of 'image.shape' "
+             'must be > 0.'])
+    ]
   else:
     return []
 
@@ -167,7 +164,7 @@ def _Assert3DImage(image):
       added that asserts the correct dynamic shape.
     """
   return control_flow_ops.with_dependencies(
-    _Check3DImage(image, require_static=False), image)
+      _Check3DImage(image, require_static=False), image)
 
 
 def _CheckAtLeast3DImage(image, require_static=True):
@@ -195,12 +192,15 @@ def _CheckAtLeast3DImage(image, require_static=True):
   if require_static and not image_shape.is_fully_defined():
     raise ValueError('\'image\' must be fully defined.')
   if any(x == 0 for x in image_shape):
-    raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
-                     image_shape)
+    raise ValueError(
+        'all dims of \'image.shape\' must be > 0: %s' % image_shape)
   if not image_shape.is_fully_defined():
-    return [check_ops.assert_positive(array_ops.shape(image),
-                                      ["all dims of 'image.shape' "
-                                       "must be > 0."])]
+    return [
+        check_ops.assert_positive(
+            array_ops.shape(image),
+            ["all dims of 'image.shape' "
+             'must be > 0.'])
+    ]
   else:
     return []
 
@@ -248,10 +248,11 @@ def random_flip_up_down(image, seed=None):
     image = _Assert3DImage(image)
     uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
     mirror_cond = math_ops.less(uniform_random, .5)
-    result = control_flow_ops.cond(mirror_cond,
-                                   lambda: array_ops.reverse(image, [0]),
-                                   lambda: image,
-                                   name=scope)
+    result = control_flow_ops.cond(
+        mirror_cond,
+        lambda: array_ops.reverse(image, [0]),
+        lambda: image,
+        name=scope)
     return fix_image_flip_shape(image, result)
 
 
@@ -279,10 +280,11 @@ def random_flip_left_right(image, seed=None):
     image = _Assert3DImage(image)
     uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
     mirror_cond = math_ops.less(uniform_random, .5)
-    result = control_flow_ops.cond(mirror_cond,
-                                   lambda: array_ops.reverse(image, [1]),
-                                   lambda: image,
-                                   name=scope)
+    result = control_flow_ops.cond(
+        mirror_cond,
+        lambda: array_ops.reverse(image, [1]),
+        lambda: image,
+        name=scope)
     return fix_image_flip_shape(image, result)
 
 
@@ -307,8 +309,8 @@ def flip_left_right(image):
   with ops.name_scope(None, 'flip_left_right', [image]) as scope:
     image = ops.convert_to_tensor(image, name='image')
     image = _Assert3DImage(image)
-    return fix_image_flip_shape(image,
-                                array_ops.reverse(image, [1], name=scope))
+    return fix_image_flip_shape(image, array_ops.reverse(
+        image, [1], name=scope))
 
 
 @tf_export('image.flip_up_down')
@@ -332,8 +334,8 @@ def flip_up_down(image):
   with ops.name_scope(None, 'flip_up_down', [image]) as scope:
     image = ops.convert_to_tensor(image, name='image')
     image = _Assert3DImage(image)
-    return fix_image_flip_shape(image,
-                                array_ops.reverse(image, [0], name=scope))
+    return fix_image_flip_shape(image, array_ops.reverse(
+        image, [0], name=scope))
 
 
 @tf_export('image.rot90')
@@ -356,19 +358,19 @@ def rot90(image, k=1, name=None):
     k = math_ops.mod(k, 4)
 
     def _rot90():
-      return array_ops.transpose(array_ops.reverse_v2(image, [1]),
-                                 [1, 0, 2])
+      return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2])
+
     def _rot180():
       return array_ops.reverse_v2(image, [0, 1])
+
     def _rot270():
-      return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]),
-                                  [1])
-    cases = [(math_ops.equal(k, 1), _rot90),
-             (math_ops.equal(k, 2), _rot180),
+      return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1])
+
+    cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180),
              (math_ops.equal(k, 3), _rot270)]
 
-    ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
-                                name=scope)
+    ret = control_flow_ops.case(
+        cases, default=lambda: image, exclusive=True, name=scope)
     ret.set_shape([None, None, image.get_shape()[2]])
     return ret
 
@@ -518,8 +520,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
         ]), [4, 2])
     padded = array_ops.pad(image, paddings)
 
-    padded_shape = [None if _is_tensor(i) else i
-                    for i in [batch, target_height, target_width, depth]]
+    padded_shape = [
+        None if _is_tensor(i) else i
+        for i in [batch, target_height, target_width, depth]
+    ]
     padded.set_shape(padded_shape)
 
     if not is_batch:
@@ -593,12 +597,13 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
     image = control_flow_ops.with_dependencies(assert_ops, image)
 
     cropped = array_ops.slice(
-        image,
-        array_ops.stack([0, offset_height, offset_width, 0]),
+        image, array_ops.stack([0, offset_height, offset_width, 0]),
         array_ops.stack([-1, target_height, target_width, -1]))
 
-    cropped_shape = [None if _is_tensor(i) else i
-                     for i in [batch, target_height, target_width, depth]]
+    cropped_shape = [
+        None if _is_tensor(i) else i
+        for i in [batch, target_height, target_width, depth]
+    ]
     cropped.set_shape(cropped_shape)
 
     if not is_batch:
@@ -663,8 +668,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
       target_height = control_flow_ops.with_dependencies(
           assert_ops, target_height)
     if _is_tensor(target_width):
-      target_width = control_flow_ops.with_dependencies(
-          assert_ops, target_width)
+      target_width = control_flow_ops.with_dependencies(assert_ops,
+                                                        target_width)
 
     def max_(x, y):
       if _is_tensor(x) or _is_tensor(y):
@@ -709,10 +714,12 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
     _, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4)
 
     assert_ops = []
-    assert_ops += _assert(equal_(resized_height, target_height), ValueError,
-                          'resized height is not correct.')
-    assert_ops += _assert(equal_(resized_width, target_width), ValueError,
-                          'resized width is not correct.')
+    assert_ops += _assert(
+        equal_(resized_height, target_height), ValueError,
+        'resized height is not correct.')
+    assert_ops += _assert(
+        equal_(resized_width, target_width), ValueError,
+        'resized width is not correct.')
 
     resized = control_flow_ops.with_dependencies(assert_ops, resized)
 
@@ -813,22 +820,17 @@ def resize_images(images,
       return images
 
     if method == ResizeMethod.BILINEAR:
-      images = gen_image_ops.resize_bilinear(images,
-                                             size,
-                                             align_corners=align_corners)
+      images = gen_image_ops.resize_bilinear(
+          images, size, align_corners=align_corners)
     elif method == ResizeMethod.NEAREST_NEIGHBOR:
-      images = gen_image_ops.resize_nearest_neighbor(images,
-                                                     size,
-                                                     align_corners=
-                                                     align_corners)
+      images = gen_image_ops.resize_nearest_neighbor(
+          images, size, align_corners=align_corners)
     elif method == ResizeMethod.BICUBIC:
-      images = gen_image_ops.resize_bicubic(images,
-                                            size,
-                                            align_corners=align_corners)
+      images = gen_image_ops.resize_bicubic(
+          images, size, align_corners=align_corners)
     elif method == ResizeMethod.AREA:
-      images = gen_image_ops.resize_area(images,
-                                         size,
-                                         align_corners=align_corners)
+      images = gen_image_ops.resize_area(
+          images, size, align_corners=align_corners)
     else:
       raise ValueError('Resize method is not implemented.')
 
@@ -869,8 +871,9 @@ def per_image_standardization(image):
     image = math_ops.cast(image, dtype=dtypes.float32)
     image_mean = math_ops.reduce_mean(image)
 
-    variance = (math_ops.reduce_mean(math_ops.square(image)) -
-                math_ops.square(image_mean))
+    variance = (
+        math_ops.reduce_mean(math_ops.square(image)) -
+        math_ops.square(image_mean))
     variance = gen_nn_ops.relu(variance)
     stddev = math_ops.sqrt(variance)
 
@@ -971,9 +974,8 @@ def adjust_brightness(image, delta):
     orig_dtype = image.dtype
     flt_image = convert_image_dtype(image, dtypes.float32)
 
-    adjusted = math_ops.add(flt_image,
-                            math_ops.cast(delta, dtypes.float32),
-                            name=name)
+    adjusted = math_ops.add(
+        flt_image, math_ops.cast(delta, dtypes.float32), name=name)
 
     return convert_image_dtype(adjusted, orig_dtype, saturate=True)
 
@@ -1012,9 +1014,8 @@ def adjust_contrast(images, contrast_factor):
     flt_images = convert_image_dtype(images, dtypes.float32)
 
     # pylint: disable=protected-access
-    adjusted = gen_image_ops._adjust_contrastv2(flt_images,
-                                                contrast_factor=contrast_factor,
-                                                name=name)
+    adjusted = gen_image_ops._adjust_contrastv2(
+        flt_images, contrast_factor=contrast_factor, name=name)
     # pylint: enable=protected-access
 
     return convert_image_dtype(adjusted, orig_dtype, saturate=True)
@@ -1061,10 +1062,10 @@ def adjust_gamma(image, gamma=1, gain=1):
       gamma = control_flow_ops.with_dependencies(assert_op, gamma)
 
     # scale = max(dtype) - min(dtype).
-    scale = constant_op.constant(image.dtype.limits[1] - image.dtype.limits[0],
-                                 dtype=dtypes.float32)
+    scale = constant_op.constant(
+        image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32)
     # According to the definition of gamma correction.
-    adjusted_img = (img / scale) ** gamma * scale * gain
+    adjusted_img = (img / scale)**gamma * scale * gain
 
     return adjusted_img
 
@@ -1195,9 +1196,8 @@ def grayscale_to_rgb(images, name=None):
   with ops.name_scope(name, 'grayscale_to_rgb', [images]) as name:
     images = ops.convert_to_tensor(images, name='images')
     rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0)
-    shape_list = (
-        [array_ops.ones(rank_1,
-                        dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)])
+    shape_list = ([array_ops.ones(rank_1, dtype=dtypes.int32)] +
+                  [array_ops.expand_dims(3, 0)])
     multiples = array_ops.concat(shape_list, 0)
     rgb = array_ops.tile(images, multiples, name=name)
     rgb.set_shape(images.get_shape()[:-1].concatenate([3]))
@@ -1393,8 +1393,7 @@ def decode_image(contents, channels=None, name=None):
       gif_channels = 0 if channels is None else channels
       good_channels = math_ops.logical_and(
           math_ops.not_equal(gif_channels, 1, name='check_gif_channels'),
-          math_ops.not_equal(gif_channels, 4, name='check_gif_channels')
-      )
+          math_ops.not_equal(gif_channels, 4, name='check_gif_channels'))
       channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
       assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
       with ops.control_dependencies([assert_channels]):
@@ -1417,8 +1416,8 @@ def decode_image(contents, channels=None, name=None):
     def _jpeg():
       """Decodes a jpeg image."""
       jpeg_channels = 0 if channels is None else channels
-      good_channels = math_ops.not_equal(jpeg_channels, 4,
-                                         name='check_jpeg_channels')
+      good_channels = math_ops.not_equal(
+          jpeg_channels, 4, name='check_jpeg_channels')
       channels_msg = ('Channels must be in (None, 0, 1, 3) when decoding JPEG '
                       'images')
       assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
@@ -1496,16 +1495,21 @@ def total_variation(images, name=None):
 
     # Calculate the total variation by taking the absolute value of the
     # pixel-differences and summing over the appropriate axis.
-    tot_var = (math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) +
-               math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis))
+    tot_var = (
+        math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) +
+        math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis))
 
   return tot_var
 
 
 @tf_export('image.sample_distorted_bounding_box')
-def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
-                                  seed2=None, min_object_covered=None,
-                                  aspect_ratio_range=None, area_range=None,
+def sample_distorted_bounding_box(image_size,
+                                  bounding_boxes,
+                                  seed=None,
+                                  seed2=None,
+                                  min_object_covered=None,
+                                  aspect_ratio_range=None,
+                                  area_range=None,
                                   max_attempts=None,
                                   use_image_if_no_bounding_boxes=None,
                                   name=None):
@@ -1521,10 +1525,12 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
   The output of this Op is a single bounding box that may be used to crop the
   original image. The output is returned as 3 tensors: `begin`, `size` and
   `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
-  image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+  image. The latter may be supplied to `tf.image.draw_bounding_boxes` to
+  visualize
   what the bounding box looks like.
 
-  Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+  Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`.
+  The
   bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
   height of the underlying image.
 
@@ -1552,23 +1558,27 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
   false and no bounding boxes are supplied, an error is raised.
 
   Args:
-    image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`.
+    image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
+      `int16`, `int32`, `int64`.
       1-D, containing `[height, width, channels]`.
     bounding_boxes: A `Tensor` of type `float32`.
       3-D with shape `[batch, N, 4]` describing the N bounding boxes
       associated with the image.
     seed: An optional `int`. Defaults to `0`.
       If either `seed` or `seed2` are set to non-zero, the random number
-      generator is seeded by the given `seed`.  Otherwise, it is seeded by a random
+      generator is seeded by the given `seed`.  Otherwise, it is seeded by a
+        random
       seed.
     seed2: An optional `int`. Defaults to `0`.
       A second seed to avoid seed collision.
     min_object_covered: A Tensor of type `float32`. Defaults to `0.1`.
       The cropped area of the image must contain at least this
-      fraction of any bounding box supplied. The value of this parameter should be
+      fraction of any bounding box supplied. The value of this parameter should
+        be
       non-negative. In the case of 0, the cropped area does not need to overlap
       any of the bounding boxes supplied.
-    aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, 1.33]`.
+    aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75,
+      1.33]`.
       The cropped area of the image must have an aspect ratio =
       width / height within this range.
     area_range: An optional list of `floats`. Defaults to `[0.05, 1]`.
@@ -1576,32 +1586,41 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
       supplied image within in this range.
     max_attempts: An optional `int`. Defaults to `100`.
       Number of attempts at generating a cropped region of the image
-      of the specified constraints. After `max_attempts` failures, return the entire
+      of the specified constraints. After `max_attempts` failures, return the
+        entire
       image.
     use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`.
       Controls behavior if no bounding boxes supplied.
-      If true, assume an implicit bounding box covering the whole input. If false,
+      If true, assume an implicit bounding box covering the whole input. If
+        false,
       raise an error.
     name: A name for the operation (optional).
 
   Returns:
     A tuple of `Tensor` objects (begin, size, bboxes).
 
-    begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+    begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+    `[offset_height, offset_width, 0]`. Provide as input to
       `tf.slice`.
-    size: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[target_height, target_width, -1]`. Provide as input to
+    size: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+    `[target_height, target_width, -1]`. Provide as input to
       `tf.slice`.
-    bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+    bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing
+    the distorted bounding box.
       Provide as input to `tf.image.draw_bounding_boxes`.
   """
   with ops.name_scope(name, 'sample_distorted_bounding_box'):
-    return gen_image_ops._sample_distorted_bounding_box_v2(image_size,
-                bounding_boxes, seed=seed,
-                seed2=seed2, min_object_covered=min_object_covered,
-                aspect_ratio_range=aspect_ratio_range, area_range=area_range,
-                max_attempts=max_attempts,
-                use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes,
-                name=name)
+    return gen_image_ops._sample_distorted_bounding_box_v2(  # pylint: disable=protected-access
+        image_size,
+        bounding_boxes,
+        seed=seed,
+        seed2=seed2,
+        min_object_covered=min_object_covered,
+        aspect_ratio_range=aspect_ratio_range,
+        area_range=area_range,
+        max_attempts=max_attempts,
+        use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes,
+        name=name)
 
 
 @tf_export('image.non_max_suppression')
index 2d77e260816bacbc1143acf5646bbe9533ac9844..7776ff08c4f55c43947010f313d8167596b15db7 100644 (file)
@@ -100,27 +100,29 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
     # Use dynamic rank.
     weights_rank_tensor = array_ops.rank(weights)
     rank_diff = weights_rank_tensor - array_ops.rank(predictions)
+
     def _maybe_expand_weights():
       return control_flow_ops.cond(
           math_ops.equal(rank_diff, -1),
-          lambda: array_ops.expand_dims(weights, [-1]),
-          lambda: weights)
+          lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
+
     # Don't attempt squeeze if it will fail based on static check.
     if ((weights_rank is not None) and
         (not weights_shape.dims[-1].is_compatible_with(1))):
       maybe_squeeze_weights = lambda: weights
     else:
       maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
+
     def _maybe_adjust_weights():
       return control_flow_ops.cond(
-          math_ops.equal(rank_diff, 1),
-          maybe_squeeze_weights,
+          math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
           _maybe_expand_weights)
+
     # If weights are scalar, do nothing. Otherwise, try to add or remove a
     # dimension to match predictions.
     weights = control_flow_ops.cond(
-        math_ops.equal(weights_rank_tensor, 0),
-        lambda: weights, _maybe_adjust_weights)
+        math_ops.equal(weights_rank_tensor, 0), lambda: weights,
+        _maybe_adjust_weights)
   return predictions, labels, weights
 
 
@@ -165,14 +167,14 @@ def _maybe_expand_labels(labels, predictions):
         if predictions_rank == labels_rank + 1:
           return array_ops.expand_dims(labels, -1, name=scope)
         raise ValueError(
-            'Unexpected labels shape %s for predictions shape %s.' % (
-                labels.get_shape(), predictions.get_shape()))
+            'Unexpected labels shape %s for predictions shape %s.' %
+            (labels.get_shape(), predictions.get_shape()))
 
     # Otherwise, use dynamic shape.
     return control_flow_ops.cond(
-        math_ops.equal(array_ops.rank(predictions), array_ops.rank(labels) + 1),
-        lambda: array_ops.expand_dims(labels, -1, name=scope),
-        lambda: labels)
+        math_ops.equal(array_ops.rank(predictions),
+                       array_ops.rank(labels) + 1),
+        lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
 
 
 def _safe_div(numerator, denominator, name):
@@ -264,8 +266,11 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
 
 
 @tf_export('metrics.mean')
-def mean(values, weights=None, metrics_collections=None,
-         updates_collections=None, name=None):
+def mean(values,
+         weights=None,
+         metrics_collections=None,
+         updates_collections=None,
+         name=None):
   """Computes the (weighted) mean of the given values.
 
   The `mean` function creates two local variables, `total` and `count`
@@ -340,8 +345,12 @@ def mean(values, weights=None, metrics_collections=None,
 
 
 @tf_export('metrics.accuracy')
-def accuracy(labels, predictions, weights=None, metrics_collections=None,
-             updates_collections=None, name=None):
+def accuracy(labels,
+             predictions,
+             weights=None,
+             metrics_collections=None,
+             updates_collections=None,
+             name=None):
   """Calculates how often `predictions` matches `labels`.
 
   The `accuracy` function creates two local variables, `total` and
@@ -395,12 +404,15 @@ def accuracy(labels, predictions, weights=None, metrics_collections=None,
   if labels.dtype != predictions.dtype:
     predictions = math_ops.cast(predictions, labels.dtype)
   is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
-  return mean(is_correct, weights, metrics_collections,
-              updates_collections, name or 'accuracy')
+  return mean(is_correct, weights, metrics_collections, updates_collections,
+              name or 'accuracy')
 
 
-def _confusion_matrix_at_thresholds(
-    labels, predictions, thresholds, weights=None, includes=None):
+def _confusion_matrix_at_thresholds(labels,
+                                    predictions,
+                                    thresholds,
+                                    weights=None,
+                                    includes=None):
   """Computes true_positives, false_negatives, true_negatives, false_positives.
 
   This function creates up to four local variables, `true_positives`,
@@ -498,8 +510,8 @@ def _confusion_matrix_at_thresholds(
   if weights is not None:
     weights = weights_broadcast_ops.broadcast_weights(
         math_ops.to_float(weights), predictions)
-    weights_tiled = array_ops.tile(array_ops.reshape(
-        weights, [1, -1]), [num_thresholds, 1])
+    weights_tiled = array_ops.tile(
+        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
     thresh_tiled.get_shape().assert_is_compatible_with(
         weights_tiled.get_shape())
   else:
@@ -515,8 +527,9 @@ def _confusion_matrix_at_thresholds(
         math_ops.logical_and(label_is_pos, pred_is_pos))
     if weights_tiled is not None:
       is_true_positive *= weights_tiled
-    update_ops['tp'] = state_ops.assign_add(
-        true_p, math_ops.reduce_sum(is_true_positive, 1))
+    update_ops['tp'] = state_ops.assign_add(true_p,
+                                            math_ops.reduce_sum(
+                                                is_true_positive, 1))
     values['tp'] = true_p
 
   if 'fn' in includes:
@@ -526,8 +539,9 @@ def _confusion_matrix_at_thresholds(
         math_ops.logical_and(label_is_pos, pred_is_neg))
     if weights_tiled is not None:
       is_false_negative *= weights_tiled
-    update_ops['fn'] = state_ops.assign_add(
-        false_n, math_ops.reduce_sum(is_false_negative, 1))
+    update_ops['fn'] = state_ops.assign_add(false_n,
+                                            math_ops.reduce_sum(
+                                                is_false_negative, 1))
     values['fn'] = false_n
 
   if 'tn' in includes:
@@ -537,8 +551,9 @@ def _confusion_matrix_at_thresholds(
         math_ops.logical_and(label_is_neg, pred_is_neg))
     if weights_tiled is not None:
       is_true_negative *= weights_tiled
-    update_ops['tn'] = state_ops.assign_add(
-        true_n, math_ops.reduce_sum(is_true_negative, 1))
+    update_ops['tn'] = state_ops.assign_add(true_n,
+                                            math_ops.reduce_sum(
+                                                is_true_negative, 1))
     values['tn'] = true_n
 
   if 'fp' in includes:
@@ -548,17 +563,24 @@ def _confusion_matrix_at_thresholds(
         math_ops.logical_and(label_is_neg, pred_is_pos))
     if weights_tiled is not None:
       is_false_positive *= weights_tiled
-    update_ops['fp'] = state_ops.assign_add(
-        false_p, math_ops.reduce_sum(is_false_positive, 1))
+    update_ops['fp'] = state_ops.assign_add(false_p,
+                                            math_ops.reduce_sum(
+                                                is_false_positive, 1))
     values['fp'] = false_p
 
   return values, update_ops
 
 
 @tf_export('metrics.auc')
-def auc(labels, predictions, weights=None, num_thresholds=200,
-        metrics_collections=None, updates_collections=None,
-        curve='ROC', name=None, summation_method='trapezoidal'):
+def auc(labels,
+        predictions,
+        weights=None,
+        num_thresholds=200,
+        metrics_collections=None,
+        updates_collections=None,
+        curve='ROC',
+        name=None,
+        summation_method='trapezoidal'):
   """Computes the approximate AUC via a Riemann sum.
 
   The `auc` function creates four local variables, `true_positives`,
@@ -626,14 +648,14 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
     raise RuntimeError('tf.metrics.auc is not supported when eager execution '
                        'is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'auc', (labels, predictions, weights)):
+  with variable_scope.variable_scope(name, 'auc',
+                                     (labels, predictions, weights)):
     if curve != 'ROC' and curve != 'PR':
-      raise ValueError('curve must be either ROC or PR, %s unknown' %
-                       (curve))
+      raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
     kepsilon = 1e-7  # to account for floating point imprecisions
-    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
-                  for i in range(num_thresholds-2)]
+    thresholds = [
+        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+    ]
     thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
 
     values, update_ops = _confusion_matrix_at_thresholds(
@@ -641,6 +663,7 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
 
     # Add epsilons to avoid dividing by 0.
     epsilon = 1.0e-6
+
     def compute_auc(tp, fn, tn, fp, name):
       """Computes the roc-auc or pr-auc based on confusion counts."""
       rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
@@ -671,11 +694,10 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
         raise ValueError('Invalid summation_method: %s' % summation_method)
 
     # sum up the areas of all the trapeziums
-    auc_value = compute_auc(
-        values['tp'], values['fn'], values['tn'], values['fp'], 'value')
-    update_op = compute_auc(
-        update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'],
-        'update_op')
+    auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+                            values['fp'], 'value')
+    update_op = compute_auc(update_ops['tp'], update_ops['fn'],
+                            update_ops['tn'], update_ops['fp'], 'update_op')
 
     if metrics_collections:
       ops.add_to_collections(metrics_collections, auc_value)
@@ -687,7 +709,9 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
 
 
 @tf_export('metrics.mean_absolute_error')
-def mean_absolute_error(labels, predictions, weights=None,
+def mean_absolute_error(labels,
+                        predictions,
+                        weights=None,
                         metrics_collections=None,
                         updates_collections=None,
                         name=None):
@@ -746,7 +770,10 @@ def mean_absolute_error(labels, predictions, weights=None,
 
 
 @tf_export('metrics.mean_cosine_distance')
-def mean_cosine_distance(labels, predictions, dim, weights=None,
+def mean_cosine_distance(labels,
+                         predictions,
+                         dim,
+                         weights=None,
                          metrics_collections=None,
                          updates_collections=None,
                          name=None):
@@ -802,10 +829,8 @@ def mean_cosine_distance(labels, predictions, dim, weights=None,
       radial_diffs, reduction_indices=[
           dim,
       ], keepdims=True)
-  mean_distance, update_op = mean(radial_diffs, weights,
-                                  None,
-                                  None,
-                                  name or 'mean_cosine_distance')
+  mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
+                                  'mean_cosine_distance')
   mean_distance = math_ops.subtract(1.0, mean_distance)
   update_op = math_ops.subtract(1.0, update_op)
 
@@ -906,8 +931,8 @@ def mean_per_class_accuracy(labels,
 
     per_class_accuracy = _safe_div(count, total, None)
 
-    mean_accuracy_v = math_ops.reduce_mean(per_class_accuracy,
-                                           name='mean_accuracy')
+    mean_accuracy_v = math_ops.reduce_mean(
+        per_class_accuracy, name='mean_accuracy')
     update_op = _safe_div(update_count_op, update_total_op, name='update_op')
 
     if metrics_collections:
@@ -975,13 +1000,14 @@ def mean_iou(labels,
     raise RuntimeError('tf.metrics.mean_iou is not supported when '
                        'eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'mean_iou', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'mean_iou',
+                                     (predictions, labels, weights)):
     # Check if shape is compatible.
     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
 
     total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
                                                       num_classes, weights)
+
     def compute_mean_iou(name):
       """Compute the mean intersection-over-union via the confusion matrix."""
       sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
@@ -992,22 +1018,21 @@ def mean_iou(labels,
       # The mean is only computed over classes that appear in the
       # label or prediction tensor. If the denominator is 0, we need to
       # ignore the class.
-      num_valid_entries = math_ops.reduce_sum(math_ops.cast(
-          math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
+      num_valid_entries = math_ops.reduce_sum(
+          math_ops.cast(
+              math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
 
       # If the value of the denominator is 0, set it to 1 to avoid
       # zero division.
       denominator = array_ops.where(
-          math_ops.greater(denominator, 0),
-          denominator,
+          math_ops.greater(denominator, 0), denominator,
           array_ops.ones_like(denominator))
       iou = math_ops.div(cm_diag, denominator)
 
       # If the number of valid entries is 0 (no classes) we return 0.
       result = array_ops.where(
           math_ops.greater(num_valid_entries, 0),
-          math_ops.reduce_sum(iou, name=name) / num_valid_entries,
-          0)
+          math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
       return result
 
     mean_iou_v = compute_mean_iou('mean_iou')
@@ -1022,7 +1047,10 @@ def mean_iou(labels,
 
 
 @tf_export('metrics.mean_relative_error')
-def mean_relative_error(labels, predictions, normalizer, weights=None,
+def mean_relative_error(labels,
+                        predictions,
+                        normalizer,
+                        weights=None,
                         metrics_collections=None,
                         updates_collections=None,
                         name=None):
@@ -1081,15 +1109,16 @@ def mean_relative_error(labels, predictions, normalizer, weights=None,
       predictions, normalizer)
   predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
   relative_errors = array_ops.where(
-      math_ops.equal(normalizer, 0.0),
-      array_ops.zeros_like(labels),
+      math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
       math_ops.div(math_ops.abs(labels - predictions), normalizer))
   return mean(relative_errors, weights, metrics_collections,
               updates_collections, name or 'mean_relative_error')
 
 
 @tf_export('metrics.mean_squared_error')
-def mean_squared_error(labels, predictions, weights=None,
+def mean_squared_error(labels,
+                       predictions,
+                       weights=None,
                        metrics_collections=None,
                        updates_collections=None,
                        name=None):
@@ -1143,13 +1172,16 @@ def mean_squared_error(labels, predictions, weights=None,
   predictions, labels, weights = _remove_squeezable_dimensions(
       predictions=predictions, labels=labels, weights=weights)
   squared_error = math_ops.square(labels - predictions)
-  return mean(squared_error, weights, metrics_collections,
-              updates_collections, name or 'mean_squared_error')
+  return mean(squared_error, weights, metrics_collections, updates_collections,
+              name or 'mean_squared_error')
 
 
 @tf_export('metrics.mean_tensor')
-def mean_tensor(values, weights=None, metrics_collections=None,
-                updates_collections=None, name=None):
+def mean_tensor(values,
+                weights=None,
+                metrics_collections=None,
+                updates_collections=None,
+                name=None):
   """Computes the element-wise (weighted) mean of the given tensors.
 
   In contrast to the `mean` function which returns a scalar with the
@@ -1216,9 +1248,8 @@ def mean_tensor(values, weights=None, metrics_collections=None,
       update_count_op = state_ops.assign_add(count, num_values)
 
     def compute_mean(total, count, name):
-      non_zero_count = math_ops.maximum(count,
-                                        array_ops.ones_like(count),
-                                        name=name)
+      non_zero_count = math_ops.maximum(
+          count, array_ops.ones_like(count), name=name)
       return math_ops.truediv(total, non_zero_count, name=name)
 
     mean_t = compute_mean(total, count, 'value')
@@ -1234,7 +1265,9 @@ def mean_tensor(values, weights=None, metrics_collections=None,
 
 
 @tf_export('metrics.percentage_below')
-def percentage_below(values, threshold, weights=None,
+def percentage_below(values,
+                     threshold,
+                     weights=None,
                      metrics_collections=None,
                      updates_collections=None,
                      name=None):
@@ -1281,14 +1314,13 @@ def percentage_below(values, threshold, weights=None,
                        'eager execution is enabled.')
 
   is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
-  return mean(is_below_threshold,
-              weights,
-              metrics_collections,
-              updates_collections,
-              name or 'percentage_below_threshold')
+  return mean(is_below_threshold, weights, metrics_collections,
+              updates_collections, name or 'percentage_below_threshold')
 
 
-def _count_condition(values, weights=None, metrics_collections=None,
+def _count_condition(values,
+                     weights=None,
+                     metrics_collections=None,
                      updates_collections=None):
   """Sums the weights of cases where the given values are True.
 
@@ -1318,8 +1350,8 @@ def _count_condition(values, weights=None, metrics_collections=None,
 
   values = math_ops.to_float(values)
   if weights is not None:
-    with ops.control_dependencies((
-        check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)):
+    with ops.control_dependencies((check_ops.assert_rank_in(
+        weights, (0, array_ops.rank(values))),)):
       weights = math_ops.to_float(weights)
       values = math_ops.multiply(values, weights)
 
@@ -1336,7 +1368,9 @@ def _count_condition(values, weights=None, metrics_collections=None,
 
 
 @tf_export('metrics.false_negatives')
-def false_negatives(labels, predictions, weights=None,
+def false_negatives(labels,
+                    predictions,
+                    weights=None,
                     metrics_collections=None,
                     updates_collections=None,
                     name=None):
@@ -1372,21 +1406,24 @@ def false_negatives(labels, predictions, weights=None,
     raise RuntimeError('tf.metrics.false_negatives is not supported when '
                        'eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'false_negatives', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'false_negatives',
+                                     (predictions, labels, weights)):
 
     predictions, labels, weights = _remove_squeezable_dimensions(
         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
         labels=math_ops.cast(labels, dtype=dtypes.bool),
         weights=weights)
-    is_false_negative = math_ops.logical_and(math_ops.equal(labels, True),
-                                             math_ops.equal(predictions, False))
+    is_false_negative = math_ops.logical_and(
+        math_ops.equal(labels, True), math_ops.equal(predictions, False))
     return _count_condition(is_false_negative, weights, metrics_collections,
                             updates_collections)
 
 
 @tf_export('metrics.false_negatives_at_thresholds')
-def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
+def false_negatives_at_thresholds(labels,
+                                  predictions,
+                                  thresholds,
+                                  weights=None,
                                   metrics_collections=None,
                                   updates_collections=None,
                                   name=None):
@@ -1440,7 +1477,9 @@ def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
 
 
 @tf_export('metrics.false_positives')
-def false_positives(labels, predictions, weights=None,
+def false_positives(labels,
+                    predictions,
+                    weights=None,
                     metrics_collections=None,
                     updates_collections=None,
                     name=None):
@@ -1477,21 +1516,24 @@ def false_positives(labels, predictions, weights=None,
     raise RuntimeError('tf.metrics.false_positives is not supported when '
                        'eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'false_positives', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'false_positives',
+                                     (predictions, labels, weights)):
 
     predictions, labels, weights = _remove_squeezable_dimensions(
         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
         labels=math_ops.cast(labels, dtype=dtypes.bool),
         weights=weights)
-    is_false_positive = math_ops.logical_and(math_ops.equal(labels, False),
-                                             math_ops.equal(predictions, True))
+    is_false_positive = math_ops.logical_and(
+        math_ops.equal(labels, False), math_ops.equal(predictions, True))
     return _count_condition(is_false_positive, weights, metrics_collections,
                             updates_collections)
 
 
 @tf_export('metrics.false_positives_at_thresholds')
-def false_positives_at_thresholds(labels, predictions, thresholds, weights=None,
+def false_positives_at_thresholds(labels,
+                                  predictions,
+                                  thresholds,
+                                  weights=None,
                                   metrics_collections=None,
                                   updates_collections=None,
                                   name=None):
@@ -1545,7 +1587,9 @@ def false_positives_at_thresholds(labels, predictions, thresholds, weights=None,
 
 
 @tf_export('metrics.true_negatives')
-def true_negatives(labels, predictions, weights=None,
+def true_negatives(labels,
+                   predictions,
+                   weights=None,
                    metrics_collections=None,
                    updates_collections=None,
                    name=None):
@@ -1582,21 +1626,24 @@ def true_negatives(labels, predictions, weights=None,
     raise RuntimeError('tf.metrics.true_negatives is not '
                        'supported when eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'true_negatives', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'true_negatives',
+                                     (predictions, labels, weights)):
 
     predictions, labels, weights = _remove_squeezable_dimensions(
         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
         labels=math_ops.cast(labels, dtype=dtypes.bool),
         weights=weights)
-    is_true_negative = math_ops.logical_and(math_ops.equal(labels, False),
-                                            math_ops.equal(predictions, False))
+    is_true_negative = math_ops.logical_and(
+        math_ops.equal(labels, False), math_ops.equal(predictions, False))
     return _count_condition(is_true_negative, weights, metrics_collections,
                             updates_collections)
 
 
 @tf_export('metrics.true_negatives_at_thresholds')
-def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
+def true_negatives_at_thresholds(labels,
+                                 predictions,
+                                 thresholds,
+                                 weights=None,
                                  metrics_collections=None,
                                  updates_collections=None,
                                  name=None):
@@ -1650,7 +1697,9 @@ def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
 
 
 @tf_export('metrics.true_positives')
-def true_positives(labels, predictions, weights=None,
+def true_positives(labels,
+                   predictions,
+                   weights=None,
                    metrics_collections=None,
                    updates_collections=None,
                    name=None):
@@ -1687,21 +1736,24 @@ def true_positives(labels, predictions, weights=None,
     raise RuntimeError('tf.metrics.true_positives is not '
                        'supported when eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'true_positives', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'true_positives',
+                                     (predictions, labels, weights)):
 
     predictions, labels, weights = _remove_squeezable_dimensions(
         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
         labels=math_ops.cast(labels, dtype=dtypes.bool),
         weights=weights)
-    is_true_positive = math_ops.logical_and(math_ops.equal(labels, True),
-                                            math_ops.equal(predictions, True))
+    is_true_positive = math_ops.logical_and(
+        math_ops.equal(labels, True), math_ops.equal(predictions, True))
     return _count_condition(is_true_positive, weights, metrics_collections,
                             updates_collections)
 
 
 @tf_export('metrics.true_positives_at_thresholds')
-def true_positives_at_thresholds(labels, predictions, thresholds, weights=None,
+def true_positives_at_thresholds(labels,
+                                 predictions,
+                                 thresholds,
+                                 weights=None,
                                  metrics_collections=None,
                                  updates_collections=None,
                                  name=None):
@@ -1755,8 +1807,11 @@ def true_positives_at_thresholds(labels, predictions, thresholds, weights=None,
 
 
 @tf_export('metrics.precision')
-def precision(labels, predictions, weights=None,
-              metrics_collections=None, updates_collections=None,
+def precision(labels,
+              predictions,
+              weights=None,
+              metrics_collections=None,
+              updates_collections=None,
               name=None):
   """Computes the precision of the predictions with respect to the labels.
 
@@ -1805,8 +1860,8 @@ def precision(labels, predictions, weights=None,
     raise RuntimeError('tf.metrics.precision is not '
                        'supported when eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'precision', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'precision',
+                                     (predictions, labels, weights)):
 
     predictions, labels, weights = _remove_squeezable_dimensions(
         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
@@ -1814,22 +1869,27 @@ def precision(labels, predictions, weights=None,
         weights=weights)
 
     true_p, true_positives_update_op = true_positives(
-        labels, predictions, weights, metrics_collections=None,
-        updates_collections=None, name=None)
+        labels,
+        predictions,
+        weights,
+        metrics_collections=None,
+        updates_collections=None,
+        name=None)
     false_p, false_positives_update_op = false_positives(
-        labels, predictions, weights, metrics_collections=None,
-        updates_collections=None, name=None)
+        labels,
+        predictions,
+        weights,
+        metrics_collections=None,
+        updates_collections=None,
+        name=None)
 
     def compute_precision(tp, fp, name):
       return array_ops.where(
-          math_ops.greater(tp + fp, 0),
-          math_ops.div(tp, tp + fp),
-          0,
-          name)
+          math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
 
     p = compute_precision(true_p, false_p, 'value')
-    update_op = compute_precision(
-        true_positives_update_op, false_positives_update_op, 'update_op')
+    update_op = compute_precision(true_positives_update_op,
+                                  false_positives_update_op, 'update_op')
 
     if metrics_collections:
       ops.add_to_collections(metrics_collections, p)
@@ -1841,10 +1901,13 @@ def precision(labels, predictions, weights=None,
 
 
 @tf_export('metrics.precision_at_thresholds')
-def precision_at_thresholds(labels, predictions, thresholds,
+def precision_at_thresholds(labels,
+                            predictions,
+                            thresholds,
                             weights=None,
                             metrics_collections=None,
-                            updates_collections=None, name=None):
+                            updates_collections=None,
+                            name=None):
   """Computes precision values for different `thresholds` on `predictions`.
 
   The `precision_at_thresholds` function creates four local variables,
@@ -1900,12 +1963,13 @@ def precision_at_thresholds(labels, predictions, thresholds,
 
     # Avoid division by zero.
     epsilon = 1e-7
+
     def compute_precision(tp, fp, name):
       return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
 
     prec = compute_precision(values['tp'], values['fp'], 'value')
-    update_op = compute_precision(
-        update_ops['tp'], update_ops['fp'], 'update_op')
+    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
+                                  'update_op')
 
     if metrics_collections:
       ops.add_to_collections(metrics_collections, prec)
@@ -1917,8 +1981,11 @@ def precision_at_thresholds(labels, predictions, thresholds,
 
 
 @tf_export('metrics.recall')
-def recall(labels, predictions, weights=None,
-           metrics_collections=None, updates_collections=None,
+def recall(labels,
+           predictions,
+           weights=None,
+           metrics_collections=None,
+           updates_collections=None,
            name=None):
   """Computes the recall of the predictions with respect to the labels.
 
@@ -1965,30 +2032,36 @@ def recall(labels, predictions, weights=None,
     raise RuntimeError('tf.metrics.recall is not supported is not '
                        'supported when eager execution is enabled.')
 
-  with variable_scope.variable_scope(
-      name, 'recall', (predictions, labels, weights)):
+  with variable_scope.variable_scope(name, 'recall',
+                                     (predictions, labels, weights)):
     predictions, labels, weights = _remove_squeezable_dimensions(
         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
         labels=math_ops.cast(labels, dtype=dtypes.bool),
         weights=weights)
 
     true_p, true_positives_update_op = true_positives(
-        labels, predictions, weights, metrics_collections=None,
-        updates_collections=None, name=None)
+        labels,
+        predictions,
+        weights,
+        metrics_collections=None,
+        updates_collections=None,
+        name=None)
     false_n, false_negatives_update_op = false_negatives(
-        labels, predictions, weights, metrics_collections=None,
-        updates_collections=None, name=None)
+        labels,
+        predictions,
+        weights,
+        metrics_collections=None,
+        updates_collections=None,
+        name=None)
 
     def compute_recall(true_p, false_n, name):
       return array_ops.where(
           math_ops.greater(true_p + false_n, 0),
-          math_ops.div(true_p, true_p + false_n),
-          0,
-          name)
+          math_ops.div(true_p, true_p + false_n), 0, name)
 
     rec = compute_recall(true_p, false_n, 'value')
-    update_op = compute_recall(
-        true_positives_update_op, false_negatives_update_op, 'update_op')
+    update_op = compute_recall(true_positives_update_op,
+                               false_negatives_update_op, 'update_op')
 
     if metrics_collections:
       ops.add_to_collections(metrics_collections, rec)
@@ -2022,8 +2095,8 @@ def _select_class_id(ids, selected_id):
   """
   ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
   if isinstance(ids, sparse_tensor.SparseTensor):
-    return sparse_ops.sparse_retain(
-        ids, math_ops.equal(ids.values, selected_id))
+    return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
+                                                        selected_id))
 
   # TODO(ptucker): Make this more efficient, maybe add a sparse version of
   # tf.equal and tf.reduce_any?
@@ -2031,12 +2104,13 @@ def _select_class_id(ids, selected_id):
   # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
   ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
   ids_last_dim = array_ops.size(ids_shape) - 1
-  filled_selected_id_shape = math_ops.reduced_shape(
-      ids_shape, array_ops.reshape(ids_last_dim, [1]))
+  filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
+                                                    array_ops.reshape(
+                                                        ids_last_dim, [1]))
 
   # Intersect `ids` with the selected ID.
-  filled_selected_id = array_ops.fill(
-      filled_selected_id_shape, math_ops.to_int64(selected_id))
+  filled_selected_id = array_ops.fill(filled_selected_id_shape,
+                                      math_ops.to_int64(selected_id))
   result = sets.set_intersection(filled_selected_id, ids)
   return sparse_tensor.SparseTensor(
       indices=result.indices, values=result.values, dense_shape=ids_shape)
@@ -2096,15 +2170,15 @@ def _sparse_true_positive_at_k(labels,
   Returns:
     A [D1, ... DN] `Tensor` of true positive counts.
   """
-  with ops.name_scope(
-      name, 'true_positives', (predictions_idx, labels, weights)):
-    labels, predictions_idx = _maybe_select_class_id(
-        labels, predictions_idx, class_id)
+  with ops.name_scope(name, 'true_positives',
+                      (predictions_idx, labels, weights)):
+    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
+                                                     class_id)
     tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
     tp = math_ops.to_double(tp)
     if weights is not None:
-      with ops.control_dependencies((
-          weights_broadcast_ops.assert_broadcastable(weights, tp),)):
+      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+          weights, tp),)):
         weights = math_ops.to_double(weights)
         tp = math_ops.multiply(tp, weights)
     return tp
@@ -2148,11 +2222,12 @@ def _streaming_sparse_true_positive_at_k(labels,
   Raises:
     ValueError: If `weights` is not `None` and has an incompatible shape.
   """
-  with ops.name_scope(
-      name, _at_k_name('true_positive', k, class_id=class_id),
-      (predictions_idx, labels, weights)) as scope:
+  with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
+                      (predictions_idx, labels, weights)) as scope:
     tp = _sparse_true_positive_at_k(
-        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+        predictions_idx=predictions_idx,
+        labels=labels,
+        class_id=class_id,
         weights=weights)
     batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
 
@@ -2189,18 +2264,16 @@ def _sparse_false_negative_at_k(labels,
   Returns:
     A [D1, ... DN] `Tensor` of false negative counts.
   """
-  with ops.name_scope(
-      None, 'false_negatives', (predictions_idx, labels, weights)):
-    labels, predictions_idx = _maybe_select_class_id(labels,
-                                                     predictions_idx,
+  with ops.name_scope(None, 'false_negatives',
+                      (predictions_idx, labels, weights)):
+    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
                                                      class_id)
-    fn = sets.set_size(sets.set_difference(predictions_idx,
-                                           labels,
-                                           aminusb=False))
+    fn = sets.set_size(
+        sets.set_difference(predictions_idx, labels, aminusb=False))
     fn = math_ops.to_double(fn)
     if weights is not None:
-      with ops.control_dependencies((
-          weights_broadcast_ops.assert_broadcastable(weights, fn),)):
+      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+          weights, fn),)):
         weights = math_ops.to_double(weights)
         fn = math_ops.multiply(fn, weights)
     return fn
@@ -2244,11 +2317,12 @@ def _streaming_sparse_false_negative_at_k(labels,
   Raises:
     ValueError: If `weights` is not `None` and has an incompatible shape.
   """
-  with ops.name_scope(
-      name, _at_k_name('false_negative', k, class_id=class_id),
-      (predictions_idx, labels, weights)) as scope:
+  with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
+                      (predictions_idx, labels, weights)) as scope:
     fn = _sparse_false_negative_at_k(
-        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+        predictions_idx=predictions_idx,
+        labels=labels,
+        class_id=class_id,
         weights=weights)
     batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
 
@@ -2335,9 +2409,8 @@ def recall_at_k(labels,
     raise RuntimeError('tf.metrics.recall_at_k is not '
                        'supported when eager execution is enabled.')
 
-  with ops.name_scope(
-      name, _at_k_name('recall', k, class_id=class_id),
-      (predictions, labels, weights)) as scope:
+  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
+                      (predictions, labels, weights)) as scope:
     _, top_k_idx = nn.top_k(predictions, k)
     return recall_at_top_k(
         labels=labels,
@@ -2404,16 +2477,21 @@ def recall_at_top_k(labels,
     `predictions`, or if either `metrics_collections` or `updates_collections`
     are not a list or tuple.
   """
-  with ops.name_scope(name,
-                      _at_k_name('recall', k, class_id=class_id),
+  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
                       (predictions_idx, labels, weights)) as scope:
     labels = _maybe_expand_labels(labels, predictions_idx)
     top_k_idx = math_ops.to_int64(predictions_idx)
     tp, tp_update = _streaming_sparse_true_positive_at_k(
-        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+        predictions_idx=top_k_idx,
+        labels=labels,
+        k=k,
+        class_id=class_id,
         weights=weights)
     fn, fn_update = _streaming_sparse_false_negative_at_k(
-        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+        predictions_idx=top_k_idx,
+        labels=labels,
+        k=k,
+        class_id=class_id,
         weights=weights)
 
     metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
@@ -2427,9 +2505,13 @@ def recall_at_top_k(labels,
 
 
 @tf_export('metrics.recall_at_thresholds')
-def recall_at_thresholds(labels, predictions, thresholds,
-                         weights=None, metrics_collections=None,
-                         updates_collections=None, name=None):
+def recall_at_thresholds(labels,
+                         predictions,
+                         thresholds,
+                         weights=None,
+                         metrics_collections=None,
+                         updates_collections=None,
+                         name=None):
   """Computes various recall values for different `thresholds` on `predictions`.
 
   The `recall_at_thresholds` function creates four local variables,
@@ -2483,6 +2565,7 @@ def recall_at_thresholds(labels, predictions, thresholds,
 
     # Avoid division by zero.
     epsilon = 1e-7
+
     def compute_recall(tp, fn, name):
       return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
 
@@ -2499,7 +2582,9 @@ def recall_at_thresholds(labels, predictions, thresholds,
 
 
 @tf_export('metrics.root_mean_squared_error')
-def root_mean_squared_error(labels, predictions, weights=None,
+def root_mean_squared_error(labels,
+                            predictions,
+                            weights=None,
                             metrics_collections=None,
                             updates_collections=None,
                             name=None):
@@ -2552,9 +2637,9 @@ def root_mean_squared_error(labels, predictions, weights=None,
 
   predictions, labels, weights = _remove_squeezable_dimensions(
       predictions=predictions, labels=labels, weights=weights)
-  mse, update_mse_op = mean_squared_error(
-      labels, predictions, weights, None, None,
-      name or 'root_mean_squared_error')
+  mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
+                                          None, name or
+                                          'root_mean_squared_error')
 
   rmse = math_ops.sqrt(mse)
   update_rmse_op = math_ops.sqrt(update_mse_op)
@@ -2569,9 +2654,14 @@ def root_mean_squared_error(labels, predictions, weights=None,
 
 
 @tf_export('metrics.sensitivity_at_specificity')
-def sensitivity_at_specificity(
-    labels, predictions, specificity, weights=None, num_thresholds=200,
-    metrics_collections=None, updates_collections=None, name=None):
+def sensitivity_at_specificity(labels,
+                               predictions,
+                               specificity,
+                               weights=None,
+                               num_thresholds=200,
+                               metrics_collections=None,
+                               updates_collections=None,
+                               name=None):
   """Computes the specificity at a given sensitivity.
 
   The `sensitivity_at_specificity` function creates four local
@@ -2632,8 +2722,9 @@ def sensitivity_at_specificity(
   with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
                                      (predictions, labels, weights)):
     kepsilon = 1e-7  # to account for floating point imprecisions
-    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
-                  for i in range(num_thresholds-2)]
+    thresholds = [
+        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+    ]
     thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
 
     values, update_ops = _confusion_matrix_at_thresholds(
@@ -2645,8 +2736,7 @@ def sensitivity_at_specificity(
       tf_index = math_ops.cast(tf_index, dtypes.int32)
 
       # Now, we have the implicit threshold, so compute the sensitivity:
-      return math_ops.div(tp[tf_index],
-                          tp[tf_index] + fn[tf_index] + kepsilon,
+      return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
                           name)
 
     sensitivity = compute_sensitivity_at_specificity(
@@ -2685,8 +2775,8 @@ def _expand_and_tile(tensor, multiple, dim=0, name=None):
   """
   if multiple < 1:
     raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
-  with ops.name_scope(
-      name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
+  with ops.name_scope(name, 'expand_and_tile',
+                      (tensor, multiple, dim)) as scope:
     # Sparse.
     tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
     if isinstance(tensor, sparse_tensor.SparseTensor):
@@ -2786,8 +2876,8 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx):
   Raises:
     ValueError: if the last dimension of predictions_idx is not set.
   """
-  with ops.name_scope(
-      None, 'average_precision', (predictions_idx, labels)) as scope:
+  with ops.name_scope(None, 'average_precision',
+                      (predictions_idx, labels)) as scope:
     predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
     if predictions_idx.get_shape().ndims == 0:
       raise ValueError('The rank of predictions_idx must be at least 1.')
@@ -2824,10 +2914,12 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx):
     retrieved_per_k = math_ops.cumsum(
         array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
     precision_per_k = math_ops.div(
-        math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
+        math_ops.to_double(tp_per_k),
+        math_ops.to_double(retrieved_per_k),
         name='precision_per_k')
     relevant_precision_per_k = math_ops.multiply(
-        precision_per_k, math_ops.to_double(relevant_per_k),
+        precision_per_k,
+        math_ops.to_double(relevant_per_k),
         name='relevant_precision_per_k')
 
     # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
@@ -3017,9 +3109,8 @@ def average_precision_at_k(labels,
 
   if k < 1:
     raise ValueError('Invalid k=%s.' % k)
-  with ops.name_scope(
-      name, _at_k_name('average_precision', k),
-      (predictions, labels, weights)) as scope:
+  with ops.name_scope(name, _at_k_name('average_precision', k),
+                      (predictions, labels, weights)) as scope:
     # Calculate top k indices to produce [D1, ... DN, k] tensor.
     _, predictions_idx = nn.top_k(predictions, k)
     return _streaming_sparse_average_precision_at_top_k(
@@ -3060,17 +3151,16 @@ def _sparse_false_positive_at_k(labels,
   Returns:
     A [D1, ... DN] `Tensor` of false positive counts.
   """
-  with ops.name_scope(
-      None, 'false_positives', (predictions_idx, labels, weights)):
-    labels, predictions_idx = _maybe_select_class_id(labels,
-                                                     predictions_idx,
+  with ops.name_scope(None, 'false_positives',
+                      (predictions_idx, labels, weights)):
+    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
                                                      class_id)
-    fp = sets.set_size(sets.set_difference(
-        predictions_idx, labels, aminusb=True))
+    fp = sets.set_size(
+        sets.set_difference(predictions_idx, labels, aminusb=True))
     fp = math_ops.to_double(fp)
     if weights is not None:
-      with ops.control_dependencies((
-          weights_broadcast_ops.assert_broadcastable(weights, fp),)):
+      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+          weights, fp),)):
         weights = math_ops.to_double(weights)
         fp = math_ops.multiply(fp, weights)
     return fp
@@ -3114,11 +3204,12 @@ def _streaming_sparse_false_positive_at_k(labels,
   Raises:
     ValueError: If `weights` is not `None` and has an incompatible shape.
   """
-  with ops.name_scope(
-      name, _at_k_name('false_positive', k, class_id=class_id),
-      (predictions_idx, labels, weights)) as scope:
+  with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
+                      (predictions_idx, labels, weights)) as scope:
     fp = _sparse_false_positive_at_k(
-        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+        predictions_idx=predictions_idx,
+        labels=labels,
+        class_id=class_id,
         weights=weights)
     batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
 
@@ -3190,10 +3281,16 @@ def precision_at_top_k(labels,
     labels = _maybe_expand_labels(labels, predictions_idx)
     top_k_idx = math_ops.to_int64(predictions_idx)
     tp, tp_update = _streaming_sparse_true_positive_at_k(
-        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+        predictions_idx=top_k_idx,
+        labels=labels,
+        k=k,
+        class_id=class_id,
         weights=weights)
     fp, fp_update = _streaming_sparse_false_positive_at_k(
-        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+        predictions_idx=top_k_idx,
+        labels=labels,
+        k=k,
+        class_id=class_id,
         weights=weights)
 
     metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
@@ -3323,9 +3420,14 @@ def precision_at_k(labels,
 
 
 @tf_export('metrics.specificity_at_sensitivity')
-def specificity_at_sensitivity(
-    labels, predictions, sensitivity, weights=None, num_thresholds=200,
-    metrics_collections=None, updates_collections=None, name=None):
+def specificity_at_sensitivity(labels,
+                               predictions,
+                               sensitivity,
+                               weights=None,
+                               num_thresholds=200,
+                               metrics_collections=None,
+                               updates_collections=None,
+                               name=None):
   """Computes the specificity at a given sensitivity.
 
   The `specificity_at_sensitivity` function creates four local
@@ -3386,8 +3488,9 @@ def specificity_at_sensitivity(
   with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
                                      (predictions, labels, weights)):
     kepsilon = 1e-7  # to account for floating point imprecisions
-    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
-                  for i in range(num_thresholds-2)]
+    thresholds = [
+        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+    ]
     thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
 
     values, update_ops = _confusion_matrix_at_thresholds(
@@ -3419,8 +3522,7 @@ def specificity_at_sensitivity(
       tf_index = math_ops.cast(tf_index, dtypes.int32)
 
       # Now, we have the implicit threshold, so compute the specificity:
-      return math_ops.div(tn[tf_index],
-                          tn[tf_index] + fp[tf_index] + kepsilon,
+      return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
                           name)
 
     specificity = compute_specificity_at_sensitivity(
index 837ee02e64ce3c4c8fd36e503562bdb75b4f5976..3268fd0e0ac312dc8a15e9fef8a14f540dcb55e1 100644 (file)
@@ -196,9 +196,12 @@ def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
       targets * -log(sigmoid(logits)) +
           (1 - targets) * -log(1 - sigmoid(logits))
 
-  A value `pos_weights > 1` decreases the false negative count, hence increasing the recall.
-  Conversely setting `pos_weights < 1` decreases the false positive count and increases the precision.
-  This can be seen from the fact that `pos_weight` is introduced as a multiplicative coefficient for the positive targets term 
+  A value `pos_weights > 1` decreases the false negative count, hence increasing
+  the recall.
+  Conversely setting `pos_weights < 1` decreases the false positive count and
+  increases the precision.
+  This can be seen from the fact that `pos_weight` is introduced as a
+  multiplicative coefficient for the positive targets term
   in the loss expression:
 
       targets * -log(sigmoid(logits)) * pos_weight +
@@ -646,9 +649,12 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
 
 
 @tf_export("nn.moments")
-def moments(x, axes,
-            shift=None,  # pylint: disable=unused-argument
-            name=None, keep_dims=False):
+def moments(
+    x,
+    axes,
+    shift=None,  # pylint: disable=unused-argument
+    name=None,
+    keep_dims=False):
   """Calculate the mean and variance of `x`.
 
   The mean and variance are calculated by aggregating the contents of `x`
@@ -692,8 +698,8 @@ def moments(x, axes,
       mean = array_ops.squeeze(mean, axes)
       variance = array_ops.squeeze(variance, axes)
     if x.dtype == dtypes.float16:
-      return (math_ops.cast(mean, dtypes.float16), math_ops.cast(
-          variance, dtypes.float16))
+      return (math_ops.cast(mean, dtypes.float16),
+              math_ops.cast(variance, dtypes.float16))
     else:
       return (mean, variance)
 
@@ -824,8 +830,8 @@ def batch_normalization(x,
     inv = math_ops.rsqrt(variance + variance_epsilon)
     if scale is not None:
       inv *= scale
-    return x * inv + (offset - mean * inv
-                      if offset is not None else -mean * inv)
+    return x * inv + (
+        offset - mean * inv if offset is not None else -mean * inv)
 
 
 @tf_export("nn.fused_batch_norm")
index 676756402442969b56f0b21f8ed5ec6a79738055..5a45bdc1e5e1d38a34176ed9443fcd1713f38e1e 100644 (file)
@@ -131,8 +131,7 @@ class LogPoissonLossTest(test_lib.TestCase):
     y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False)
     y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True)
     y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False)
-    y_tf_stirling = nn_impl.log_poisson_loss(
-        z_np, x_np, compute_full_loss=True)
+    y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True)
     y_tf_np = self.evaluate(y_tf)
     y_tf_np_stirling = self.evaluate(y_tf_stirling)
     eps = 1e-3
@@ -773,8 +772,8 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
     def _SoftmaxCrossEntropyWithLogits(logits, targets):
       # logits, targets: float arrays of the same shape.
       assert logits.shape == targets.shape
-      stable_exp_logits = np.exp(logits - np.amax(
-          logits, axis=1, keepdims=True))
+      stable_exp_logits = np.exp(
+          logits - np.amax(logits, axis=1, keepdims=True))
       pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
       return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
 
@@ -865,8 +864,8 @@ class LeakyReluTest(test_lib.TestCase):
     batch_size = 3
     height, width = 4, 4
     np.random.seed(1)  # Make it reproducible.
-    inputs = np.random.uniform(
-        size=(batch_size, height, width, 3)).astype(np.float32)
+    inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype(
+        np.float32)
     inputs = constant_op.constant(inputs)
 
     outputs = nn_ops.leaky_relu(inputs)
@@ -884,7 +883,8 @@ class LeakyReluTest(test_lib.TestCase):
       with self.test_session() as sess:
         outputs = sess.run(outputs)
       tol = 2e-3 if dtype == np.float16 else 1e-6
-      self.assertAllClose(outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
+      self.assertAllClose(
+          outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
 
 
 class SwishTest(test_lib.TestCase):
@@ -915,7 +915,10 @@ class SwishTest(test_lib.TestCase):
 
 class MomentsTest(test_lib.TestCase):
 
-  def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
+  def doOutputTest(self,
+                   input_shape,
+                   moments_axes,
+                   tol=1e-4,
                    check_gradients=False):
     for mu in [0.0, 1.0, 1e3]:
       for sigma in [1.0, 0.1]:
index 3ab0bd16fae34cc2b481cf811cc2790cae18a351..270d96a3c7c831d8c06dd86199cf2dc5dfc43421 100644 (file)
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Functions for Python 2 vs. 3 compatibility.
 
 ## Conversion routines
@@ -118,7 +117,7 @@ def path_to_str(path):
   Returns:
     A `str` object.
   """
-  if hasattr(path, "__fspath__"):
+  if hasattr(path, '__fspath__'):
     path = as_str_any(path.__fspath__())
   return path
 
@@ -129,11 +128,9 @@ integral_types = (_numbers.Integral, _np.integer)
 real_types = (_numbers.Real, _np.integer, _np.floating)
 complex_types = (_numbers.Complex, _np.number)
 
-
 # Either bytes or text.
 bytes_or_text_types = (bytes, _six.text_type)
 
-
 _allowed_symbols = [
     'as_str',
     'bytes_or_text_types',
index 8eee489e2d0f3074aa4c4dd32c452b850063d3ab..38a900738786e2413f5b1dd914caaebeafc92e21 100644 (file)
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """This pip smoke test verifies dependency files exist in the pip package.
 
 This script runs bazel queries to see what python files are required by the
@@ -26,13 +25,12 @@ from __future__ import print_function
 import os
 import subprocess
 
+os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
 
-os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
-
-
-PIP_PACKAGE_QUERY_EXPRESSION = \
-  'deps(//tensorflow/tools/pip_package:build_pip_package)'
+PIP_PACKAGE_QUERY_EXPRESSION = (
+    "deps(//tensorflow/tools/pip_package:build_pip_package)")
 
+# pylint: disable=g-backslash-continuation
 PY_TEST_QUERY_EXPRESSION = 'deps(\
   filter("^((?!benchmark).)*$",\
   kind(py_test,\
@@ -40,6 +38,7 @@ PY_TEST_QUERY_EXPRESSION = 'deps(\
   + //tensorflow/contrib/... \
   - //tensorflow/contrib/tensorboard/... \
   - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'
+# pylint: enable=g-backslash-continuation
 
 # Hard-coded blacklist of files if not included in pip package
 # TODO(amitpatankar): Clean up blacklist.
@@ -90,15 +89,15 @@ def main():
   """
 
   # pip_package_dependencies_list is the list of included files in pip packages
-  pip_package_dependencies = subprocess.check_output([
-      'bazel', 'query', PIP_PACKAGE_QUERY_EXPRESSION])
+  pip_package_dependencies = subprocess.check_output(
+      ["bazel", "query", PIP_PACKAGE_QUERY_EXPRESSION])
   pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
   print("Pip package superset size: %d" % len(pip_package_dependencies_list))
 
   # tf_py_test_dependencies is the list of dependencies for all python
   # tests in tensorflow
-  tf_py_test_dependencies = subprocess.check_output([
-      'bazel', 'query', PY_TEST_QUERY_EXPRESSION])
+  tf_py_test_dependencies = subprocess.check_output(
+      ["bazel", "query", PY_TEST_QUERY_EXPRESSION])
   tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
   print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
 
@@ -119,8 +118,7 @@ def main():
 
       # Check if the dependency is in the pip package, the blacklist, or
       # should be ignored because of its file extension
-      if not (ignore or
-              dependency in pip_package_dependencies_list or
+      if not (ignore or dependency in pip_package_dependencies_list or
               dependency in BLACKLIST):
         missing_dependencies.append(dependency)
 
@@ -131,9 +129,9 @@ def main():
     for missing_dependency in missing_dependencies:
       print("\nMissing dependency: %s " % missing_dependency)
       print("Affected Tests:")
-      rdep_query = 'rdeps(kind(py_test, \
-      //tensorflow/python/...), %s)' % missing_dependency
-      affected_tests = subprocess.check_output(['bazel', 'query', rdep_query])
+      rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" %
+                    missing_dependency)
+      affected_tests = subprocess.check_output(["bazel", "query", rdep_query])
       affected_tests_list = affected_tests.split("\n")[:-2]
       print("\n".join(affected_tests_list))
 
@@ -145,5 +143,6 @@ or add them to //tensorflow/tools/pip_package/BUILD.""")
   else:
     print("TEST PASSED")
 
+
 if __name__ == "__main__":
   main()