From 2271f0f8c463a01af86c9e17be38e3cfc12eae11 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Mon, 5 Feb 2018 15:32:32 -0800 Subject: [PATCH] Make fold batch norm code use OneofPattern and rearrange functions to (maybe) be more readable. PiperOrigin-RevId: 184597111 --- .../contrib/quantize/python/fold_batch_norms.py | 727 ++++++++++----------- 1 file changed, 329 insertions(+), 398 deletions(-) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 8ec5334..7fa0d48 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import re from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common @@ -120,12 +121,8 @@ def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training): weights = math_ops.multiply( correction_scale, weights, name='correction_mult') - # TODO(suharshs): This naming of the following ops needs to carefully - # follow the naming expected by quantize.py. Generalize the quantize code - # to not require these delicate naming conventions. scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') - new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor) @@ -145,46 +142,6 @@ def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training): 'Unexpected inputs to op: %s' % match.output_tensor.name) -def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): - """Clones layer_op with input_tensor and weight_tensor as new inputs.""" - new_layer_name = layer_op.name.split('/')[-1] + '_Fold' - if layer_op.type == 'Conv2D': - return nn_ops.conv2d( - input_tensor, - weight_tensor, - strides=layer_op.get_attr('strides'), - padding=layer_op.get_attr('padding'), - use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), - data_format=layer_op.get_attr('data_format'), - name=new_layer_name) - elif layer_op.type == 'MatMul': - return math_ops.matmul( - input_tensor, - weight_tensor, - transpose_a=layer_op.get_attr('transpose_a'), - transpose_b=layer_op.get_attr('transpose_b'), - name=new_layer_name) - elif layer_op.type == 'DepthwiseConv2dNative': - return nn.depthwise_conv2d( - input_tensor, - weight_tensor, - strides=layer_op.get_attr('strides'), - padding=layer_op.get_attr('padding'), - name=new_layer_name) - else: - raise ValueError('Cannot handle operation of type: %s' % layer_op.type) - - -@ops.RegisterGradient('FoldFusedBatchNormGrad') -def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, - unused_2): - x = op.inputs[0] - n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() - dmean_dx = grad_mean / n - dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) - return (dmean_dx + dvar_dx), None, None, None, None - - def _FindFusedBatchNorms(graph): """Finds all ops and tensors related to found FusedBatchNorms. @@ -203,68 +160,57 @@ def _FindFusedBatchNorms(graph): moving_average_pattern = graph_matcher.OpTypePattern('*') bn_decay_pattern = graph_matcher.OpTypePattern('*') - conv_pattern = graph_matcher.OpTypePattern( - 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + layer_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative|MatMul', + inputs=[input_pattern, weight_pattern]) # MatMul has a Reshape between it and FusedBatchNorm. - matmul_pattern = graph_matcher.OpTypePattern( - 'MatMul', inputs=[input_pattern, weight_pattern]) matmul_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', inputs=[matmul_pattern, + 'Reshape', inputs=[layer_pattern, graph_matcher.OpTypePattern('*')]) - conv_batch_norm_pattern = graph_matcher.OpTypePattern( + batch_norm_pattern = graph_matcher.OpTypePattern( 'FusedBatchNorm', inputs=[ - conv_pattern, gamma_pattern, beta_pattern, mean_pattern, - variance_pattern - ]) - conv_moving_average_sub_pattern = graph_matcher.OpTypePattern( - 'Sub', inputs=[moving_average_pattern, conv_batch_norm_pattern]) - # TODO(suharshs): Use a OneofPattern here when available - conv_moving_average_mul_pattern = graph_matcher.OpTypePattern( - 'Mul', inputs=[conv_moving_average_sub_pattern, bn_decay_pattern]) - matmul_batch_norm_pattern = graph_matcher.OpTypePattern( - 'FusedBatchNorm', - inputs=[ - matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, - variance_pattern + graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]), + gamma_pattern, beta_pattern, mean_pattern, variance_pattern ]) matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', - inputs=[matmul_batch_norm_pattern, - graph_matcher.OpTypePattern('*')]) - - matmul_moving_average_sub_pattern = graph_matcher.OpTypePattern( - 'Sub', inputs=[moving_average_pattern, matmul_batch_norm_pattern]) - matmul_moving_average_mul_pattern = graph_matcher.OpTypePattern( - 'Mul', inputs=[matmul_moving_average_sub_pattern, bn_decay_pattern]) - - conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) - matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) - conv_moving_average_mul_matcher = graph_matcher.GraphMatcher( - conv_moving_average_mul_pattern) - matmul_moving_average_mul_matcher = graph_matcher.GraphMatcher( - matmul_moving_average_mul_pattern) - - def _GetMovingAverageTensors(graph, moving_avg_mul_matcher, - moving_avg_sub_pattern, bn_op): - """Gets the moving mean and variance tensors and the batch norm momentum.""" - for mul_match_result in moving_avg_mul_matcher.match_graph(graph): - sub_op = mul_match_result.get_op(moving_avg_sub_pattern) - - if sub_op.inputs[1].name == bn_op.outputs[1].name: - # During training: Batch Mean is bn_op.outputs[1] - moving_mean_tensor = sub_op.inputs[0] - bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern) - if sub_op.inputs[1].name == bn_op.outputs[2].name: - # During training: Batch Var is bn_op.outputs[2] - moving_variance_tensor = sub_op.inputs[0] - bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern) - return (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor, - bn_decay_var_tensor) - - def _GetCommonTensors(match_result, bn_op, bn_input_tensor): - """Gets tensors needed for FusedBatchNormMatch from match_result.""" + 'Reshape', inputs=[batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + bn_matcher = graph_matcher.GraphMatcher( + graph_matcher.OneofPattern( + [matmul_bn_output_reshape_pattern, batch_norm_pattern])) + + moving_average_sub_pattern = graph_matcher.OpTypePattern( + 'Sub', inputs=[moving_average_pattern, batch_norm_pattern]) + moving_average_mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern]) + + moving_avg_mul_matcher = graph_matcher.GraphMatcher( + moving_average_mul_pattern) + + for match_result in bn_matcher.match_graph(graph): + moving_mean_tensor = None + moving_variance_tensor = None + bn_decay_mean_tensor = None + bn_decay_var_tensor = None + layer_op = match_result.get_op(layer_pattern) + layer_tensor = match_result.get_tensor(layer_pattern) + bn_op = match_result.get_op(batch_norm_pattern) + batch_epsilon_tensor = bn_op.get_attr('epsilon') + + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_tensor = bn_op.outputs[0] + if layer_op.type == 'MatMul': + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + # If the matcher didn't match matmul_bn_output_reshape, there will be + # another match for this 'MatMul' later, so we can skip this one. + if output_reshape_op is None: + continue + output_tensor = output_reshape_op.outputs[0] + input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -295,40 +241,25 @@ def _FindFusedBatchNorms(graph): g = ops.get_default_graph() with g.as_default(), g.name_scope(scope + sep): n = math_ops.cast( - array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor), + array_ops.size(layer_tensor) / array_ops.size(mean_tensor), dtypes.float32) variance_tensor = math_ops.multiply( bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction') + # TODO(suharshs): Find a way to get rid of this inner match. + for mul_match_result in moving_avg_mul_matcher.match_graph(graph): + sub_op = mul_match_result.get_op(moving_average_sub_pattern) + if sub_op.inputs[1].name == bn_op.outputs[1].name: + # During training: Batch Mean is bn_op.outputs[1] + moving_mean_tensor = sub_op.inputs[0] + bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern) + if sub_op.inputs[1].name == bn_op.outputs[2].name: + # During training: Batch Var is bn_op.outputs[2] + moving_variance_tensor = sub_op.inputs[0] + bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern) else: mean_tensor = match_result.get_tensor(mean_pattern) variance_tensor = match_result.get_tensor(variance_pattern) - return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) - - for match_result in conv_matcher.match_graph(graph): - moving_mean_tensor = None - moving_variance_tensor = None - bn_decay_mean_tensor = None - bn_decay_var_tensor = None - layer_op = match_result.get_op(conv_pattern) - layer_tensor = match_result.get_tensor(conv_pattern) - bn_op = match_result.get_op(conv_batch_norm_pattern) - if bn_op.get_attr('is_training'): - (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor, - bn_decay_var_tensor) = _GetMovingAverageTensors( - graph, - moving_avg_mul_matcher=conv_moving_average_mul_matcher, - moving_avg_sub_pattern=conv_moving_average_sub_pattern, - bn_op=bn_op) - output_tensor = bn_op.outputs[0] - batch_epsilon_tensor = bn_op.get_attr('epsilon') - (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) = _GetCommonTensors( - match_result, - bn_op, - layer_tensor, - ) yield _BatchNormMatch( layer_op=layer_op, bn_op=bn_op, @@ -345,124 +276,146 @@ def _FindFusedBatchNorms(graph): bn_decay_var_tensor=bn_decay_var_tensor, batch_epsilon_tensor=batch_epsilon_tensor) - for match_result in matmul_matcher.match_graph(graph): - moving_mean_tensor = None - moving_variance_tensor = None - bn_decay_mean_tensor = None - bn_decay_var_tensor = None - layer_op = match_result.get_op(matmul_pattern) - layer_tensor = match_result.get_tensor(matmul_pattern) - bn_op = match_result.get_op(matmul_batch_norm_pattern) - if bn_op.get_attr('is_training'): - (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor, - bn_decay_var_tensor) = _GetMovingAverageTensors( - graph, - moving_avg_mul_matcher=matmul_moving_average_mul_matcher, - moving_avg_sub_pattern=matmul_moving_average_sub_pattern, - bn_op=bn_op) - - # In the MatMul case, the output of batch norm is reshaped back into a - # 2D tensor, so the output_tensor is the output of the Reshape op. - output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) - output_tensor = output_reshape_op.outputs[0] - batch_epsilon_tensor = bn_op.get_attr('epsilon') - (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor) - yield _BatchNormMatch( - layer_op=layer_op, - bn_op=bn_op, - output_tensor=output_tensor, - input_tensor=input_tensor, - weight_tensor=weight_tensor, - gamma_tensor=gamma_tensor, - beta_tensor=beta_tensor, - mean_tensor=mean_tensor, - variance_tensor=variance_tensor, - moving_mean_tensor=moving_mean_tensor, - moving_variance_tensor=moving_variance_tensor, - bn_decay_mean_tensor=bn_decay_mean_tensor, - bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon_tensor=batch_epsilon_tensor) +def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, + fused_batch_norm): + """Computes batch norm correction params. + Before batch normalization is frozen: + We use batch statistics for batch norm. + correction_scale = sigma_b/sigma_mv + correction_recip = 1/correction_scale + correction_offset = 0 -class _BatchNormMatch(object): - """Contains all information related to a found Fused/UnfusedBatchNorm.""" + After batch normalization is frozen: + correction_scale = sigma_b/sigma_mv + correction_recip = 1 + correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). - def __init__(self, layer_op, bn_op, output_tensor, input_tensor, - weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor, moving_mean_tensor, moving_variance_tensor, - bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor): - self._layer_op = layer_op - self._bn_op = bn_op - self._output_tensor = output_tensor - self._input_tensor = input_tensor - self._weight_tensor = weight_tensor - self._gamma_tensor = gamma_tensor - self._beta_tensor = beta_tensor - self._mean_tensor = mean_tensor - self._variance_tensor = variance_tensor - self._moving_mean_tensor = moving_mean_tensor - self._moving_variance_tensor = moving_variance_tensor - self._bn_decay_mean_tensor = bn_decay_mean_tensor - self._bn_decay_var_tensor = bn_decay_var_tensor - self._batch_epsilon_tensor = batch_epsilon_tensor + Batch norm is frozen if global_step > bn_freeze_delay. + The corrections ensure that: + a) The weights are quantized after scaling by gamma/sigma_mv. This enables + smoother training as the scaling on the weights changes slowly, rather than + jump across mini-batches + b) Changing the values of the corrections allows for one to switch between + using batch statistics to using moving mean and average, without requiring + changes to batch_norm - @property - def layer_op(self): - return self._layer_op - @property - def bn_op(self): - return self._bn_op + Args: + context: The scope under which we look for batch norm params + match: Object containg required batch norm tensors for correction + computation + freeze_batch_norm_delay: Delay in steps at which computation switches + from regular batch norm to frozen mean and variance. + fused_batch_norm: Bool, true if fused batch norm is used - @property - def output_tensor(self): - return self._output_tensor + Returns: + A tuple of correction_scale, correction_recip, correction_offset + """ - @property - def input_tensor(self): - return self._input_tensor + g = ops.get_default_graph() + with g.name_scope(context + '/batch_norm_correction'): + recip_sigma_mv = math_ops.rsqrt( + match.moving_variance_tensor + match.batch_epsilon_tensor) + recip_sigma = math_ops.rsqrt( + match.variance_tensor + match.batch_epsilon_tensor) + correction_scale = math_ops.divide( + recip_sigma_mv, recip_sigma, name='scale_compute') + correction_scale = array_ops.identity( + correction_scale, name='correction_scale') + correction_recip = math_ops.reciprocal( + correction_scale, name='reciprocal_compute') + correction_offset = math_ops.multiply( + match.gamma_tensor, + match.mean_tensor * recip_sigma - + match.moving_mean_tensor * recip_sigma_mv, + name='offset_compute') - @property - def weight_tensor(self): - return self._weight_tensor + if freeze_batch_norm_delay is not None: + use_mv_avg = math_ops.greater_equal( + training_util.get_or_create_global_step(), + freeze_batch_norm_delay, + name='use_moving_average') + else: + use_mv_avg = False - @property - def gamma_tensor(self): - return self._gamma_tensor + bn_decay_zero = 0.0 + bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) + bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) - @property - def beta_tensor(self): - return self._beta_tensor + bn_decay_mean_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_mean_tensor, + name='freeze_moving_mean') + graph_editor.reroute_ts( + [bn_decay_mean_out], [match.bn_decay_mean_tensor], + can_modify=bn_decay_mean_consumers) - @property - def mean_tensor(self): - return self._mean_tensor + if fused_batch_norm is False: + bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) + bn_decay_var_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_var_tensor, + name='freeze_moving_var') + graph_editor.reroute_ts( + [bn_decay_var_out], [match.bn_decay_var_tensor], + can_modify=bn_decay_var_consumers) - @property - def variance_tensor(self): - return self._variance_tensor + correction_recip = utils.smart_cond( + use_mv_avg, + lambda: array_ops.ones(correction_scale.shape), + lambda: correction_recip, + name='correction_recip') - @property - def moving_mean_tensor(self): - return self._moving_mean_tensor + correction_offset = utils.smart_cond( + use_mv_avg, + lambda: correction_offset, + lambda: array_ops.zeros(correction_offset.shape), + name='correction_offset') + return correction_scale, correction_recip, correction_offset - @property - def moving_variance_tensor(self): - return self._moving_variance_tensor - @property - def batch_epsilon_tensor(self): - return self._batch_epsilon_tensor +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) - @property - def bn_decay_mean_tensor(self): - return self._bn_decay_mean_tensor - @property - def bn_decay_var_tensor(self): - return self._bn_decay_var_tensor +@ops.RegisterGradient('FoldFusedBatchNormGrad') +def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, + unused_2): + x = op.inputs[0] + n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() + dmean_dx = grad_mean / n + dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) + return (dmean_dx + dvar_dx), None, None, None, None def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training): @@ -475,81 +428,42 @@ def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training): graph: Graph to walk and modify. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization - is_training: Bool, True if training - - Raises: - ValueError: When batch norm folding fails. - """ - input_to_ops_map = input_to_ops.InputToOps(graph) - - for bn in common.BatchNormGroups(graph): - has_scaling = _HasScaling(graph, input_to_ops_map, bn) - - # The mangling code intimately depends on BatchNorm node's internals. - original_op, folded_op = _CreateFoldedOp( - graph, - bn, - has_scaling=has_scaling, - freeze_batch_norm_delay=freeze_batch_norm_delay, - is_training=is_training) - - activation = common.GetEndpointActivationOp(graph, bn) - if activation: - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[activation]) - if nodes_modified_count != 1: - raise ValueError('Unexpected inputs to op: %s' % activation.name) - continue - - # Treat consumer ops in bypass modules differently since they have Add - # operations instead of Relu* above. - add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) - add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[add_bypass]) - if nodes_modified_count != 1: - raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) - - -def _HasScaling(graph, input_to_ops_map, bn): - r"""Checks if batch norm has scaling enabled. - - Difference between batch norm with scaling and without is that with scaling: - - Rsqrt -> mul -> mul_1 - \-> mul_2 - - where - mul multiplies gamma by inverse square root of EMA of batch variance, - mul_1 multiplies output of mul with output from the base operation - (convolution, FC or depthwise convolution), - mul_2 multiplies output of mul with EMA of batch mean, - and without scaling: - - Rsqrt -> mul - \-> mul_1 - - where - mul multiplies the inverse square root of EMA of batch variance with output - from the base operation, - mul_1 multiplies inverse square root of EMA of batch variance with EMA - of batch mean. - - Args: - graph: Graph to inspect. - input_to_ops_map: InputToOps object containing mapping from tensor's name - to ops that take it as input. - bn: Batch norm layer prefix string. + is_training: Bool, True if training - Returns: - A boolean indicating whether this batch norm layer has scaling enabled. + Raises: + ValueError: When batch norm folding fails. """ - rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') - rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) + input_to_ops_map = input_to_ops.InputToOps(graph) - return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 + for bn in common.BatchNormGroups(graph): + has_scaling = _HasScaling(graph, input_to_ops_map, bn) + + # The mangling code intimately depends on BatchNorm node's internals. + original_op, folded_op = _CreateFoldedOp( + graph, + bn, + has_scaling=has_scaling, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) + + activation = common.GetEndpointActivationOp(graph, bn) + if activation: + nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], + [original_op.outputs[0]], + can_modify=[activation]) + if nodes_modified_count != 1: + raise ValueError('Unexpected inputs to op: %s' % activation.name) + continue + + # Treat consumer ops in bypass modules differently since they have Add + # operations instead of Relu* above. + add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) + add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') + nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], + [original_op.outputs[0]], + can_modify=[add_bypass]) + if nodes_modified_count != 1: + raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) def _GetBatchNormParams(graph, context, has_scaling): @@ -629,107 +543,6 @@ def _GetBatchNormParams(graph, context, has_scaling): batch_epsilon_tensor=batch_epsilon_tensor) -def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, - fused_batch_norm): - """Computes batch norm correction params. - - Before batch normalization is frozen: - We use batch statistics for batch norm. - correction_scale = sigma_b/sigma_mv - correction_recip = 1/correction_scale - correction_offset = 0 - - After batch normalization is frozen: - correction_scale = sigma_b/sigma_mv - correction_recip = 1 - correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). - - Batch norm is frozen if global_step > bn_freeze_delay. - The corrections ensure that: - a) The weights are quantized after scaling by gamma/sigma_mv. This enables - smoother training as the scaling on the weights changes slowly, rather than - jump across mini-batches - b) Changing the values of the corrections allows for one to switch between - using batch statistics to using moving mean and average, without requiring - changes to batch_norm - - - Args: - context: The scope under which we look for batch norm params - match: Object containg required batch norm tensors for correction - computation - freeze_batch_norm_delay: Delay in steps at which computation switches - from regular batch norm to frozen mean and variance. - fused_batch_norm: Bool, true if fused batch norm is used - - Returns: - A tuple of correction_scale, correction_recip, correction_offset - """ - - g = ops.get_default_graph() - with g.name_scope(context + 'batch_norm_correction'): - recip_sigma_mv = math_ops.rsqrt( - match.moving_variance_tensor + match.batch_epsilon_tensor) - recip_sigma = math_ops.rsqrt( - match.variance_tensor + match.batch_epsilon_tensor) - correction_scale = math_ops.divide( - recip_sigma_mv, recip_sigma, name='scale_compute') - correction_scale = array_ops.identity( - correction_scale, name='correction_scale') - correction_recip = math_ops.reciprocal( - correction_scale, name='reciprocal_compute') - correction_offset = math_ops.multiply( - match.gamma_tensor, - match.mean_tensor * recip_sigma - - match.moving_mean_tensor * recip_sigma_mv, - name='offset_compute') - - if freeze_batch_norm_delay is not None: - use_mv_avg = math_ops.greater_equal( - training_util.get_or_create_global_step(), - freeze_batch_norm_delay, - name='use_moving_average') - else: - use_mv_avg = False - - bn_decay_zero = 0.0 - bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) - bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) - - bn_decay_mean_out = utils.smart_cond( - use_mv_avg, - lambda: bn_decay_zero, - lambda: match.bn_decay_mean_tensor, - name='freeze_moving_mean') - graph_editor.reroute_ts( - [bn_decay_mean_out], [match.bn_decay_mean_tensor], - can_modify=bn_decay_mean_consumers) - - if fused_batch_norm is False: - bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) - bn_decay_var_out = utils.smart_cond( - use_mv_avg, - lambda: bn_decay_zero, - lambda: match.bn_decay_var_tensor, - name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], - can_modify=bn_decay_var_consumers) - - correction_recip = utils.smart_cond( - use_mv_avg, - lambda: array_ops.ones(correction_scale.shape), - lambda: correction_recip, - name='correction_recip') - - correction_offset = utils.smart_cond( - use_mv_avg, - lambda: correction_offset, - lambda: array_ops.zeros(correction_offset.shape), - name='correction_offset') - return correction_scale, correction_recip, correction_offset - - def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, is_training): """Folds in batch norm layer into preceding convolution or FC layer. @@ -961,3 +774,121 @@ def _AssertShapesMatch(op_name, in_tensor, out_tensor): if not in_shape.is_compatible_with(out_shape): raise ValueError('%s should not change tensor shape: input %s, ' 'output %s' % (op_name, in_shape, out_shape)) + + +def _HasScaling(graph, input_to_ops_map, bn): + r"""Checks if batch norm has scaling enabled. + + Difference between batch norm with scaling and without is that with scaling: + + Rsqrt -> mul -> mul_1 + \-> mul_2 + + where + mul multiplies gamma by inverse square root of EMA of batch variance, + mul_1 multiplies output of mul with output from the base operation + (convolution, FC or depthwise convolution), + mul_2 multiplies output of mul with EMA of batch mean, + and without scaling: + + Rsqrt -> mul + \-> mul_1 + + where + mul multiplies the inverse square root of EMA of batch variance with output + from the base operation, + mul_1 multiplies inverse square root of EMA of batch variance with EMA + of batch mean. + + Args: + graph: Graph to inspect. + input_to_ops_map: InputToOps object containing mapping from tensor's name + to ops that take it as input. + bn: Batch norm layer prefix string. + + Returns: + A boolean indicating whether this batch norm layer has scaling enabled. + """ + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) + + return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 + + +class _BatchNormMatch(object): + """Contains all information related to a found Fused/UnfusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor, moving_mean_tensor, moving_variance_tensor, + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + self._moving_mean_tensor = moving_mean_tensor + self._moving_variance_tensor = moving_variance_tensor + self._bn_decay_mean_tensor = bn_decay_mean_tensor + self._bn_decay_var_tensor = bn_decay_var_tensor + self._batch_epsilon_tensor = batch_epsilon_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + @property + def moving_mean_tensor(self): + return self._moving_mean_tensor + + @property + def moving_variance_tensor(self): + return self._moving_variance_tensor + + @property + def batch_epsilon_tensor(self): + return self._batch_epsilon_tensor + + @property + def bn_decay_mean_tensor(self): + return self._bn_decay_mean_tensor + + @property + def bn_decay_var_tensor(self): + return self._bn_decay_var_tensor -- 2.7.4