Drop name_scope from operation names during quantization to avoid doubling it up.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Mar 2018 14:27:16 +0000 (07:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 14:31:58 +0000 (07:31 -0700)
PiperOrigin-RevId: 189737746

tensorflow/contrib/quantize/python/common.py
tensorflow/contrib/quantize/python/quantize.py
tensorflow/contrib/quantize/python/quantize_test.py

index 3138149..bf648e1 100644 (file)
@@ -123,3 +123,11 @@ def CreateOrGetQuantizationStep():
         # normal variables to return a tensor of the same name.
         return array_ops.identity(
             state_ops.assign_add(quantization_step_tensor, 1))
+
+
+def DropStringPrefix(s, prefix):
+  """If the string starts with this prefix, drops it."""
+  if s.startswith(prefix):
+    return s[len(prefix):]
+  else:
+    return s
index 9780e6d..2b5b877 100644 (file)
@@ -367,6 +367,12 @@ def _InsertQuantOp(context,
       consumer operation.
   """
   name_prefix = _AddContextToName(context, name)
+  # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
+  # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
+  # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
+  # breaks things later.
+  name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/')
+
   inputs = producer.outputs[0]
   if moving_avg:
     quant = (
index 8e60f4b..216310a 100644 (file)
@@ -164,6 +164,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
       self.assertTrue('FakeQuantWithMinMaxVars' in
                       [i.op.type for i in bypass_tensor.op.inputs])
 
+  def testWithNameScope(self):
+    self._RunTestOverParameters(self._TestWithNameScope)
+
+  def _TestWithNameScope(self, is_training):
+    graph = ops.Graph()
+    with graph.as_default():
+      with graph.name_scope('name_scope'):
+        batch_size, height, width, depth = 5, 128, 128, 3
+        input1 = array_ops.zeros((batch_size, height, width, depth))
+        _ = conv2d(
+            input1,
+            32, [5, 5],
+            stride=2,
+            padding='SAME',
+            weights_initializer=self._WeightInit(0.09),
+            activation_fn=None,
+            scope='test')
+
+        quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+    for op in graph.get_operations():
+      self.assertTrue(not op.name.startswith('name_scope/name_scope/'),
+                      'Broken op: %s' % op.name)
+
   def _WeightInit(self, stddev):
     """Returns truncated normal variable initializer.