From 3ab696e7e7e5c422acaa2fb2f3a938ce14effc9c Mon Sep 17 00:00:00 2001 From: Raghuraman Krishnamoorthi Date: Thu, 26 Apr 2018 15:40:15 -0700 Subject: [PATCH] Handle variations in scoping of batch norms for correct unfused batch norm folding. PiperOrigin-RevId: 194465704 --- .../contrib/quantize/python/fold_batch_norms.py | 115 ++++++++++++++------- .../quantize/python/fold_batch_norms_test.py | 59 ++++++----- 2 files changed, 109 insertions(+), 65 deletions(-) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 6f41722..1f286bc 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -480,6 +480,43 @@ def _IsValidUnfusedBatchNorm(graph, context): return bool(add_shift.outputs[0].consumers()) +def _FindMatchingTensor(graph, match_pattern, scope): + """Finds best match of ops matching match_pattern with scope. + + Example: _FindMatchingTensor(graph,'/BatchNorm/moments/Squeeze', + 'MobilenetV1/MobilenetV1/Conv2d_0/') returns: + Tensor('MobilenetV1/Conv2d_0/BatchNorm/moments/Squeeze') + + Args: + graph: Graph to inspect. + match_pattern: Part of the name of the op that we need to match, should + be present in the op's name + scope: The scope of the op. All the elements of the scope need not be + present in the op's name. + + Returns: + Tensor from graph that provides the best match to the match_pattern and + scope + """ + + oplist = graph.get_operations() + split_context = set(scope.split('/')) + match_dict = {} + for op in oplist: + if op.name.endswith(match_pattern): + split_name = op.name.split('/') + num_matches = len(set(split_name) & split_context) + if num_matches > 0: + match_dict[op.name] = num_matches + # match_dict contains matching op names from graph with values being + # number of matches to scope. We pick the key with the most matches + if match_dict: + max_key = max(match_dict, key=match_dict.get) + return graph.get_tensor_by_name(max_key + ':0') + else: + return None + + def _GetBatchNormParams(graph, context, has_scaling): """Extracts relevant tensors for folding batch norms. @@ -500,7 +537,8 @@ def _GetBatchNormParams(graph, context, has_scaling): bn_decay_mean_tensor = None bn_decay_var_tensor = None - split_context = context.split('/') + # TODO(raghuramank) This code relies on string matching and needs to be + # updated if unfused batch norm continues to be widely used # Matching variable names is brittle and relies on scoping # conventions. Fused batch norm folding is more robust. Support for unfused # batch norms will be deprecated as we move forward. Fused batch norms allow @@ -518,49 +556,48 @@ def _GetBatchNormParams(graph, context, has_scaling): # and the names of the tensors start with a single MobilenetV2 # The moving mean for example, has the name: # MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read - # We ignore the first string (MobilenetV1 or MobilenetV2) - # in the context to match correctly in both cases - - base_context = '/'.join(split_context[1:]) - oplist = graph.get_operations() - op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze' - op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1' - op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y' - op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay' - op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay' + # We identify the best match for an op by checking for + # 1. The suffix of the op is exactly matched + # 2. Maximum number of matches with the context.The matching + # score is given by the number of parts of context (split by /) that + # are present in the parts of the tensor name (again split by /). + # For example: scope= MobilenetV2/MobilenetV2/expanded_conv_3 and + # op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read + # will have 2 matches,scope with a different conv layer will have one match. + + op_suffix_mean = '/BatchNorm/moments/Squeeze' + op_suffix_variance = '/BatchNorm/moments/Squeeze_1' + op_suffix_epsilon = '/BatchNorm/batchnorm/add/y' + op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay' + op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay' if variable_scope.get_variable_scope().use_resource: - op_suffix_gamma = base_context + '/BatchNorm/gamma/Read/ReadVariableOp' + op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp' op_suffix_moving_variance = ( - base_context + '/BatchNorm/moving_variance/Read/ReadVariableOp') - op_suffix_moving_mean = ( - base_context + '/BatchNorm/moving_mean/Read/ReadVariableOp') + '/BatchNorm/moving_variance/Read/ReadVariableOp') + op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp') else: - op_suffix_gamma = base_context + '/BatchNorm/gamma' - op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read' - op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read' + op_suffix_gamma = '/BatchNorm/gamma' + op_suffix_moving_variance = '/BatchNorm/moving_variance/read' + op_suffix_moving_mean = '/BatchNorm/moving_mean/read' # Parse through list of ops to find relevant ops - for op in oplist: - if op.name.endswith(op_suffix_mean): - # This is an efficient way to check for two things: - # Is batch norm present and is it training mode? - # Batch statistics are computed only during batch norm in training - batch_mean_tensor = graph.get_tensor_by_name(op.name + ':0') - if op.name.endswith(op_suffix_variance): - batch_variance_tensor = graph.get_tensor_by_name(op.name + ':0') - if op.name.endswith(op_suffix_moving_mean): - moving_mean_tensor = graph.get_tensor_by_name(op.name + ':0') - if op.name.endswith(op_suffix_moving_variance): - moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0') - if op.name.endswith(op_suffix_epsilon): - batch_epsilon = graph.get_tensor_by_name(op.name + ':0') - if op.name.endswith(op_suffix_bn_decay_mean): - bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0') - if op.name.endswith(op_suffix_bn_decay_var): - bn_decay_var_tensor = graph.get_tensor_by_name(op.name + ':0') - if has_scaling: - if op.name.endswith(op_suffix_gamma): - gamma_tensor = graph.get_tensor_by_name(op.name + ':0') + + batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context) + batch_variance_tensor = _FindMatchingTensor(graph, op_suffix_variance, + context) + moving_mean_tensor = _FindMatchingTensor(graph, op_suffix_moving_mean, + context) + moving_variance_tensor = _FindMatchingTensor(graph, op_suffix_moving_variance, + context) + batch_epsilon = _FindMatchingTensor(graph, op_suffix_epsilon, context) + bn_decay_mean_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_mean, + context) + bn_decay_var_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_var, + context) + if batch_mean_tensor is None and moving_mean_tensor is None: + ValueError('Error folding unfused batch norms') + if has_scaling: + gamma_tensor = _FindMatchingTensor(graph, op_suffix_gamma, context) if not has_scaling: gamma_tensor = array_ops.ones(moving_mean_tensor.shape) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 64e8142..fa5e11b 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import saver as saver_lib @@ -157,32 +158,38 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_depth = 3 stride = 1 activation_fn = relu - scope = 'network/expanded_conv_1/conv' - layer1 = conv2d( - inputs, - out_depth, [5, 5], - stride=stride, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=self._BatchNormParams( - scale=has_scaling, fused=fused_batch_norm), - scope=scope) - # Add another layer - scope = 'network/expanded_conv_2/conv' - - _ = conv2d( - layer1, - 2 * out_depth, [5, 5], - stride=stride, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=self._BatchNormParams( - scale=has_scaling, fused=fused_batch_norm), - scope=scope) + scope = 'topnet/testnet' + with variable_scope.variable_scope(scope, [inputs]): + layer1 = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=None, + scope='testnet/layer1') + # Add bn and relu with different scope + layer1 = batch_norm( + layer1, scale=has_scaling, fused=fused_batch_norm, scope='layer1') + layer1 = activation_fn(layer1) + layer2 = conv2d( + layer1, + 2 * out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope='testnet/layer2') + # Add bn and relu with different scope + layer2 = batch_norm( + layer2, scale=has_scaling, fused=fused_batch_norm, scope='layer2') + _ = activation_fn(layer2) + + scope = 'topnet/testnet/testnet/layer2' fold_batch_norms.FoldBatchNorms( g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) -- 2.7.4