From: A. Unique TensorFlower Date: Mon, 7 May 2018 17:49:26 +0000 (-0700) Subject: Extend block sparsity support for TPUs X-Git-Tag: upstream/v1.9.0_rc1~150^2~1^2~81 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9ba26ca0d59989592051fdb5c7a2caabe4f399f3;p=platform%2Fupstream%2Ftensorflow.git Extend block sparsity support for TPUs PiperOrigin-RevId: 195685740 --- diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index ea6032e..4b7af18 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -396,14 +396,19 @@ class Pruning(object): self._block_pooling_function) with ops.name_scope(weights.op.name + '_pruning_ops'): - abs_weights = math_ops.abs( - array_ops.reshape(weights, [ - 1, - squeezed_weights.get_shape()[0], - squeezed_weights.get_shape()[1], 1 - ])) + abs_weights = math_ops.abs(squeezed_weights) + pool_window = [self._block_dim[0], self._block_dim[1]] - pooled_weights = nn_ops.pool( + pool_fn = pruning_utils.factorized_pool + + if not self._spec.use_tpu: + pool_fn = nn_ops.pool + abs_weights = array_ops.reshape( + abs_weights, + [1, abs_weights.get_shape()[0], + abs_weights.get_shape()[1], 1]) + + pooled_weights = pool_fn( abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, @@ -411,19 +416,18 @@ class Pruning(object): padding='SAME', name=weights.op.name + '_pooled') + if pooled_weights.get_shape().ndims != 2: + pooled_weights = array_ops.squeeze(pooled_weights) + smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) - - reshaped_mask = array_ops.reshape( - new_mask, - [pooled_weights.get_shape()[1], - pooled_weights.get_shape()[2]]) updated_mask = pruning_utils.kronecker_product( - reshaped_mask, array_ops.ones(self._block_dim)) + new_mask, array_ops.ones(self._block_dim)) sliced_mask = array_ops.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1]]) + return smoothed_threshold, array_ops.reshape(sliced_mask, array_ops.shape(weights)) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 56d3dce..ef6c6a3 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs): return math_ops.div(cdf, math_ops.reduce_max(cdf)) +def factorized_pool(input_tensor, + window_shape, + pooling_type, + strides, + padding, + name=None): + """Performs m x n pooling through a combination of 1xm and 1xn pooling. + + Args: + input_tensor: Input tensor. Must be rank 2 + window_shape: Pooling window shape + pooling_type: Either 'MAX' or 'AVG' + strides: The stride of the pooling window + padding: 'SAME' or 'VALID'. + name: Name of the op + + Returns: + A rank 2 tensor containing the pooled output + + Raises: + ValueError: if the input tensor is not rank 2 + """ + if input_tensor.get_shape().ndims != 2: + raise ValueError('factorized_pool() accepts tensors of rank 2 only') + + [height, width] = input_tensor.get_shape() + with ops.name_scope(name, 'factorized_pool'): + input_tensor_aligned = array_ops.reshape( + input_tensor, [1, 1, height, width], + name=input_tensor.op.name + '_aligned') + + height_pooling = nn_ops.pool( + input_tensor_aligned, + window_shape=[1, window_shape[0]], + pooling_type=pooling_type, + strides=[1, strides[0]], + padding=padding) + swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2]) + + width_pooling = nn_ops.pool( + swap_height_width, + window_shape=[1, window_shape[1]], + pooling_type=pooling_type, + strides=[1, strides[1]], + padding=padding) + + return array_ops.squeeze( + array_ops.transpose(width_pooling, perm=[0, 1, 3, 2])) + + def determine_partitioned_axis(partitioned_variable): partitioned_axis = 0 concatenated_variable_shape = partitioned_variable.get_shape() diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index 10e1dd0..ccde5b4 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -22,8 +22,10 @@ import numpy as np from tensorflow.contrib.model_pruning.python import pruning_utils from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -31,6 +33,30 @@ from tensorflow.python.platform import test class PruningUtilsTest(test.TestCase): + def _compare_cdf(self, values): + abs_values = math_ops.abs(values) + max_value = math_ops.reduce_max(abs_values) + with self.test_session(): + variables.global_variables_initializer().run() + cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( + abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) + cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) + self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval()) + + def _compare_pooling_methods(self, weights, pooling_kwargs): + with self.test_session(): + variables.global_variables_initializer().run() + pooled_weights_tf = array_ops.squeeze( + nn_ops.pool( + array_ops.reshape( + weights, + [1, weights.get_shape()[0], + weights.get_shape()[1], 1]), **pooling_kwargs)) + pooled_weights_factorized_pool = pruning_utils.factorized_pool( + weights, **pooling_kwargs) + self.assertAllClose(pooled_weights_tf.eval(), + pooled_weights_factorized_pool.eval()) + def testHistogram(self): width = 10 height = 10 @@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase): self.assertAllEqual(len(norm_cdf_val), nbins) self.assertAllEqual(expected_cdf, norm_cdf_val) - def _compare_cdf(self, values): - abs_values = math_ops.abs(values) - max_value = math_ops.reduce_max(abs_values) - with self.test_session(): - variables.global_variables_initializer().run() - cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( - abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) - cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) - return cdf.eval(), cdf_from_histogram.eval() - def testCDFEquivalence2D(self): width = 100 height = 100 weights = variable_scope.get_variable("weights", shape=[width, height]) - cdf_val, cdf_from_histogram_val = self._compare_cdf(weights) - self.assertAllEqual(cdf_val, cdf_from_histogram_val) + self._compare_cdf(weights) def testCDFEquivalence4D(self): weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) - cdf_val, cdf_from_histogram_val = self._compare_cdf(weights) - self.assertAllEqual(cdf_val, cdf_from_histogram_val) + self._compare_cdf(weights) + + def testFactorizedAvgPool(self): + weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + pooling_kwargs = { + "window_shape": [2, 4], + "pooling_type": "AVG", + "strides": [2, 4], + "padding": "SAME" + } + self._compare_pooling_methods(weights, pooling_kwargs) + + def testFactorizedMaxPool(self): + weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + pooling_kwargs = { + "window_shape": [2, 4], + "pooling_type": "MAX", + "strides": [2, 4], + "padding": "SAME" + } + self._compare_pooling_methods(weights, pooling_kwargs) if __name__ == "__main__":