Handle variations in scoping of batch norms for correct unfused batch norm folding.
authorRaghuraman Krishnamoorthi <raghuramank@google.com>
Thu, 26 Apr 2018 22:40:15 +0000 (15:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 22:43:11 +0000 (15:43 -0700)
PiperOrigin-RevId: 194465704

tensorflow/contrib/quantize/python/fold_batch_norms.py
tensorflow/contrib/quantize/python/fold_batch_norms_test.py

index 6f41722..1f286bc 100644 (file)
@@ -480,6 +480,43 @@ def _IsValidUnfusedBatchNorm(graph, context):
   return bool(add_shift.outputs[0].consumers())
 
 
+def _FindMatchingTensor(graph, match_pattern, scope):
+  """Finds best match of ops matching match_pattern with scope.
+
+     Example: _FindMatchingTensor(graph,'/BatchNorm/moments/Squeeze',
+     'MobilenetV1/MobilenetV1/Conv2d_0/') returns:
+      Tensor('MobilenetV1/Conv2d_0/BatchNorm/moments/Squeeze')
+
+  Args:
+    graph: Graph to inspect.
+    match_pattern: Part of the name of the op that we need to match, should
+    be present in the op's name
+    scope: The scope of the op. All the elements of the scope need not be
+    present in the op's name.
+
+  Returns:
+    Tensor from graph that provides the best match to the match_pattern and
+    scope
+  """
+
+  oplist = graph.get_operations()
+  split_context = set(scope.split('/'))
+  match_dict = {}
+  for op in oplist:
+    if op.name.endswith(match_pattern):
+      split_name = op.name.split('/')
+      num_matches = len(set(split_name) & split_context)
+      if num_matches > 0:
+        match_dict[op.name] = num_matches
+  # match_dict contains matching op names from graph with values being
+  # number of matches to scope. We pick the key with the most matches
+  if match_dict:
+    max_key = max(match_dict, key=match_dict.get)
+    return graph.get_tensor_by_name(max_key + ':0')
+  else:
+    return None
+
+
 def _GetBatchNormParams(graph, context, has_scaling):
   """Extracts relevant tensors for folding batch norms.
 
@@ -500,7 +537,8 @@ def _GetBatchNormParams(graph, context, has_scaling):
   bn_decay_mean_tensor = None
   bn_decay_var_tensor = None
 
-  split_context = context.split('/')
+  # TODO(raghuramank) This code relies on string matching and needs to be
+  # updated if unfused batch norm continues to be widely used
   # Matching variable names is brittle and relies on scoping
   # conventions. Fused batch norm folding is more robust. Support for unfused
   # batch norms will be deprecated as we move forward. Fused batch norms allow
@@ -518,49 +556,48 @@ def _GetBatchNormParams(graph, context, has_scaling):
   # and the names of the tensors start with a single MobilenetV2
   # The moving mean for example, has the name:
   # MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
-  # We ignore the first string (MobilenetV1 or MobilenetV2)
-  # in the context to match correctly in both cases
-
-  base_context = '/'.join(split_context[1:])
-  oplist = graph.get_operations()
-  op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze'
-  op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1'
-  op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y'
-  op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay'
-  op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay'
+  # We identify the best match for an op by checking for
+  # 1. The suffix of the op is exactly matched
+  # 2. Maximum number of matches with the context.The matching
+  # score is given by the number of parts of context (split by /) that
+  # are present in the parts of the tensor name (again split by /).
+  # For example: scope= MobilenetV2/MobilenetV2/expanded_conv_3 and
+  # op.name =  MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
+  # will have 2 matches,scope with a different conv layer will have one match.
+
+  op_suffix_mean = '/BatchNorm/moments/Squeeze'
+  op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
+  op_suffix_epsilon = '/BatchNorm/batchnorm/add/y'
+  op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
+  op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
 
   if variable_scope.get_variable_scope().use_resource:
-    op_suffix_gamma = base_context + '/BatchNorm/gamma/Read/ReadVariableOp'
+    op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
     op_suffix_moving_variance = (
-        base_context + '/BatchNorm/moving_variance/Read/ReadVariableOp')
-    op_suffix_moving_mean = (
-        base_context + '/BatchNorm/moving_mean/Read/ReadVariableOp')
+        '/BatchNorm/moving_variance/Read/ReadVariableOp')
+    op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
   else:
-    op_suffix_gamma = base_context + '/BatchNorm/gamma'
-    op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read'
-    op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read'
+    op_suffix_gamma = '/BatchNorm/gamma'
+    op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
+    op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
   # Parse through list of ops to find relevant ops
-  for op in oplist:
-    if op.name.endswith(op_suffix_mean):
-      # This is an efficient way to check for two things:
-      # Is batch norm present and is it training mode?
-      # Batch statistics are computed only during batch norm in training
-      batch_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
-    if op.name.endswith(op_suffix_variance):
-      batch_variance_tensor = graph.get_tensor_by_name(op.name + ':0')
-    if op.name.endswith(op_suffix_moving_mean):
-      moving_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
-    if op.name.endswith(op_suffix_moving_variance):
-      moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0')
-    if op.name.endswith(op_suffix_epsilon):
-      batch_epsilon = graph.get_tensor_by_name(op.name + ':0')
-    if op.name.endswith(op_suffix_bn_decay_mean):
-      bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
-    if op.name.endswith(op_suffix_bn_decay_var):
-      bn_decay_var_tensor = graph.get_tensor_by_name(op.name + ':0')
-    if has_scaling:
-      if op.name.endswith(op_suffix_gamma):
-        gamma_tensor = graph.get_tensor_by_name(op.name + ':0')
+
+  batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
+  batch_variance_tensor = _FindMatchingTensor(graph, op_suffix_variance,
+                                              context)
+  moving_mean_tensor = _FindMatchingTensor(graph, op_suffix_moving_mean,
+                                           context)
+  moving_variance_tensor = _FindMatchingTensor(graph, op_suffix_moving_variance,
+                                               context)
+  batch_epsilon = _FindMatchingTensor(graph, op_suffix_epsilon, context)
+  bn_decay_mean_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_mean,
+                                             context)
+  bn_decay_var_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_var,
+                                            context)
+  if batch_mean_tensor is None and moving_mean_tensor is None:
+    ValueError('Error folding unfused batch norms')
+  if has_scaling:
+    gamma_tensor = _FindMatchingTensor(graph, op_suffix_gamma, context)
 
   if not has_scaling:
     gamma_tensor = array_ops.ones(moving_mean_tensor.shape)
index 64e8142..fa5e11b 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 from tensorflow.python.training import saver as saver_lib
@@ -157,32 +158,38 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
       out_depth = 3
       stride = 1
       activation_fn = relu
-      scope = 'network/expanded_conv_1/conv'
-      layer1 = conv2d(
-          inputs,
-          out_depth, [5, 5],
-          stride=stride,
-          padding='SAME',
-          weights_initializer=self._WeightInit(0.09),
-          activation_fn=activation_fn,
-          normalizer_fn=batch_norm,
-          normalizer_params=self._BatchNormParams(
-              scale=has_scaling, fused=fused_batch_norm),
-          scope=scope)
-      # Add another layer
-      scope = 'network/expanded_conv_2/conv'
-
-      _ = conv2d(
-          layer1,
-          2 * out_depth, [5, 5],
-          stride=stride,
-          padding='SAME',
-          weights_initializer=self._WeightInit(0.09),
-          activation_fn=activation_fn,
-          normalizer_fn=batch_norm,
-          normalizer_params=self._BatchNormParams(
-              scale=has_scaling, fused=fused_batch_norm),
-          scope=scope)
+      scope = 'topnet/testnet'
+      with variable_scope.variable_scope(scope, [inputs]):
+        layer1 = conv2d(
+            inputs,
+            out_depth, [5, 5],
+            stride=stride,
+            padding='SAME',
+            weights_initializer=self._WeightInit(0.09),
+            activation_fn=None,
+            normalizer_fn=None,
+            scope='testnet/layer1')
+        # Add bn and relu with different scope
+        layer1 = batch_norm(
+            layer1, scale=has_scaling, fused=fused_batch_norm, scope='layer1')
+        layer1 = activation_fn(layer1)
+        layer2 = conv2d(
+            layer1,
+            2 * out_depth, [5, 5],
+            stride=stride,
+            padding='SAME',
+            weights_initializer=self._WeightInit(0.09),
+            activation_fn=activation_fn,
+            normalizer_fn=batch_norm,
+            normalizer_params=self._BatchNormParams(
+                scale=has_scaling, fused=fused_batch_norm),
+            scope='testnet/layer2')
+        # Add bn and relu with different scope
+        layer2 = batch_norm(
+            layer2, scale=has_scaling, fused=fused_batch_norm, scope='layer2')
+        _ = activation_fn(layer2)
+
+      scope = 'topnet/testnet/testnet/layer2'
 
       fold_batch_norms.FoldBatchNorms(
           g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)