return bool(add_shift.outputs[0].consumers())
+def _FindMatchingTensor(graph, match_pattern, scope):
+ """Finds best match of ops matching match_pattern with scope.
+
+ Example: _FindMatchingTensor(graph,'/BatchNorm/moments/Squeeze',
+ 'MobilenetV1/MobilenetV1/Conv2d_0/') returns:
+ Tensor('MobilenetV1/Conv2d_0/BatchNorm/moments/Squeeze')
+
+ Args:
+ graph: Graph to inspect.
+ match_pattern: Part of the name of the op that we need to match, should
+ be present in the op's name
+ scope: The scope of the op. All the elements of the scope need not be
+ present in the op's name.
+
+ Returns:
+ Tensor from graph that provides the best match to the match_pattern and
+ scope
+ """
+
+ oplist = graph.get_operations()
+ split_context = set(scope.split('/'))
+ match_dict = {}
+ for op in oplist:
+ if op.name.endswith(match_pattern):
+ split_name = op.name.split('/')
+ num_matches = len(set(split_name) & split_context)
+ if num_matches > 0:
+ match_dict[op.name] = num_matches
+ # match_dict contains matching op names from graph with values being
+ # number of matches to scope. We pick the key with the most matches
+ if match_dict:
+ max_key = max(match_dict, key=match_dict.get)
+ return graph.get_tensor_by_name(max_key + ':0')
+ else:
+ return None
+
+
def _GetBatchNormParams(graph, context, has_scaling):
"""Extracts relevant tensors for folding batch norms.
bn_decay_mean_tensor = None
bn_decay_var_tensor = None
- split_context = context.split('/')
+ # TODO(raghuramank) This code relies on string matching and needs to be
+ # updated if unfused batch norm continues to be widely used
# Matching variable names is brittle and relies on scoping
# conventions. Fused batch norm folding is more robust. Support for unfused
# batch norms will be deprecated as we move forward. Fused batch norms allow
# and the names of the tensors start with a single MobilenetV2
# The moving mean for example, has the name:
# MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
- # We ignore the first string (MobilenetV1 or MobilenetV2)
- # in the context to match correctly in both cases
-
- base_context = '/'.join(split_context[1:])
- oplist = graph.get_operations()
- op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze'
- op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y'
- op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay'
- op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay'
+ # We identify the best match for an op by checking for
+ # 1. The suffix of the op is exactly matched
+ # 2. Maximum number of matches with the context.The matching
+ # score is given by the number of parts of context (split by /) that
+ # are present in the parts of the tensor name (again split by /).
+ # For example: scope= MobilenetV2/MobilenetV2/expanded_conv_3 and
+ # op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
+ # will have 2 matches,scope with a different conv layer will have one match.
+
+ op_suffix_mean = '/BatchNorm/moments/Squeeze'
+ op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
+ op_suffix_epsilon = '/BatchNorm/batchnorm/add/y'
+ op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
+ op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
if variable_scope.get_variable_scope().use_resource:
- op_suffix_gamma = base_context + '/BatchNorm/gamma/Read/ReadVariableOp'
+ op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
op_suffix_moving_variance = (
- base_context + '/BatchNorm/moving_variance/Read/ReadVariableOp')
- op_suffix_moving_mean = (
- base_context + '/BatchNorm/moving_mean/Read/ReadVariableOp')
+ '/BatchNorm/moving_variance/Read/ReadVariableOp')
+ op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
else:
- op_suffix_gamma = base_context + '/BatchNorm/gamma'
- op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read'
- op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read'
+ op_suffix_gamma = '/BatchNorm/gamma'
+ op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
+ op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
# Parse through list of ops to find relevant ops
- for op in oplist:
- if op.name.endswith(op_suffix_mean):
- # This is an efficient way to check for two things:
- # Is batch norm present and is it training mode?
- # Batch statistics are computed only during batch norm in training
- batch_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
- if op.name.endswith(op_suffix_variance):
- batch_variance_tensor = graph.get_tensor_by_name(op.name + ':0')
- if op.name.endswith(op_suffix_moving_mean):
- moving_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
- if op.name.endswith(op_suffix_moving_variance):
- moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0')
- if op.name.endswith(op_suffix_epsilon):
- batch_epsilon = graph.get_tensor_by_name(op.name + ':0')
- if op.name.endswith(op_suffix_bn_decay_mean):
- bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
- if op.name.endswith(op_suffix_bn_decay_var):
- bn_decay_var_tensor = graph.get_tensor_by_name(op.name + ':0')
- if has_scaling:
- if op.name.endswith(op_suffix_gamma):
- gamma_tensor = graph.get_tensor_by_name(op.name + ':0')
+
+ batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
+ batch_variance_tensor = _FindMatchingTensor(graph, op_suffix_variance,
+ context)
+ moving_mean_tensor = _FindMatchingTensor(graph, op_suffix_moving_mean,
+ context)
+ moving_variance_tensor = _FindMatchingTensor(graph, op_suffix_moving_variance,
+ context)
+ batch_epsilon = _FindMatchingTensor(graph, op_suffix_epsilon, context)
+ bn_decay_mean_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_mean,
+ context)
+ bn_decay_var_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_var,
+ context)
+ if batch_mean_tensor is None and moving_mean_tensor is None:
+ ValueError('Error folding unfused batch norms')
+ if has_scaling:
+ gamma_tensor = _FindMatchingTensor(graph, op_suffix_gamma, context)
if not has_scaling:
gamma_tensor = array_ops.ones(moving_mean_tensor.shape)
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import saver as saver_lib
out_depth = 3
stride = 1
activation_fn = relu
- scope = 'network/expanded_conv_1/conv'
- layer1 = conv2d(
- inputs,
- out_depth, [5, 5],
- stride=stride,
- padding='SAME',
- weights_initializer=self._WeightInit(0.09),
- activation_fn=activation_fn,
- normalizer_fn=batch_norm,
- normalizer_params=self._BatchNormParams(
- scale=has_scaling, fused=fused_batch_norm),
- scope=scope)
- # Add another layer
- scope = 'network/expanded_conv_2/conv'
-
- _ = conv2d(
- layer1,
- 2 * out_depth, [5, 5],
- stride=stride,
- padding='SAME',
- weights_initializer=self._WeightInit(0.09),
- activation_fn=activation_fn,
- normalizer_fn=batch_norm,
- normalizer_params=self._BatchNormParams(
- scale=has_scaling, fused=fused_batch_norm),
- scope=scope)
+ scope = 'topnet/testnet'
+ with variable_scope.variable_scope(scope, [inputs]):
+ layer1 = conv2d(
+ inputs,
+ out_depth, [5, 5],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='testnet/layer1')
+ # Add bn and relu with different scope
+ layer1 = batch_norm(
+ layer1, scale=has_scaling, fused=fused_batch_norm, scope='layer1')
+ layer1 = activation_fn(layer1)
+ layer2 = conv2d(
+ layer1,
+ 2 * out_depth, [5, 5],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(
+ scale=has_scaling, fused=fused_batch_norm),
+ scope='testnet/layer2')
+ # Add bn and relu with different scope
+ layer2 = batch_norm(
+ layer2, scale=has_scaling, fused=fused_batch_norm, scope='layer2')
+ _ = activation_fn(layer2)
+
+ scope = 'topnet/testnet/testnet/layer2'
fold_batch_norms.FoldBatchNorms(
g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)