-- Add a new histogram/cdf computation method compatible with the TPU.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 07:03:48 +0000 (00:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 07:06:16 +0000 (00:06 -0700)
-- Refactor utility functions into pruning_utils.py and add tests

PiperOrigin-RevId: 192727737

tensorflow/contrib/model_pruning/BUILD
tensorflow/contrib/model_pruning/README.md
tensorflow/contrib/model_pruning/python/pruning.py
tensorflow/contrib/model_pruning/python/pruning_test.py
tensorflow/contrib/model_pruning/python/pruning_utils.py [new file with mode: 0644]
tensorflow/contrib/model_pruning/python/pruning_utils_test.py [new file with mode: 0644]

index f50575b..54bd39a 100644 (file)
@@ -72,15 +72,37 @@ py_library(
 )
 
 py_library(
+    name = "pruning_utils",
+    srcs = ["python/pruning_utils.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/python:platform",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
     name = "pruning",
     srcs = ["python/pruning.py"],
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
         ":core_layers",
+        ":pruning_utils",
         "//tensorflow/contrib/training:training_py",
         "//tensorflow/python:platform",
-        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "pruning_utils_test",
+    size = "small",
+    srcs = ["python/pruning_utils_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":pruning_utils",
+        "//tensorflow/python:client_testlib",
     ],
 )
 
index 52b659c..86f4fd6 100644 (file)
@@ -45,7 +45,7 @@ The pruning library allows for specification of the following hyper parameters:
 | do_not_prune | list of strings | [""] | list of layers names that are not pruned |
 | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
 | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
-| nbins | integer | 255 | Number of bins to use for histogram computation |
+| nbins | integer | 256 | Number of bins to use for histogram computation |
 | block_height|integer | 1 | Number of rows in a block for block sparse matrices|
 | block_width |integer | 1 | Number of cols in a block for block sparse matrices|
 | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
index 5146a4a..ea6032e 100644 (file)
   # Returns a list of all the weight tensors that have been masked
   get_weights()
 
-  The Pruning class uses a proto (defined in pruning.proto) to set up the
-  parameters for a pruning specification. Here's a typical usage:
+  The Pruning class uses a tf.hparams object to set up the
+  parameters for a model pruning. Here's a typical usage:
 
-  # Initialize a pruning spec from a proto
-  pruning_spec = '/tmp/pruning.pb'
-  p = Pruning(pruning_spec)
+  # Parse pruning hyperparameters
+  pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
+
+  # Create a pruning object using the pruning_hparams
+  p = pruning.Pruning(pruning_hparams)
 
   # Add mask update ops to the graph
   mask_update_op = p.conditional_mask_update_op()
 
   # An object of the pruning also accepts externally defined sparsity:
   sparsity = tf.Variable(0.5, name = "ConstantSparsity")
-  pruning_spec = '/tmp/pruning.pb'
-  p = Pruning(pruning_spec, sparsity=sparsity)
-
+  p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
 """
 # pylint: disable=missing-docstring
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
-
+from tensorflow.contrib.model_pruning.python import pruning_utils
 from tensorflow.contrib.model_pruning.python.layers import core_layers as core
 from tensorflow.contrib.training.python.training import hparam
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_impl
@@ -87,172 +85,18 @@ _WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
 _MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
 
 
-def _weight_mask_variable(var, scope):
-  """Create a mask for the weights.
-
-  This function adds a variable 'mask' to the graph.
-
-  Args:
-    var: the weight variable that needs to be masked
-    scope: The variable scope of the variable var
-
-  Returns:
-    the mask variable of the same size and shape as var, initialized to all 1s.
-  """
-  with variable_scope.variable_scope(scope):
-    mask = variable_scope.get_variable(
-        'mask',
-        var.get_shape(),
-        initializer=init_ops.ones_initializer(),
-        trainable=False,
-        dtype=var.dtype)
-  return mask
-
-
-def _weight_threshold_variable(var, scope):
-  """Create a scalar threshold for the weights.
-
-  This function adds a variable
-  'threshold' to the graph.
-
-  Args:
-    var: The weight variable that needs to be masked
-    scope: The variable scope of the variable var
-
-  Returns:
-    a scalar threshold variable initialized to 0.
-  """
-  with variable_scope.variable_scope(scope):
-    threshold = variable_scope.get_variable(
-        'threshold', [],
-        initializer=init_ops.zeros_initializer(),
-        trainable=False,
-        dtype=var.dtype)
-    return threshold
-
-
-def _kronecker_product(mat1, mat2):
-  """Computes the Kronecker product of two matrices mat1 and mat2.
-
-  Args:
-    mat1: A matrix of size m x n
-    mat2: A matrix of size p x q
-  Returns:
-    Kronecker product of matrices mat1 and mat2 of size mp x nq
-  """
-
-  m1, n1 = mat1.get_shape().as_list()
-  mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
-  m2, n2 = mat2.get_shape().as_list()
-  mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
-  return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
-
-
-def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None):
-  """Return histogram of values.
-
-  Given the tensor `values`, this operation returns a rank 1 histogram counting
-  the number of entries in `values` that fell into every bin.  The bins are
-  equal width and determined by the arguments `value_range` and `nbins`.
-
-  Args:
-    values:  Numeric `Tensor`.
-    value_range:  Shape [2] `Tensor` of same `dtype` as `values`.
-      values <= value_range[0] will be mapped to hist[0],
-      values >= value_range[1] will be mapped to hist[-1].
-    nbins:  Scalar `int32 Tensor`.  Number of histogram bins.
-    dtype:  dtype for returned histogram.
-    name:  A name for this operation (defaults to 'histogram').
-
-  Returns:
-    A 1-D `Tensor` holding histogram of values.
-
-  """
-  with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
-    values = ops.convert_to_tensor(values, name='values')
-    values = gen_array_ops.reshape(values, [-1])
-    value_range = ops.convert_to_tensor(value_range, name='value_range')
-    nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins')
-    nbins_float = math_ops.cast(nbins, values.dtype)
-
-    # Map tensor values that fall within value_range to [0, 1].
-    scaled_values = math_ops.truediv(
-        values - value_range[0],
-        value_range[1] - value_range[0],
-        name='scaled_values')
-
-    # map tensor values within the open interval value_range to {0,.., nbins-1},
-    # values outside the open interval will be zero or less, or nbins or more.
-    indices = math_ops.floor(nbins_float * scaled_values, name='indices')
-
-    # Clip edge cases (e.g. value = value_range[1]) or "outliers."
-    indices = math_ops.cast(
-        clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32)
-
-    return math_ops.unsorted_segment_sum(
-        array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
-
-
-def _determine_partitioned_axis(partitioned_variable):
-  partitioned_axis = 0
-  concatenated_variable_shape = partitioned_variable.get_shape()
-  for partition in partitioned_variable:
-    partition_shape = partition.get_shape()
-    maybe_partitioned_axis = np.less(partition_shape,
-                                     concatenated_variable_shape)
-    # Sanity check: make sure number of partitioned axis == 1
-    if np.count_nonzero(maybe_partitioned_axis) != 1:
-      raise ValueError('Number of partitioned axes %s not equal to 1' %
-                       np.count_nonzero(maybe_partitioned_axis))
-    partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
-  return partitioned_axis
-
-
-def _variable_assign(var, new_value):
-  return state_ops.assign(var, new_value, name=var.op.name + '_assign')
-
-
-def _partitioned_variable_assign(partitioned_var, new_value):
-  """Assign op for partitioned variables.
-
-  Args:
-    partitioned_var: A partitioned tensorflow variable
-    new_value: Value to be assigned to the variable var
-
-  Returns:
-    A tensorflow op that groups the assign ops for each of the variable slices
-  """
-  # Determine which axis was used to partition the variable. Currently
-  # tensorflow allows partitioning variable only along 1 axis.
-  axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis(
-      partitioned_var)
-
-  partition_sizes = np.array(
-      [partition.get_shape()[axis] for partition in partitioned_var])
-  new_partitioned_values = array_ops.split(
-      new_value,
-      ops.convert_to_tensor(partition_sizes, dtype=np.int32),
-      axis=axis)
-  op_list = []
-  for partition in partitioned_var:
-    op_list.append(
-        _variable_assign(partition, new_partitioned_values[len(op_list)]))
-  return control_flow_ops.group(
-      *op_list, name=partitioned_var.name + '_group_assign')
-
-
 def apply_mask(x, scope=''):
   """Apply mask to a given weight tensor.
 
   Args:
     x: Input weight tensor
-    scope: The current variable scope. Defaults to ""
+    scope: The current variable scope. Defaults to "".
   Returns:
     Tensor representing masked_weights
   """
 
-  mask = _weight_mask_variable(x, scope)
-  threshold = _weight_threshold_variable(x, scope)
+  mask = pruning_utils.weight_mask_variable(x, scope)
+  threshold = pruning_utils.weight_threshold_variable(x, scope)
   # Add masked_weights in the weights namescope so as to make it easier
   # for the quantization library to add quant ops.
   masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
@@ -335,6 +179,8 @@ def get_pruning_hparams():
     sparsity_function_exponent: float
       exponent = 1 is linearly varying sparsity between initial and final.
       exponent > 1 varies more slowly towards the end than the beginning
+    use_tpu: False
+      Indicates whether to use TPU
 
     We use the following sparsity function:
 
@@ -357,7 +203,7 @@ def get_pruning_hparams():
       do_not_prune=[''],
       threshold_decay=0.9,
       pruning_frequency=10,
-      nbins=255,
+      nbins=256,
       block_height=1,
       block_width=1,
       block_pooling_function='AVG',
@@ -365,7 +211,8 @@ def get_pruning_hparams():
       target_sparsity=0.5,
       sparsity_function_begin_step=0,
       sparsity_function_end_step=100,
-      sparsity_function_exponent=3)
+      sparsity_function_exponent=3,
+      use_tpu=False)
 
 
 class Pruning(object):
@@ -414,7 +261,7 @@ class Pruning(object):
     if graph_global_step is None:
       graph_global_step = training_util.get_global_step()
 
-    return math_ops.cast(graph_global_step, np.int32)
+    return math_ops.cast(graph_global_step, dtypes.int32)
 
   def _setup_sparsity(self):
     begin_step = self._spec.sparsity_function_begin_step
@@ -429,13 +276,13 @@ class Pruning(object):
           (begin_step, end_step))
 
     with ops.name_scope(self._spec.name):
-      p = math_ops.minimum(1.0,
-                           math_ops.maximum(
-                               0.0,
-                               math_ops.div(
-                                   math_ops.cast(self._global_step - begin_step,
-                                                 np.float32),
-                                   end_step - begin_step)))
+      p = math_ops.minimum(
+          1.0,
+          math_ops.maximum(
+              0.0,
+              math_ops.div(
+                  math_ops.cast(self._global_step - begin_step, dtypes.float32),
+                  end_step - begin_step)))
       sparsity = math_ops.add(
           math_ops.multiply(initial_sparsity - target_sparsity,
                             math_ops.pow(1 - p, exponent)),
@@ -445,17 +292,18 @@ class Pruning(object):
     return sparsity
 
   def _setup_last_update_step(self):
-    with variable_scope.variable_scope(self._spec.name) as scope:
+    with variable_scope.variable_scope(
+        self._spec.name, use_resource=self._spec.use_tpu) as scope:
       try:
         last_update_step = variable_scope.get_variable(
             'last_mask_update_step', [],
             initializer=init_ops.zeros_initializer(),
             trainable=False,
-            dtype=np.int32)
+            dtype=dtypes.int32)
       except ValueError:
         scope.reuse_variables()
         last_update_step = variable_scope.get_variable(
-            'last_mask_update_step', dtype=np.int32)
+            'last_mask_update_step', dtype=dtypes.int32)
     return last_update_step
 
   def _exists_in_do_not_prune_list(self, tensor_name):
@@ -497,18 +345,16 @@ class Pruning(object):
     with ops.name_scope(weights.op.name + '_pruning_ops'):
       abs_weights = math_ops.abs(weights)
       max_value = math_ops.reduce_max(abs_weights)
-      histogram = _histogram(
-          abs_weights, [0.0, max_value],
-          nbins=self._spec.nbins,
-          dtype=np.float32)
+      cdf_fn = pruning_utils.compute_cdf_from_histogram
+      if self._spec.use_tpu:
+        cdf_fn = pruning_utils.compute_cdf
 
-      cdf = math_ops.cumsum(histogram)
-      norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram))
+      norm_cdf = cdf_fn(abs_weights, [0.0, max_value], nbins=self._spec.nbins)
       current_threshold = math_ops.multiply(
           math_ops.div(
               math_ops.reduce_sum(
                   math_ops.cast(
-                      math_ops.less(norm_cdf, self._sparsity), np.float32)),
+                      math_ops.less(norm_cdf, self._sparsity), dtypes.float32)),
               float(self._spec.nbins)), max_value)
 
       smoothed_threshold = math_ops.add_n([
@@ -516,7 +362,7 @@ class Pruning(object):
           math_ops.multiply(threshold, self._spec.threshold_decay)
       ])
       new_mask = math_ops.cast(
-          math_ops.greater(abs_weights, smoothed_threshold), np.float32)
+          math_ops.greater(abs_weights, smoothed_threshold), dtypes.float32)
     return smoothed_threshold, new_mask
 
   def _maybe_update_block_mask(self, weights, threshold):
@@ -572,8 +418,8 @@ class Pruning(object):
           new_mask,
           [pooled_weights.get_shape()[1],
            pooled_weights.get_shape()[2]])
-      updated_mask = _kronecker_product(reshaped_mask,
-                                        array_ops.ones(self._block_dim))
+      updated_mask = pruning_utils.kronecker_product(
+          reshaped_mask, array_ops.ones(self._block_dim))
       sliced_mask = array_ops.slice(
           updated_mask, [0, 0],
           [squeezed_weights.get_shape()[0],
@@ -608,11 +454,12 @@ class Pruning(object):
           continue
 
       new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
-      self._assign_ops.append(_variable_assign(threshold, new_threshold))
+      self._assign_ops.append(
+          pruning_utils.variable_assign(threshold, new_threshold))
 
       self._assign_ops.append(
-          _partitioned_variable_assign(mask, new_mask)
-          if is_partitioned else _variable_assign(mask, new_mask))
+          pruning_utils.partitioned_variable_assign(mask, new_mask)
+          if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
 
   def mask_update_op(self):
     with ops.name_scope(self._spec.name):
index 89e6571..f80b7c5 100644 (file)
@@ -110,12 +110,12 @@ class PruningTest(test.TestCase):
       self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
       session.run(mask_update_op)
       masked_weights_val = masked_weights.eval()
-      self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+      self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
 
   def _blockMasking(self, hparams, weights, expected_mask):
 
     threshold = variables.Variable(0.0, name="threshold")
-    sparsity = variables.Variable(0.51, name="sparsity")
+    sparsity = variables.Variable(0.5, name="sparsity")
     test_spec = ",".join(hparams)
     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
 
@@ -138,7 +138,8 @@ class PruningTest(test.TestCase):
     weights_max = constant_op.constant(
         [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
          [0.0, -0.3, 0.0, -0.4]])
-    expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]
+    expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
+                     [1., 1., 1., 1.], [1., 1., 1., 1.]]
 
     self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
                        expected_mask)
@@ -155,7 +156,8 @@ class PruningTest(test.TestCase):
     weights_max = constant_op.constant(
         [[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
           [0.0, -0.3, 0.0, -0.4]]])
-    expected_mask = [[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]]
+    expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
+                      [1., 1., 1., 1.], [1., 1., 1., 1.]]]
 
     self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
                        expected_mask)
@@ -178,11 +180,12 @@ class PruningTest(test.TestCase):
       masked_weights_val = masked_weights.eval()
       session.run(mask_update_op)
       masked_weights_val = masked_weights.eval()
-      self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+      self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
 
   def testConditionalMaskUpdate(self):
     param_list = [
-        "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
+        "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6",
+        "nbins=100"
     ]
     test_spec = ",".join(param_list)
     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
new file mode 100644 (file)
index 0000000..56d3dce
--- /dev/null
@@ -0,0 +1,269 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for adding pruning related ops to the graph.
+"""
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+
+_NBINS = 256
+
+
+def weight_mask_variable(var, scope):
+  """Create a mask for the weights.
+
+  This function adds a variable 'mask' to the graph.
+
+  Args:
+    var: the weight variable that needs to be masked
+    scope: The variable scope of the variable var
+
+  Returns:
+    the mask variable of the same size and shape as var, initialized to all 1s.
+  """
+  with variable_scope.variable_scope(scope):
+    mask = variable_scope.get_variable(
+        'mask',
+        var.get_shape(),
+        initializer=init_ops.ones_initializer(),
+        trainable=False,
+        dtype=var.dtype)
+  return mask
+
+
+def weight_threshold_variable(var, scope):
+  """Create a scalar threshold for the weights.
+
+  This function adds a variable
+  'threshold' to the graph.
+
+  Args:
+    var: The weight variable that needs to be masked
+    scope: The variable scope of the variable var
+
+  Returns:
+    a scalar threshold variable initialized to 0.
+  """
+  with variable_scope.variable_scope(scope):
+    threshold = variable_scope.get_variable(
+        'threshold', [],
+        initializer=init_ops.zeros_initializer(),
+        trainable=False,
+        dtype=var.dtype)
+    return threshold
+
+
+def kronecker_product(mat1, mat2):
+  """Computes the Kronecker product of two matrices mat1 and mat2.
+
+  Args:
+    mat1: A matrix of size m x n
+    mat2: A matrix of size p x q
+  Returns:
+    Kronecker product of matrices mat1 and mat2 of size mp x nq
+  """
+
+  m1, n1 = mat1.get_shape().as_list()
+  mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
+  m2, n2 = mat2.get_shape().as_list()
+  mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
+  return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+
+
+def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None):
+  """Return histogram of values.
+
+  Given the tensor `values`, this operation returns a rank 1 histogram counting
+  the number of entries in `values` that fell into every bin.  The bins are
+  equal width and determined by the arguments `value_range` and `nbins`.
+
+  Args:
+    values:  Numeric `Tensor`.
+    value_range:  Shape [2] `Tensor` of same `dtype` as `values`.
+      values <= value_range[0] will be mapped to hist[0],
+      values >= value_range[1] will be mapped to hist[-1].
+    nbins:  Scalar `int32 Tensor`.  Number of histogram bins.
+    dtype:  dtype for returned histogram.
+    name:  A name for this operation (defaults to 'histogram').
+
+  Returns:
+    A 1-D `Tensor` holding histogram of values.
+
+  """
+  with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
+    values = ops.convert_to_tensor(values, name='values')
+    values = array_ops.reshape(values, [-1])
+    value_range = ops.convert_to_tensor(value_range, name='value_range')
+    nbins_float = np.float32(nbins)
+
+    # Map tensor values that fall within value_range to [0, 1].
+    scaled_values = math_ops.truediv(
+        values - value_range[0],
+        value_range[1] - value_range[0],
+        name='scaled_values')
+
+    # map tensor values within the open interval value_range to {0,.., nbins-1},
+    # values outside the open interval will be zero or less, or nbins or more.
+    indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+    # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+    indices = math_ops.cast(
+        clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)
+
+    return math_ops.unsorted_segment_sum(
+        array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
+
+
+def compute_cdf_from_histogram(values, value_range, **kwargs):
+  """Returns the normalized cumulative distribution of the given values tensor.
+
+  Computes the histogram and uses tf.cumsum to arrive at cdf
+
+  Args:
+    values:  Numeric `Tensor`.
+    value_range:  Shape [2] `Tensor` of same `dtype` as `values`.
+    **kwargs: keyword arguments: nbins, name
+
+  Returns:
+    A 1-D `Tensor` holding normalized cdf of values.
+
+  """
+  nbins = kwargs.get('nbins', _NBINS)
+  name = kwargs.get('name', None)
+  with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
+    histogram = _histogram(
+        values, value_range, dtype=dtypes.float32, nbins=nbins)
+    cdf = math_ops.cumsum(histogram)
+    return math_ops.div(cdf, math_ops.reduce_max(cdf))
+
+
+def compute_cdf(values, value_range, **kwargs):
+  """Returns the normalized cumulative distribution of the given values tensor.
+
+  Uses tf.while_loop to directly compute the cdf of the values. Number of bins
+  for histogram is fixed at _NBINS=255
+
+  Args:
+    values:  Numeric `Tensor`.
+    value_range:  Shape [2] `Tensor` of same `dtype` as `values`
+    **kwargs: keyword arguments: name
+
+  Returns:
+    A 1-D `Tensor` holding normalized cdf of values.
+
+  """
+  nbins = _NBINS
+  name = kwargs.get('name', None)
+  with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
+    values = ops.convert_to_tensor(values, name='values')
+    value_range = ops.convert_to_tensor(value_range, name='value_range')
+    nbins_float = np.float32(nbins)
+
+    # Map tensor values that fall within value_range to [0, 1].
+    scaled_values = math_ops.truediv(
+        values - value_range[0],
+        value_range[1] - value_range[0],
+        name='scaled_values')
+
+    # map tensor values within the open interval value_range to {0,.., nbins-1},
+    # values outside the open interval will be zero or less, or nbins or more.
+    indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+    # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+    indices = math_ops.cast(
+        clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)
+
+    cdf = array_ops.zeros(nbins)
+    i = constant_op.constant(0)
+
+    def loop_cond(loop_count, _):
+      return math_ops.less(loop_count, nbins)
+
+    def loop_body(loop_count, cdf):
+      temp = math_ops.reduce_sum(
+          math_ops.cast(
+              math_ops.less_equal(indices, loop_count), dtypes.float32))
+      cdf = math_ops.add(
+          cdf,
+          array_ops.one_hot(
+              loop_count, depth=_NBINS, on_value=temp, off_value=0.0))
+      return [loop_count + 1, cdf]
+
+    _, cdf = control_flow_ops.while_loop(
+        loop_cond, loop_body, [i, cdf], maximum_iterations=nbins)
+
+    return math_ops.div(cdf, math_ops.reduce_max(cdf))
+
+
+def determine_partitioned_axis(partitioned_variable):
+  partitioned_axis = 0
+  concatenated_variable_shape = partitioned_variable.get_shape()
+  for partition in partitioned_variable:
+    partition_shape = partition.get_shape()
+    maybe_partitioned_axis = np.less(partition_shape,
+                                     concatenated_variable_shape)
+    # Sanity check: make sure number of partitioned axis == 1
+    if np.count_nonzero(maybe_partitioned_axis) != 1:
+      raise ValueError('Number of partitioned axes %s not equal to 1' %
+                       np.count_nonzero(maybe_partitioned_axis))
+    partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
+  return partitioned_axis
+
+
+def variable_assign(var, new_value):
+  return state_ops.assign(var, new_value, name=var.op.name + '_assign')
+
+
+def partitioned_variable_assign(partitioned_var, new_value):
+  """Assign op for partitioned variables.
+
+  Args:
+    partitioned_var: A partitioned tensorflow variable
+    new_value: Value to be assigned to the variable var
+
+  Returns:
+    A tensorflow op that groups the assign ops for each of the variable slices
+  """
+  # Determine which axis was used to partition the variable. Currently
+  # tensorflow allows partitioning variable only along 1 axis.
+  axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis(
+      partitioned_var)
+
+  partition_sizes = np.array(
+      [partition.get_shape()[axis] for partition in partitioned_var])
+  new_partitioned_values = array_ops.split(
+      new_value,
+      ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32),
+      axis=axis)
+  op_list = []
+  for partition in partitioned_var:
+    op_list.append(
+        variable_assign(partition, new_partitioned_values[len(op_list)]))
+  return control_flow_ops.group(
+      *op_list, name=partitioned_var.name + '_group_assign')
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
new file mode 100644 (file)
index 0000000..10e1dd0
--- /dev/null
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utility functions in pruning_utils.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.model_pruning.python import pruning_utils
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class PruningUtilsTest(test.TestCase):
+
+  def testHistogram(self):
+    width = 10
+    height = 10
+    nbins = 100
+    expected_histogram = np.full(nbins, 1.0)
+    init = init_ops.constant_initializer(np.linspace(0.0, 1.0, width * height))
+    weights = variable_scope.get_variable(
+        "weights", [width, height], initializer=init)
+    histogram = pruning_utils._histogram(
+        weights, [0, 1.0], nbins, dtype=np.float32)
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      computed_histogram = histogram.eval()
+    self.assertAllEqual(expected_histogram, computed_histogram)
+
+  def testCDF(self):
+    nbins = 5
+    weights = constant_op.constant([-1, 0, 1, 1.5, 2, 3, 4, 5, 10, 100])
+    abs_weights = math_ops.abs(weights)
+    norm_cdf = pruning_utils.compute_cdf_from_histogram(
+        abs_weights, [0.0, 5.0], nbins=nbins)
+    expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32)
+    with self.test_session() as sess:
+      variables.global_variables_initializer().run()
+      norm_cdf_val = sess.run(norm_cdf)
+      self.assertAllEqual(len(norm_cdf_val), nbins)
+      self.assertAllEqual(expected_cdf, norm_cdf_val)
+
+  def _compare_cdf(self, values):
+    abs_values = math_ops.abs(values)
+    max_value = math_ops.reduce_max(abs_values)
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
+          abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
+      cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
+      return cdf.eval(), cdf_from_histogram.eval()
+
+  def testCDFEquivalence2D(self):
+    width = 100
+    height = 100
+    weights = variable_scope.get_variable("weights", shape=[width, height])
+    cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
+    self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+
+  def testCDFEquivalence4D(self):
+    weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
+    cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
+    self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+
+
+if __name__ == "__main__":
+  test.main()