The quantizer should match the patterns for partition variables.
authorSuharsh Sivakumar <suharshs@google.com>
Fri, 18 May 2018 20:07:18 +0000 (13:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 20:09:59 +0000 (13:09 -0700)
PiperOrigin-RevId: 197189118

tensorflow/contrib/quantize/BUILD
tensorflow/contrib/quantize/python/quantize.py
tensorflow/contrib/quantize/python/quantize_test.py

index b9918fd..2336361 100644 (file)
@@ -155,8 +155,10 @@ py_test(
         "//tensorflow/python:array_ops",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:partitioned_variables",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:session",
+        "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
     ],
 )
index 4e0de24..cbba726 100644 (file)
@@ -218,8 +218,19 @@ def _FindLayersToQuantize(graph):
   """
   input_pattern = graph_matcher.OpTypePattern('*')
   weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
-  weight_identity_pattern = graph_matcher.OpTypePattern(
+  weight_partition_identity_pattern = graph_matcher.OpTypePattern(
       'Identity', inputs=[weight_var_pattern])
+  weight_partition_concat_pattern = graph_matcher.OpTypePattern(
+      'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*'])
+  weight_identity_pattern = graph_matcher.OpTypePattern(
+      'Identity',
+      inputs=[
+          graph_matcher.OneofPattern([
+              weight_partition_identity_pattern,
+              weight_partition_concat_pattern,
+              weight_var_pattern,
+          ])
+      ])
   weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
   folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
 
index e7360ae..92ca4a1 100644 (file)
@@ -27,6 +27,8 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import googletest
 
 conv2d = layers.conv2d
@@ -327,6 +329,66 @@ class QuantizeTest(test_util.TensorFlowTestCase):
     # No ops should be inserted or removed.
     self.assertEqual(op_names_before_quantize, op_names_after_quantize)
 
+  def testSinglePartitionedVariable(self):
+    self._RunTestOverParameters(self._testSinglePartitionedVariable)
+
+  def _testSinglePartitionedVariable(self, is_training):
+    # When weights are partitioned into a single partition, the weights variable
+    # is followed by a identity -> identity (An additional identity node).
+    partitioner = partitioned_variables.fixed_size_partitioner(1)
+    graph = ops.Graph()
+    with graph.as_default():
+      with variable_scope.variable_scope('part', partitioner=partitioner):
+        batch_size, height, width, depth = 5, 128, 128, 3
+        input1 = array_ops.zeros((batch_size, height, width, depth))
+        input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
+        conv = conv2d(
+            input1,
+            32, [5, 5],
+            stride=2,
+            padding='SAME',
+            weights_initializer=self._WeightInit(0.09),
+            activation_fn=None,
+            scope='test/test')
+        node = math_ops.add(conv, input2, name='test/add')
+        node = nn_ops.relu6(node, name='test/relu6')
+
+      quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+      # Check that the weight's quant node was added.
+      op_names = [op.name for op in graph.get_operations()]
+      self.assertTrue(
+          'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names)
+
+  def testMultiplePartitionedVariables(self):
+    self._RunTestOverParameters(self._testMultiplePartitionedVariables)
+
+  def _testMultiplePartitionedVariables(self, is_training):
+    # When weights are partitioned into multiple partitions the weights variable
+    # is followed by a identity -> concat -> identity to group the partitions.
+    partitioner = partitioned_variables.fixed_size_partitioner(2)
+    graph = ops.Graph()
+    with graph.as_default():
+      with variable_scope.variable_scope('part', partitioner=partitioner):
+        batch_size, height, width, depth = 5, 128, 128, 3
+        input1 = array_ops.zeros((batch_size, height, width, depth))
+        input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
+        conv = conv2d(
+            input1,
+            32, [5, 5],
+            stride=2,
+            padding='SAME',
+            weights_initializer=self._WeightInit(0.09),
+            activation_fn=None,
+            scope='test/test')
+        node = math_ops.add(conv, input2, name='test/add')
+        node = nn_ops.relu6(node, name='test/relu6')
+
+      quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+      # Check that the weight's quant node was added.
+      op_names = [op.name for op in graph.get_operations()]
+      self.assertTrue(
+          'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names)
+
   def _WeightInit(self, stddev):
     """Returns truncated normal variable initializer.