from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import re
from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
weights = math_ops.multiply(
correction_scale, weights, name='correction_mult')
- # TODO(suharshs): This naming of the following ops needs to carefully
- # follow the naming expected by quantize.py. Generalize the quantize code
- # to not require these delicate naming conventions.
scaled_weight_tensor = math_ops.multiply(
weights, multiplier_tensor, name='mul_fold')
-
new_layer_tensor = _CloneWithNewOperands(
match.layer_op, match.input_tensor, scaled_weight_tensor)
'Unexpected inputs to op: %s' % match.output_tensor.name)
-def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
- """Clones layer_op with input_tensor and weight_tensor as new inputs."""
- new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
- if layer_op.type == 'Conv2D':
- return nn_ops.conv2d(
- input_tensor,
- weight_tensor,
- strides=layer_op.get_attr('strides'),
- padding=layer_op.get_attr('padding'),
- use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
- data_format=layer_op.get_attr('data_format'),
- name=new_layer_name)
- elif layer_op.type == 'MatMul':
- return math_ops.matmul(
- input_tensor,
- weight_tensor,
- transpose_a=layer_op.get_attr('transpose_a'),
- transpose_b=layer_op.get_attr('transpose_b'),
- name=new_layer_name)
- elif layer_op.type == 'DepthwiseConv2dNative':
- return nn.depthwise_conv2d(
- input_tensor,
- weight_tensor,
- strides=layer_op.get_attr('strides'),
- padding=layer_op.get_attr('padding'),
- name=new_layer_name)
- else:
- raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
-
-
-@ops.RegisterGradient('FoldFusedBatchNormGrad')
-def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1,
- unused_2):
- x = op.inputs[0]
- n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements()
- dmean_dx = grad_mean / n
- dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1)
- return (dmean_dx + dvar_dx), None, None, None, None
-
-
def _FindFusedBatchNorms(graph):
"""Finds all ops and tensors related to found FusedBatchNorms.
moving_average_pattern = graph_matcher.OpTypePattern('*')
bn_decay_pattern = graph_matcher.OpTypePattern('*')
- conv_pattern = graph_matcher.OpTypePattern(
- 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
+ layer_pattern = graph_matcher.OpTypePattern(
+ 'Conv2D|DepthwiseConv2dNative|MatMul',
+ inputs=[input_pattern, weight_pattern])
# MatMul has a Reshape between it and FusedBatchNorm.
- matmul_pattern = graph_matcher.OpTypePattern(
- 'MatMul', inputs=[input_pattern, weight_pattern])
matmul_reshape_pattern = graph_matcher.OpTypePattern(
- 'Reshape', inputs=[matmul_pattern,
+ 'Reshape', inputs=[layer_pattern,
graph_matcher.OpTypePattern('*')])
- conv_batch_norm_pattern = graph_matcher.OpTypePattern(
+ batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
- conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
- variance_pattern
- ])
- conv_moving_average_sub_pattern = graph_matcher.OpTypePattern(
- 'Sub', inputs=[moving_average_pattern, conv_batch_norm_pattern])
- # TODO(suharshs): Use a OneofPattern here when available
- conv_moving_average_mul_pattern = graph_matcher.OpTypePattern(
- 'Mul', inputs=[conv_moving_average_sub_pattern, bn_decay_pattern])
- matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
- 'FusedBatchNorm',
- inputs=[
- matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
- variance_pattern
+ graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]),
+ gamma_pattern, beta_pattern, mean_pattern, variance_pattern
])
matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
- 'Reshape',
- inputs=[matmul_batch_norm_pattern,
- graph_matcher.OpTypePattern('*')])
-
- matmul_moving_average_sub_pattern = graph_matcher.OpTypePattern(
- 'Sub', inputs=[moving_average_pattern, matmul_batch_norm_pattern])
- matmul_moving_average_mul_pattern = graph_matcher.OpTypePattern(
- 'Mul', inputs=[matmul_moving_average_sub_pattern, bn_decay_pattern])
-
- conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
- matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
- conv_moving_average_mul_matcher = graph_matcher.GraphMatcher(
- conv_moving_average_mul_pattern)
- matmul_moving_average_mul_matcher = graph_matcher.GraphMatcher(
- matmul_moving_average_mul_pattern)
-
- def _GetMovingAverageTensors(graph, moving_avg_mul_matcher,
- moving_avg_sub_pattern, bn_op):
- """Gets the moving mean and variance tensors and the batch norm momentum."""
- for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
- sub_op = mul_match_result.get_op(moving_avg_sub_pattern)
-
- if sub_op.inputs[1].name == bn_op.outputs[1].name:
- # During training: Batch Mean is bn_op.outputs[1]
- moving_mean_tensor = sub_op.inputs[0]
- bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
- if sub_op.inputs[1].name == bn_op.outputs[2].name:
- # During training: Batch Var is bn_op.outputs[2]
- moving_variance_tensor = sub_op.inputs[0]
- bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
- return (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
- bn_decay_var_tensor)
-
- def _GetCommonTensors(match_result, bn_op, bn_input_tensor):
- """Gets tensors needed for FusedBatchNormMatch from match_result."""
+ 'Reshape', inputs=[batch_norm_pattern,
+ graph_matcher.OpTypePattern('*')])
+
+ bn_matcher = graph_matcher.GraphMatcher(
+ graph_matcher.OneofPattern(
+ [matmul_bn_output_reshape_pattern, batch_norm_pattern]))
+
+ moving_average_sub_pattern = graph_matcher.OpTypePattern(
+ 'Sub', inputs=[moving_average_pattern, batch_norm_pattern])
+ moving_average_mul_pattern = graph_matcher.OpTypePattern(
+ 'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern])
+
+ moving_avg_mul_matcher = graph_matcher.GraphMatcher(
+ moving_average_mul_pattern)
+
+ for match_result in bn_matcher.match_graph(graph):
+ moving_mean_tensor = None
+ moving_variance_tensor = None
+ bn_decay_mean_tensor = None
+ bn_decay_var_tensor = None
+ layer_op = match_result.get_op(layer_pattern)
+ layer_tensor = match_result.get_tensor(layer_pattern)
+ bn_op = match_result.get_op(batch_norm_pattern)
+ batch_epsilon_tensor = bn_op.get_attr('epsilon')
+
+ # In the MatMul case, the output of batch norm is reshaped back into a
+ # 2D tensor, so the output_tensor is the output of the Reshape op.
+ output_tensor = bn_op.outputs[0]
+ if layer_op.type == 'MatMul':
+ output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
+ # If the matcher didn't match matmul_bn_output_reshape, there will be
+ # another match for this 'MatMul' later, so we can skip this one.
+ if output_reshape_op is None:
+ continue
+ output_tensor = output_reshape_op.outputs[0]
+
input_tensor = match_result.get_tensor(input_pattern)
weight_tensor = match_result.get_tensor(weight_pattern)
gamma_tensor = match_result.get_tensor(gamma_pattern)
g = ops.get_default_graph()
with g.as_default(), g.name_scope(scope + sep):
n = math_ops.cast(
- array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor),
+ array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
dtypes.float32)
variance_tensor = math_ops.multiply(
bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction')
+ # TODO(suharshs): Find a way to get rid of this inner match.
+ for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
+ sub_op = mul_match_result.get_op(moving_average_sub_pattern)
+ if sub_op.inputs[1].name == bn_op.outputs[1].name:
+ # During training: Batch Mean is bn_op.outputs[1]
+ moving_mean_tensor = sub_op.inputs[0]
+ bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
+ if sub_op.inputs[1].name == bn_op.outputs[2].name:
+ # During training: Batch Var is bn_op.outputs[2]
+ moving_variance_tensor = sub_op.inputs[0]
+ bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
else:
mean_tensor = match_result.get_tensor(mean_pattern)
variance_tensor = match_result.get_tensor(variance_pattern)
- return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
- variance_tensor)
-
- for match_result in conv_matcher.match_graph(graph):
- moving_mean_tensor = None
- moving_variance_tensor = None
- bn_decay_mean_tensor = None
- bn_decay_var_tensor = None
- layer_op = match_result.get_op(conv_pattern)
- layer_tensor = match_result.get_tensor(conv_pattern)
- bn_op = match_result.get_op(conv_batch_norm_pattern)
- if bn_op.get_attr('is_training'):
- (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
- bn_decay_var_tensor) = _GetMovingAverageTensors(
- graph,
- moving_avg_mul_matcher=conv_moving_average_mul_matcher,
- moving_avg_sub_pattern=conv_moving_average_sub_pattern,
- bn_op=bn_op)
- output_tensor = bn_op.outputs[0]
- batch_epsilon_tensor = bn_op.get_attr('epsilon')
- (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
- variance_tensor) = _GetCommonTensors(
- match_result,
- bn_op,
- layer_tensor,
- )
yield _BatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
bn_decay_var_tensor=bn_decay_var_tensor,
batch_epsilon_tensor=batch_epsilon_tensor)
- for match_result in matmul_matcher.match_graph(graph):
- moving_mean_tensor = None
- moving_variance_tensor = None
- bn_decay_mean_tensor = None
- bn_decay_var_tensor = None
- layer_op = match_result.get_op(matmul_pattern)
- layer_tensor = match_result.get_tensor(matmul_pattern)
- bn_op = match_result.get_op(matmul_batch_norm_pattern)
- if bn_op.get_attr('is_training'):
- (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
- bn_decay_var_tensor) = _GetMovingAverageTensors(
- graph,
- moving_avg_mul_matcher=matmul_moving_average_mul_matcher,
- moving_avg_sub_pattern=matmul_moving_average_sub_pattern,
- bn_op=bn_op)
-
- # In the MatMul case, the output of batch norm is reshaped back into a
- # 2D tensor, so the output_tensor is the output of the Reshape op.
- output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
- output_tensor = output_reshape_op.outputs[0]
- batch_epsilon_tensor = bn_op.get_attr('epsilon')
- (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
- variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
- yield _BatchNormMatch(
- layer_op=layer_op,
- bn_op=bn_op,
- output_tensor=output_tensor,
- input_tensor=input_tensor,
- weight_tensor=weight_tensor,
- gamma_tensor=gamma_tensor,
- beta_tensor=beta_tensor,
- mean_tensor=mean_tensor,
- variance_tensor=variance_tensor,
- moving_mean_tensor=moving_mean_tensor,
- moving_variance_tensor=moving_variance_tensor,
- bn_decay_mean_tensor=bn_decay_mean_tensor,
- bn_decay_var_tensor=bn_decay_var_tensor,
- batch_epsilon_tensor=batch_epsilon_tensor)
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
+ fused_batch_norm):
+ """Computes batch norm correction params.
+ Before batch normalization is frozen:
+ We use batch statistics for batch norm.
+ correction_scale = sigma_b/sigma_mv
+ correction_recip = 1/correction_scale
+ correction_offset = 0
-class _BatchNormMatch(object):
- """Contains all information related to a found Fused/UnfusedBatchNorm."""
+ After batch normalization is frozen:
+ correction_scale = sigma_b/sigma_mv
+ correction_recip = 1
+ correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
- def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
- weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
- variance_tensor, moving_mean_tensor, moving_variance_tensor,
- bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor):
- self._layer_op = layer_op
- self._bn_op = bn_op
- self._output_tensor = output_tensor
- self._input_tensor = input_tensor
- self._weight_tensor = weight_tensor
- self._gamma_tensor = gamma_tensor
- self._beta_tensor = beta_tensor
- self._mean_tensor = mean_tensor
- self._variance_tensor = variance_tensor
- self._moving_mean_tensor = moving_mean_tensor
- self._moving_variance_tensor = moving_variance_tensor
- self._bn_decay_mean_tensor = bn_decay_mean_tensor
- self._bn_decay_var_tensor = bn_decay_var_tensor
- self._batch_epsilon_tensor = batch_epsilon_tensor
+ Batch norm is frozen if global_step > bn_freeze_delay.
+ The corrections ensure that:
+ a) The weights are quantized after scaling by gamma/sigma_mv. This enables
+ smoother training as the scaling on the weights changes slowly, rather than
+ jump across mini-batches
+ b) Changing the values of the corrections allows for one to switch between
+ using batch statistics to using moving mean and average, without requiring
+ changes to batch_norm
- @property
- def layer_op(self):
- return self._layer_op
- @property
- def bn_op(self):
- return self._bn_op
+ Args:
+ context: The scope under which we look for batch norm params
+ match: Object containg required batch norm tensors for correction
+ computation
+ freeze_batch_norm_delay: Delay in steps at which computation switches
+ from regular batch norm to frozen mean and variance.
+ fused_batch_norm: Bool, true if fused batch norm is used
- @property
- def output_tensor(self):
- return self._output_tensor
+ Returns:
+ A tuple of correction_scale, correction_recip, correction_offset
+ """
- @property
- def input_tensor(self):
- return self._input_tensor
+ g = ops.get_default_graph()
+ with g.name_scope(context + '/batch_norm_correction'):
+ recip_sigma_mv = math_ops.rsqrt(
+ match.moving_variance_tensor + match.batch_epsilon_tensor)
+ recip_sigma = math_ops.rsqrt(
+ match.variance_tensor + match.batch_epsilon_tensor)
+ correction_scale = math_ops.divide(
+ recip_sigma_mv, recip_sigma, name='scale_compute')
+ correction_scale = array_ops.identity(
+ correction_scale, name='correction_scale')
+ correction_recip = math_ops.reciprocal(
+ correction_scale, name='reciprocal_compute')
+ correction_offset = math_ops.multiply(
+ match.gamma_tensor,
+ match.mean_tensor * recip_sigma -
+ match.moving_mean_tensor * recip_sigma_mv,
+ name='offset_compute')
- @property
- def weight_tensor(self):
- return self._weight_tensor
+ if freeze_batch_norm_delay is not None:
+ use_mv_avg = math_ops.greater_equal(
+ training_util.get_or_create_global_step(),
+ freeze_batch_norm_delay,
+ name='use_moving_average')
+ else:
+ use_mv_avg = False
- @property
- def gamma_tensor(self):
- return self._gamma_tensor
+ bn_decay_zero = 0.0
+ bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
+ bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
- @property
- def beta_tensor(self):
- return self._beta_tensor
+ bn_decay_mean_out = utils.smart_cond(
+ use_mv_avg,
+ lambda: bn_decay_zero,
+ lambda: match.bn_decay_mean_tensor,
+ name='freeze_moving_mean')
+ graph_editor.reroute_ts(
+ [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ can_modify=bn_decay_mean_consumers)
- @property
- def mean_tensor(self):
- return self._mean_tensor
+ if fused_batch_norm is False:
+ bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
+ bn_decay_var_out = utils.smart_cond(
+ use_mv_avg,
+ lambda: bn_decay_zero,
+ lambda: match.bn_decay_var_tensor,
+ name='freeze_moving_var')
+ graph_editor.reroute_ts(
+ [bn_decay_var_out], [match.bn_decay_var_tensor],
+ can_modify=bn_decay_var_consumers)
- @property
- def variance_tensor(self):
- return self._variance_tensor
+ correction_recip = utils.smart_cond(
+ use_mv_avg,
+ lambda: array_ops.ones(correction_scale.shape),
+ lambda: correction_recip,
+ name='correction_recip')
- @property
- def moving_mean_tensor(self):
- return self._moving_mean_tensor
+ correction_offset = utils.smart_cond(
+ use_mv_avg,
+ lambda: correction_offset,
+ lambda: array_ops.zeros(correction_offset.shape),
+ name='correction_offset')
+ return correction_scale, correction_recip, correction_offset
- @property
- def moving_variance_tensor(self):
- return self._moving_variance_tensor
- @property
- def batch_epsilon_tensor(self):
- return self._batch_epsilon_tensor
+def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
+ """Clones layer_op with input_tensor and weight_tensor as new inputs."""
+ new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
+ if layer_op.type == 'Conv2D':
+ return nn_ops.conv2d(
+ input_tensor,
+ weight_tensor,
+ strides=layer_op.get_attr('strides'),
+ padding=layer_op.get_attr('padding'),
+ use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
+ data_format=layer_op.get_attr('data_format'),
+ name=new_layer_name)
+ elif layer_op.type == 'MatMul':
+ return math_ops.matmul(
+ input_tensor,
+ weight_tensor,
+ transpose_a=layer_op.get_attr('transpose_a'),
+ transpose_b=layer_op.get_attr('transpose_b'),
+ name=new_layer_name)
+ elif layer_op.type == 'DepthwiseConv2dNative':
+ return nn.depthwise_conv2d(
+ input_tensor,
+ weight_tensor,
+ strides=layer_op.get_attr('strides'),
+ padding=layer_op.get_attr('padding'),
+ name=new_layer_name)
+ else:
+ raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
- @property
- def bn_decay_mean_tensor(self):
- return self._bn_decay_mean_tensor
- @property
- def bn_decay_var_tensor(self):
- return self._bn_decay_var_tensor
+@ops.RegisterGradient('FoldFusedBatchNormGrad')
+def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1,
+ unused_2):
+ x = op.inputs[0]
+ n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements()
+ dmean_dx = grad_mean / n
+ dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1)
+ return (dmean_dx + dvar_dx), None, None, None, None
def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
graph: Graph to walk and modify.
freeze_batch_norm_delay: How many steps to wait before freezing
moving mean and variance and using them for batch normalization
- is_training: Bool, True if training
-
- Raises:
- ValueError: When batch norm folding fails.
- """
- input_to_ops_map = input_to_ops.InputToOps(graph)
-
- for bn in common.BatchNormGroups(graph):
- has_scaling = _HasScaling(graph, input_to_ops_map, bn)
-
- # The mangling code intimately depends on BatchNorm node's internals.
- original_op, folded_op = _CreateFoldedOp(
- graph,
- bn,
- has_scaling=has_scaling,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- is_training=is_training)
-
- activation = common.GetEndpointActivationOp(graph, bn)
- if activation:
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[activation])
- if nodes_modified_count != 1:
- raise ValueError('Unexpected inputs to op: %s' % activation.name)
- continue
-
- # Treat consumer ops in bypass modules differently since they have Add
- # operations instead of Relu* above.
- add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
- add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[add_bypass])
- if nodes_modified_count != 1:
- raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
-
-
-def _HasScaling(graph, input_to_ops_map, bn):
- r"""Checks if batch norm has scaling enabled.
-
- Difference between batch norm with scaling and without is that with scaling:
-
- Rsqrt -> mul -> mul_1
- \-> mul_2
-
- where
- mul multiplies gamma by inverse square root of EMA of batch variance,
- mul_1 multiplies output of mul with output from the base operation
- (convolution, FC or depthwise convolution),
- mul_2 multiplies output of mul with EMA of batch mean,
- and without scaling:
-
- Rsqrt -> mul
- \-> mul_1
-
- where
- mul multiplies the inverse square root of EMA of batch variance with output
- from the base operation,
- mul_1 multiplies inverse square root of EMA of batch variance with EMA
- of batch mean.
-
- Args:
- graph: Graph to inspect.
- input_to_ops_map: InputToOps object containing mapping from tensor's name
- to ops that take it as input.
- bn: Batch norm layer prefix string.
+ is_training: Bool, True if training
- Returns:
- A boolean indicating whether this batch norm layer has scaling enabled.
+ Raises:
+ ValueError: When batch norm folding fails.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
- rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
+ input_to_ops_map = input_to_ops.InputToOps(graph)
- return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
+ for bn in common.BatchNormGroups(graph):
+ has_scaling = _HasScaling(graph, input_to_ops_map, bn)
+
+ # The mangling code intimately depends on BatchNorm node's internals.
+ original_op, folded_op = _CreateFoldedOp(
+ graph,
+ bn,
+ has_scaling=has_scaling,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ is_training=is_training)
+
+ activation = common.GetEndpointActivationOp(graph, bn)
+ if activation:
+ nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
+ [original_op.outputs[0]],
+ can_modify=[activation])
+ if nodes_modified_count != 1:
+ raise ValueError('Unexpected inputs to op: %s' % activation.name)
+ continue
+
+ # Treat consumer ops in bypass modules differently since they have Add
+ # operations instead of Relu* above.
+ add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+ add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+ nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
+ [original_op.outputs[0]],
+ can_modify=[add_bypass])
+ if nodes_modified_count != 1:
+ raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def _GetBatchNormParams(graph, context, has_scaling):
batch_epsilon_tensor=batch_epsilon_tensor)
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
- fused_batch_norm):
- """Computes batch norm correction params.
-
- Before batch normalization is frozen:
- We use batch statistics for batch norm.
- correction_scale = sigma_b/sigma_mv
- correction_recip = 1/correction_scale
- correction_offset = 0
-
- After batch normalization is frozen:
- correction_scale = sigma_b/sigma_mv
- correction_recip = 1
- correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
-
- Batch norm is frozen if global_step > bn_freeze_delay.
- The corrections ensure that:
- a) The weights are quantized after scaling by gamma/sigma_mv. This enables
- smoother training as the scaling on the weights changes slowly, rather than
- jump across mini-batches
- b) Changing the values of the corrections allows for one to switch between
- using batch statistics to using moving mean and average, without requiring
- changes to batch_norm
-
-
- Args:
- context: The scope under which we look for batch norm params
- match: Object containg required batch norm tensors for correction
- computation
- freeze_batch_norm_delay: Delay in steps at which computation switches
- from regular batch norm to frozen mean and variance.
- fused_batch_norm: Bool, true if fused batch norm is used
-
- Returns:
- A tuple of correction_scale, correction_recip, correction_offset
- """
-
- g = ops.get_default_graph()
- with g.name_scope(context + 'batch_norm_correction'):
- recip_sigma_mv = math_ops.rsqrt(
- match.moving_variance_tensor + match.batch_epsilon_tensor)
- recip_sigma = math_ops.rsqrt(
- match.variance_tensor + match.batch_epsilon_tensor)
- correction_scale = math_ops.divide(
- recip_sigma_mv, recip_sigma, name='scale_compute')
- correction_scale = array_ops.identity(
- correction_scale, name='correction_scale')
- correction_recip = math_ops.reciprocal(
- correction_scale, name='reciprocal_compute')
- correction_offset = math_ops.multiply(
- match.gamma_tensor,
- match.mean_tensor * recip_sigma -
- match.moving_mean_tensor * recip_sigma_mv,
- name='offset_compute')
-
- if freeze_batch_norm_delay is not None:
- use_mv_avg = math_ops.greater_equal(
- training_util.get_or_create_global_step(),
- freeze_batch_norm_delay,
- name='use_moving_average')
- else:
- use_mv_avg = False
-
- bn_decay_zero = 0.0
- bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
- bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
-
- bn_decay_mean_out = utils.smart_cond(
- use_mv_avg,
- lambda: bn_decay_zero,
- lambda: match.bn_decay_mean_tensor,
- name='freeze_moving_mean')
- graph_editor.reroute_ts(
- [bn_decay_mean_out], [match.bn_decay_mean_tensor],
- can_modify=bn_decay_mean_consumers)
-
- if fused_batch_norm is False:
- bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
- bn_decay_var_out = utils.smart_cond(
- use_mv_avg,
- lambda: bn_decay_zero,
- lambda: match.bn_decay_var_tensor,
- name='freeze_moving_var')
- graph_editor.reroute_ts(
- [bn_decay_var_out], [match.bn_decay_var_tensor],
- can_modify=bn_decay_var_consumers)
-
- correction_recip = utils.smart_cond(
- use_mv_avg,
- lambda: array_ops.ones(correction_scale.shape),
- lambda: correction_recip,
- name='correction_recip')
-
- correction_offset = utils.smart_cond(
- use_mv_avg,
- lambda: correction_offset,
- lambda: array_ops.zeros(correction_offset.shape),
- name='correction_offset')
- return correction_scale, correction_recip, correction_offset
-
-
def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
is_training):
"""Folds in batch norm layer into preceding convolution or FC layer.
if not in_shape.is_compatible_with(out_shape):
raise ValueError('%s should not change tensor shape: input %s, '
'output %s' % (op_name, in_shape, out_shape))
+
+
+def _HasScaling(graph, input_to_ops_map, bn):
+ r"""Checks if batch norm has scaling enabled.
+
+ Difference between batch norm with scaling and without is that with scaling:
+
+ Rsqrt -> mul -> mul_1
+ \-> mul_2
+
+ where
+ mul multiplies gamma by inverse square root of EMA of batch variance,
+ mul_1 multiplies output of mul with output from the base operation
+ (convolution, FC or depthwise convolution),
+ mul_2 multiplies output of mul with EMA of batch mean,
+ and without scaling:
+
+ Rsqrt -> mul
+ \-> mul_1
+
+ where
+ mul multiplies the inverse square root of EMA of batch variance with output
+ from the base operation,
+ mul_1 multiplies inverse square root of EMA of batch variance with EMA
+ of batch mean.
+
+ Args:
+ graph: Graph to inspect.
+ input_to_ops_map: InputToOps object containing mapping from tensor's name
+ to ops that take it as input.
+ bn: Batch norm layer prefix string.
+
+ Returns:
+ A boolean indicating whether this batch norm layer has scaling enabled.
+ """
+ rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
+ rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
+
+ return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
+
+
+class _BatchNormMatch(object):
+ """Contains all information related to a found Fused/UnfusedBatchNorm."""
+
+ def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
+ weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
+ variance_tensor, moving_mean_tensor, moving_variance_tensor,
+ bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor):
+ self._layer_op = layer_op
+ self._bn_op = bn_op
+ self._output_tensor = output_tensor
+ self._input_tensor = input_tensor
+ self._weight_tensor = weight_tensor
+ self._gamma_tensor = gamma_tensor
+ self._beta_tensor = beta_tensor
+ self._mean_tensor = mean_tensor
+ self._variance_tensor = variance_tensor
+ self._moving_mean_tensor = moving_mean_tensor
+ self._moving_variance_tensor = moving_variance_tensor
+ self._bn_decay_mean_tensor = bn_decay_mean_tensor
+ self._bn_decay_var_tensor = bn_decay_var_tensor
+ self._batch_epsilon_tensor = batch_epsilon_tensor
+
+ @property
+ def layer_op(self):
+ return self._layer_op
+
+ @property
+ def bn_op(self):
+ return self._bn_op
+
+ @property
+ def output_tensor(self):
+ return self._output_tensor
+
+ @property
+ def input_tensor(self):
+ return self._input_tensor
+
+ @property
+ def weight_tensor(self):
+ return self._weight_tensor
+
+ @property
+ def gamma_tensor(self):
+ return self._gamma_tensor
+
+ @property
+ def beta_tensor(self):
+ return self._beta_tensor
+
+ @property
+ def mean_tensor(self):
+ return self._mean_tensor
+
+ @property
+ def variance_tensor(self):
+ return self._variance_tensor
+
+ @property
+ def moving_mean_tensor(self):
+ return self._moving_mean_tensor
+
+ @property
+ def moving_variance_tensor(self):
+ return self._moving_variance_tensor
+
+ @property
+ def batch_epsilon_tensor(self):
+ return self._batch_epsilon_tensor
+
+ @property
+ def bn_decay_mean_tensor(self):
+ return self._bn_decay_mean_tensor
+
+ @property
+ def bn_decay_var_tensor(self):
+ return self._bn_decay_var_tensor