Extend block sparsity support for TPUs
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 May 2018 17:49:26 +0000 (10:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 7 May 2018 23:37:48 +0000 (16:37 -0700)
PiperOrigin-RevId: 195685740

tensorflow/contrib/model_pruning/python/pruning.py
tensorflow/contrib/model_pruning/python/pruning_utils.py
tensorflow/contrib/model_pruning/python/pruning_utils_test.py

index ea6032e..4b7af18 100644 (file)
@@ -396,14 +396,19 @@ class Pruning(object):
                        self._block_pooling_function)
 
     with ops.name_scope(weights.op.name + '_pruning_ops'):
-      abs_weights = math_ops.abs(
-          array_ops.reshape(weights, [
-              1,
-              squeezed_weights.get_shape()[0],
-              squeezed_weights.get_shape()[1], 1
-          ]))
+      abs_weights = math_ops.abs(squeezed_weights)
+
       pool_window = [self._block_dim[0], self._block_dim[1]]
-      pooled_weights = nn_ops.pool(
+      pool_fn = pruning_utils.factorized_pool
+
+      if not self._spec.use_tpu:
+        pool_fn = nn_ops.pool
+        abs_weights = array_ops.reshape(
+            abs_weights,
+            [1, abs_weights.get_shape()[0],
+             abs_weights.get_shape()[1], 1])
+
+      pooled_weights = pool_fn(
           abs_weights,
           window_shape=pool_window,
           pooling_type=self._block_pooling_function,
@@ -411,19 +416,18 @@ class Pruning(object):
           padding='SAME',
           name=weights.op.name + '_pooled')
 
+      if pooled_weights.get_shape().ndims != 2:
+        pooled_weights = array_ops.squeeze(pooled_weights)
+
       smoothed_threshold, new_mask = self._update_mask(pooled_weights,
                                                        threshold)
-
-      reshaped_mask = array_ops.reshape(
-          new_mask,
-          [pooled_weights.get_shape()[1],
-           pooled_weights.get_shape()[2]])
       updated_mask = pruning_utils.kronecker_product(
-          reshaped_mask, array_ops.ones(self._block_dim))
+          new_mask, array_ops.ones(self._block_dim))
       sliced_mask = array_ops.slice(
           updated_mask, [0, 0],
           [squeezed_weights.get_shape()[0],
            squeezed_weights.get_shape()[1]])
+
     return smoothed_threshold, array_ops.reshape(sliced_mask,
                                                  array_ops.shape(weights))
 
index 56d3dce..ef6c6a3 100644 (file)
@@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 
@@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs):
     return math_ops.div(cdf, math_ops.reduce_max(cdf))
 
 
+def factorized_pool(input_tensor,
+                    window_shape,
+                    pooling_type,
+                    strides,
+                    padding,
+                    name=None):
+  """Performs m x n pooling through a combination of 1xm and 1xn pooling.
+
+  Args:
+    input_tensor: Input tensor. Must be rank 2
+    window_shape: Pooling window shape
+    pooling_type: Either 'MAX' or 'AVG'
+    strides: The stride of the pooling window
+    padding: 'SAME' or 'VALID'.
+    name: Name of the op
+
+  Returns:
+    A rank 2 tensor containing the pooled output
+
+  Raises:
+    ValueError: if the input tensor is not rank 2
+  """
+  if input_tensor.get_shape().ndims != 2:
+    raise ValueError('factorized_pool() accepts tensors of rank 2 only')
+
+  [height, width] = input_tensor.get_shape()
+  with ops.name_scope(name, 'factorized_pool'):
+    input_tensor_aligned = array_ops.reshape(
+        input_tensor, [1, 1, height, width],
+        name=input_tensor.op.name + '_aligned')
+
+    height_pooling = nn_ops.pool(
+        input_tensor_aligned,
+        window_shape=[1, window_shape[0]],
+        pooling_type=pooling_type,
+        strides=[1, strides[0]],
+        padding=padding)
+    swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])
+
+    width_pooling = nn_ops.pool(
+        swap_height_width,
+        window_shape=[1, window_shape[1]],
+        pooling_type=pooling_type,
+        strides=[1, strides[1]],
+        padding=padding)
+
+  return array_ops.squeeze(
+      array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))
+
+
 def determine_partitioned_axis(partitioned_variable):
   partitioned_axis = 0
   concatenated_variable_shape = partitioned_variable.get_shape()
index 10e1dd0..ccde5b4 100644 (file)
@@ -22,8 +22,10 @@ import numpy as np
 
 from tensorflow.contrib.model_pruning.python import pruning_utils
 from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -31,6 +33,30 @@ from tensorflow.python.platform import test
 
 class PruningUtilsTest(test.TestCase):
 
+  def _compare_cdf(self, values):
+    abs_values = math_ops.abs(values)
+    max_value = math_ops.reduce_max(abs_values)
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
+          abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
+      cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
+      self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
+
+  def _compare_pooling_methods(self, weights, pooling_kwargs):
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      pooled_weights_tf = array_ops.squeeze(
+          nn_ops.pool(
+              array_ops.reshape(
+                  weights,
+                  [1, weights.get_shape()[0],
+                   weights.get_shape()[1], 1]), **pooling_kwargs))
+      pooled_weights_factorized_pool = pruning_utils.factorized_pool(
+          weights, **pooling_kwargs)
+      self.assertAllClose(pooled_weights_tf.eval(),
+                          pooled_weights_factorized_pool.eval())
+
   def testHistogram(self):
     width = 10
     height = 10
@@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase):
       self.assertAllEqual(len(norm_cdf_val), nbins)
       self.assertAllEqual(expected_cdf, norm_cdf_val)
 
-  def _compare_cdf(self, values):
-    abs_values = math_ops.abs(values)
-    max_value = math_ops.reduce_max(abs_values)
-    with self.test_session():
-      variables.global_variables_initializer().run()
-      cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
-          abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
-      cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
-      return cdf.eval(), cdf_from_histogram.eval()
-
   def testCDFEquivalence2D(self):
     width = 100
     height = 100
     weights = variable_scope.get_variable("weights", shape=[width, height])
-    cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
-    self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+    self._compare_cdf(weights)
 
   def testCDFEquivalence4D(self):
     weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
-    cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
-    self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+    self._compare_cdf(weights)
+
+  def testFactorizedAvgPool(self):
+    weights = variable_scope.get_variable("weights", shape=[1024, 2048])
+    pooling_kwargs = {
+        "window_shape": [2, 4],
+        "pooling_type": "AVG",
+        "strides": [2, 4],
+        "padding": "SAME"
+    }
+    self._compare_pooling_methods(weights, pooling_kwargs)
+
+  def testFactorizedMaxPool(self):
+    weights = variable_scope.get_variable("weights", shape=[1024, 2048])
+    pooling_kwargs = {
+        "window_shape": [2, 4],
+        "pooling_type": "MAX",
+        "strides": [2, 4],
+        "padding": "SAME"
+    }
+    self._compare_pooling_methods(weights, pooling_kwargs)
 
 
 if __name__ == "__main__":