Make fold batch norm code use OneofPattern and rearrange functions to (maybe) be...
authorSuharsh Sivakumar <suharshs@google.com>
Mon, 5 Feb 2018 23:32:32 +0000 (15:32 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 23:39:01 +0000 (15:39 -0800)
PiperOrigin-RevId: 184597111

tensorflow/contrib/quantize/python/fold_batch_norms.py

index 8ec5334..7fa0d48 100644 (file)
@@ -17,6 +17,7 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+
 import re
 from tensorflow.contrib import graph_editor
 from tensorflow.contrib.quantize.python import common
@@ -120,12 +121,8 @@ def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
         weights = math_ops.multiply(
             correction_scale, weights, name='correction_mult')
 
-      # TODO(suharshs): This naming of the following ops needs to carefully
-      # follow the naming expected by quantize.py. Generalize the quantize code
-      # to not require these delicate naming conventions.
       scaled_weight_tensor = math_ops.multiply(
           weights, multiplier_tensor, name='mul_fold')
-
       new_layer_tensor = _CloneWithNewOperands(
           match.layer_op, match.input_tensor, scaled_weight_tensor)
 
@@ -145,46 +142,6 @@ def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
             'Unexpected inputs to op: %s' % match.output_tensor.name)
 
 
-def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
-  """Clones layer_op with input_tensor and weight_tensor as new inputs."""
-  new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
-  if layer_op.type == 'Conv2D':
-    return nn_ops.conv2d(
-        input_tensor,
-        weight_tensor,
-        strides=layer_op.get_attr('strides'),
-        padding=layer_op.get_attr('padding'),
-        use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
-        data_format=layer_op.get_attr('data_format'),
-        name=new_layer_name)
-  elif layer_op.type == 'MatMul':
-    return math_ops.matmul(
-        input_tensor,
-        weight_tensor,
-        transpose_a=layer_op.get_attr('transpose_a'),
-        transpose_b=layer_op.get_attr('transpose_b'),
-        name=new_layer_name)
-  elif layer_op.type == 'DepthwiseConv2dNative':
-    return nn.depthwise_conv2d(
-        input_tensor,
-        weight_tensor,
-        strides=layer_op.get_attr('strides'),
-        padding=layer_op.get_attr('padding'),
-        name=new_layer_name)
-  else:
-    raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
-
-
-@ops.RegisterGradient('FoldFusedBatchNormGrad')
-def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1,
-                            unused_2):
-  x = op.inputs[0]
-  n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements()
-  dmean_dx = grad_mean / n
-  dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1)
-  return (dmean_dx + dvar_dx), None, None, None, None
-
-
 def _FindFusedBatchNorms(graph):
   """Finds all ops and tensors related to found FusedBatchNorms.
 
@@ -203,68 +160,57 @@ def _FindFusedBatchNorms(graph):
 
   moving_average_pattern = graph_matcher.OpTypePattern('*')
   bn_decay_pattern = graph_matcher.OpTypePattern('*')
-  conv_pattern = graph_matcher.OpTypePattern(
-      'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
+  layer_pattern = graph_matcher.OpTypePattern(
+      'Conv2D|DepthwiseConv2dNative|MatMul',
+      inputs=[input_pattern, weight_pattern])
   # MatMul has a Reshape between it and FusedBatchNorm.
-  matmul_pattern = graph_matcher.OpTypePattern(
-      'MatMul', inputs=[input_pattern, weight_pattern])
   matmul_reshape_pattern = graph_matcher.OpTypePattern(
-      'Reshape', inputs=[matmul_pattern,
+      'Reshape', inputs=[layer_pattern,
                          graph_matcher.OpTypePattern('*')])
 
-  conv_batch_norm_pattern = graph_matcher.OpTypePattern(
+  batch_norm_pattern = graph_matcher.OpTypePattern(
       'FusedBatchNorm',
       inputs=[
-          conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
-          variance_pattern
-      ])
-  conv_moving_average_sub_pattern = graph_matcher.OpTypePattern(
-      'Sub', inputs=[moving_average_pattern, conv_batch_norm_pattern])
-  # TODO(suharshs): Use a OneofPattern here when available
-  conv_moving_average_mul_pattern = graph_matcher.OpTypePattern(
-      'Mul', inputs=[conv_moving_average_sub_pattern, bn_decay_pattern])
-  matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
-      'FusedBatchNorm',
-      inputs=[
-          matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
-          variance_pattern
+          graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]),
+          gamma_pattern, beta_pattern, mean_pattern, variance_pattern
       ])
   matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
-      'Reshape',
-      inputs=[matmul_batch_norm_pattern,
-              graph_matcher.OpTypePattern('*')])
-
-  matmul_moving_average_sub_pattern = graph_matcher.OpTypePattern(
-      'Sub', inputs=[moving_average_pattern, matmul_batch_norm_pattern])
-  matmul_moving_average_mul_pattern = graph_matcher.OpTypePattern(
-      'Mul', inputs=[matmul_moving_average_sub_pattern, bn_decay_pattern])
-
-  conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
-  matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
-  conv_moving_average_mul_matcher = graph_matcher.GraphMatcher(
-      conv_moving_average_mul_pattern)
-  matmul_moving_average_mul_matcher = graph_matcher.GraphMatcher(
-      matmul_moving_average_mul_pattern)
-
-  def _GetMovingAverageTensors(graph, moving_avg_mul_matcher,
-                               moving_avg_sub_pattern, bn_op):
-    """Gets the moving mean and variance tensors and the batch norm momentum."""
-    for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
-      sub_op = mul_match_result.get_op(moving_avg_sub_pattern)
-
-      if sub_op.inputs[1].name == bn_op.outputs[1].name:
-        # During training: Batch Mean is bn_op.outputs[1]
-        moving_mean_tensor = sub_op.inputs[0]
-        bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
-      if sub_op.inputs[1].name == bn_op.outputs[2].name:
-        # During training: Batch Var is bn_op.outputs[2]
-        moving_variance_tensor = sub_op.inputs[0]
-        bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
-    return (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
-            bn_decay_var_tensor)
-
-  def _GetCommonTensors(match_result, bn_op, bn_input_tensor):
-    """Gets tensors needed for FusedBatchNormMatch from match_result."""
+      'Reshape', inputs=[batch_norm_pattern,
+                         graph_matcher.OpTypePattern('*')])
+
+  bn_matcher = graph_matcher.GraphMatcher(
+      graph_matcher.OneofPattern(
+          [matmul_bn_output_reshape_pattern, batch_norm_pattern]))
+
+  moving_average_sub_pattern = graph_matcher.OpTypePattern(
+      'Sub', inputs=[moving_average_pattern, batch_norm_pattern])
+  moving_average_mul_pattern = graph_matcher.OpTypePattern(
+      'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern])
+
+  moving_avg_mul_matcher = graph_matcher.GraphMatcher(
+      moving_average_mul_pattern)
+
+  for match_result in bn_matcher.match_graph(graph):
+    moving_mean_tensor = None
+    moving_variance_tensor = None
+    bn_decay_mean_tensor = None
+    bn_decay_var_tensor = None
+    layer_op = match_result.get_op(layer_pattern)
+    layer_tensor = match_result.get_tensor(layer_pattern)
+    bn_op = match_result.get_op(batch_norm_pattern)
+    batch_epsilon_tensor = bn_op.get_attr('epsilon')
+
+    # In the MatMul case, the output of batch norm is reshaped back into a
+    # 2D tensor, so the output_tensor is the output of the Reshape op.
+    output_tensor = bn_op.outputs[0]
+    if layer_op.type == 'MatMul':
+      output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
+      # If the matcher didn't match matmul_bn_output_reshape, there will be
+      # another match for this 'MatMul' later, so we can skip this one.
+      if output_reshape_op is None:
+        continue
+      output_tensor = output_reshape_op.outputs[0]
+
     input_tensor = match_result.get_tensor(input_pattern)
     weight_tensor = match_result.get_tensor(weight_pattern)
     gamma_tensor = match_result.get_tensor(gamma_pattern)
@@ -295,40 +241,25 @@ def _FindFusedBatchNorms(graph):
       g = ops.get_default_graph()
       with g.as_default(), g.name_scope(scope + sep):
         n = math_ops.cast(
-            array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor),
+            array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
             dtypes.float32)
         variance_tensor = math_ops.multiply(
             bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction')
+      # TODO(suharshs): Find a way to get rid of this inner match.
+      for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
+        sub_op = mul_match_result.get_op(moving_average_sub_pattern)
+        if sub_op.inputs[1].name == bn_op.outputs[1].name:
+          # During training: Batch Mean is bn_op.outputs[1]
+          moving_mean_tensor = sub_op.inputs[0]
+          bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
+        if sub_op.inputs[1].name == bn_op.outputs[2].name:
+          # During training: Batch Var is bn_op.outputs[2]
+          moving_variance_tensor = sub_op.inputs[0]
+          bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
     else:
       mean_tensor = match_result.get_tensor(mean_pattern)
       variance_tensor = match_result.get_tensor(variance_pattern)
-    return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
-            variance_tensor)
-
-  for match_result in conv_matcher.match_graph(graph):
-    moving_mean_tensor = None
-    moving_variance_tensor = None
-    bn_decay_mean_tensor = None
-    bn_decay_var_tensor = None
-    layer_op = match_result.get_op(conv_pattern)
-    layer_tensor = match_result.get_tensor(conv_pattern)
-    bn_op = match_result.get_op(conv_batch_norm_pattern)
-    if bn_op.get_attr('is_training'):
-      (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
-       bn_decay_var_tensor) = _GetMovingAverageTensors(
-           graph,
-           moving_avg_mul_matcher=conv_moving_average_mul_matcher,
-           moving_avg_sub_pattern=conv_moving_average_sub_pattern,
-           bn_op=bn_op)
 
-    output_tensor = bn_op.outputs[0]
-    batch_epsilon_tensor = bn_op.get_attr('epsilon')
-    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
-     variance_tensor) = _GetCommonTensors(
-         match_result,
-         bn_op,
-         layer_tensor,
-     )
     yield _BatchNormMatch(
         layer_op=layer_op,
         bn_op=bn_op,
@@ -345,124 +276,146 @@ def _FindFusedBatchNorms(graph):
         bn_decay_var_tensor=bn_decay_var_tensor,
         batch_epsilon_tensor=batch_epsilon_tensor)
 
-  for match_result in matmul_matcher.match_graph(graph):
-    moving_mean_tensor = None
-    moving_variance_tensor = None
-    bn_decay_mean_tensor = None
-    bn_decay_var_tensor = None
-    layer_op = match_result.get_op(matmul_pattern)
-    layer_tensor = match_result.get_tensor(matmul_pattern)
-    bn_op = match_result.get_op(matmul_batch_norm_pattern)
-    if bn_op.get_attr('is_training'):
-      (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
-       bn_decay_var_tensor) = _GetMovingAverageTensors(
-           graph,
-           moving_avg_mul_matcher=matmul_moving_average_mul_matcher,
-           moving_avg_sub_pattern=matmul_moving_average_sub_pattern,
-           bn_op=bn_op)
-
-    # In the MatMul case, the output of batch norm is reshaped back into a
-    # 2D tensor, so the output_tensor is the output of the Reshape op.
-    output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
-    output_tensor = output_reshape_op.outputs[0]
-    batch_epsilon_tensor = bn_op.get_attr('epsilon')
 
-    (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
-     variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
-    yield _BatchNormMatch(
-        layer_op=layer_op,
-        bn_op=bn_op,
-        output_tensor=output_tensor,
-        input_tensor=input_tensor,
-        weight_tensor=weight_tensor,
-        gamma_tensor=gamma_tensor,
-        beta_tensor=beta_tensor,
-        mean_tensor=mean_tensor,
-        variance_tensor=variance_tensor,
-        moving_mean_tensor=moving_mean_tensor,
-        moving_variance_tensor=moving_variance_tensor,
-        bn_decay_mean_tensor=bn_decay_mean_tensor,
-        bn_decay_var_tensor=bn_decay_var_tensor,
-        batch_epsilon_tensor=batch_epsilon_tensor)
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
+                                 fused_batch_norm):
+  """Computes batch norm correction params.
 
+     Before batch normalization is frozen:
+     We use batch statistics for batch norm.
+       correction_scale = sigma_b/sigma_mv
+       correction_recip = 1/correction_scale
+       correction_offset = 0
 
-class _BatchNormMatch(object):
-  """Contains all information related to a found Fused/UnfusedBatchNorm."""
+     After batch normalization is frozen:
+      correction_scale = sigma_b/sigma_mv
+      correction_recip = 1
+      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
 
-  def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
-               weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
-               variance_tensor, moving_mean_tensor, moving_variance_tensor,
-               bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor):
-    self._layer_op = layer_op
-    self._bn_op = bn_op
-    self._output_tensor = output_tensor
-    self._input_tensor = input_tensor
-    self._weight_tensor = weight_tensor
-    self._gamma_tensor = gamma_tensor
-    self._beta_tensor = beta_tensor
-    self._mean_tensor = mean_tensor
-    self._variance_tensor = variance_tensor
-    self._moving_mean_tensor = moving_mean_tensor
-    self._moving_variance_tensor = moving_variance_tensor
-    self._bn_decay_mean_tensor = bn_decay_mean_tensor
-    self._bn_decay_var_tensor = bn_decay_var_tensor
-    self._batch_epsilon_tensor = batch_epsilon_tensor
+     Batch norm is frozen if global_step > bn_freeze_delay.
+     The corrections ensure that:
+     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
+     smoother training as the scaling on the weights changes slowly, rather than
+     jump across mini-batches
+     b) Changing the values of the corrections allows for one to switch between
+     using batch statistics to using moving mean and average, without requiring
+     changes to batch_norm
 
-  @property
-  def layer_op(self):
-    return self._layer_op
 
-  @property
-  def bn_op(self):
-    return self._bn_op
+  Args:
+    context: The scope under which we look for batch norm params
+    match: Object containg required batch norm tensors for correction
+      computation
+    freeze_batch_norm_delay: Delay in steps at which computation switches
+      from regular batch norm to frozen mean and variance.
+    fused_batch_norm: Bool, true if fused batch norm is used
 
-  @property
-  def output_tensor(self):
-    return self._output_tensor
+  Returns:
+    A tuple of correction_scale, correction_recip, correction_offset
+  """
 
-  @property
-  def input_tensor(self):
-    return self._input_tensor
+  g = ops.get_default_graph()
+  with g.name_scope(context + '/batch_norm_correction'):
+    recip_sigma_mv = math_ops.rsqrt(
+        match.moving_variance_tensor + match.batch_epsilon_tensor)
+    recip_sigma = math_ops.rsqrt(
+        match.variance_tensor + match.batch_epsilon_tensor)
+    correction_scale = math_ops.divide(
+        recip_sigma_mv, recip_sigma, name='scale_compute')
+    correction_scale = array_ops.identity(
+        correction_scale, name='correction_scale')
+    correction_recip = math_ops.reciprocal(
+        correction_scale, name='reciprocal_compute')
+    correction_offset = math_ops.multiply(
+        match.gamma_tensor,
+        match.mean_tensor * recip_sigma -
+        match.moving_mean_tensor * recip_sigma_mv,
+        name='offset_compute')
 
-  @property
-  def weight_tensor(self):
-    return self._weight_tensor
+    if freeze_batch_norm_delay is not None:
+      use_mv_avg = math_ops.greater_equal(
+          training_util.get_or_create_global_step(),
+          freeze_batch_norm_delay,
+          name='use_moving_average')
+    else:
+      use_mv_avg = False
 
-  @property
-  def gamma_tensor(self):
-    return self._gamma_tensor
+    bn_decay_zero = 0.0
+    bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
+    bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
 
-  @property
-  def beta_tensor(self):
-    return self._beta_tensor
+    bn_decay_mean_out = utils.smart_cond(
+        use_mv_avg,
+        lambda: bn_decay_zero,
+        lambda: match.bn_decay_mean_tensor,
+        name='freeze_moving_mean')
+    graph_editor.reroute_ts(
+        [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+        can_modify=bn_decay_mean_consumers)
 
-  @property
-  def mean_tensor(self):
-    return self._mean_tensor
+    if fused_batch_norm is False:
+      bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
+      bn_decay_var_out = utils.smart_cond(
+          use_mv_avg,
+          lambda: bn_decay_zero,
+          lambda: match.bn_decay_var_tensor,
+          name='freeze_moving_var')
+      graph_editor.reroute_ts(
+          [bn_decay_var_out], [match.bn_decay_var_tensor],
+          can_modify=bn_decay_var_consumers)
 
-  @property
-  def variance_tensor(self):
-    return self._variance_tensor
+    correction_recip = utils.smart_cond(
+        use_mv_avg,
+        lambda: array_ops.ones(correction_scale.shape),
+        lambda: correction_recip,
+        name='correction_recip')
 
-  @property
-  def moving_mean_tensor(self):
-    return self._moving_mean_tensor
+    correction_offset = utils.smart_cond(
+        use_mv_avg,
+        lambda: correction_offset,
+        lambda: array_ops.zeros(correction_offset.shape),
+        name='correction_offset')
+  return correction_scale, correction_recip, correction_offset
 
-  @property
-  def moving_variance_tensor(self):
-    return self._moving_variance_tensor
 
-  @property
-  def batch_epsilon_tensor(self):
-    return self._batch_epsilon_tensor
+def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
+  """Clones layer_op with input_tensor and weight_tensor as new inputs."""
+  new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
+  if layer_op.type == 'Conv2D':
+    return nn_ops.conv2d(
+        input_tensor,
+        weight_tensor,
+        strides=layer_op.get_attr('strides'),
+        padding=layer_op.get_attr('padding'),
+        use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
+        data_format=layer_op.get_attr('data_format'),
+        name=new_layer_name)
+  elif layer_op.type == 'MatMul':
+    return math_ops.matmul(
+        input_tensor,
+        weight_tensor,
+        transpose_a=layer_op.get_attr('transpose_a'),
+        transpose_b=layer_op.get_attr('transpose_b'),
+        name=new_layer_name)
+  elif layer_op.type == 'DepthwiseConv2dNative':
+    return nn.depthwise_conv2d(
+        input_tensor,
+        weight_tensor,
+        strides=layer_op.get_attr('strides'),
+        padding=layer_op.get_attr('padding'),
+        name=new_layer_name)
+  else:
+    raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
 
-  @property
-  def bn_decay_mean_tensor(self):
-    return self._bn_decay_mean_tensor
 
-  @property
-  def bn_decay_var_tensor(self):
-    return self._bn_decay_var_tensor
+@ops.RegisterGradient('FoldFusedBatchNormGrad')
+def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1,
+                            unused_2):
+  x = op.inputs[0]
+  n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements()
+  dmean_dx = grad_mean / n
+  dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1)
+  return (dmean_dx + dvar_dx), None, None, None, None
 
 
 def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
@@ -475,81 +428,42 @@ def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
     graph: Graph to walk and modify.
     freeze_batch_norm_delay: How many steps to wait before freezing
     moving mean and variance and using them for batch normalization
-    is_training: Bool, True if training
-
-  Raises:
-    ValueError: When batch norm folding fails.
-  """
-  input_to_ops_map = input_to_ops.InputToOps(graph)
-
-  for bn in common.BatchNormGroups(graph):
-    has_scaling = _HasScaling(graph, input_to_ops_map, bn)
-
-    # The mangling code intimately depends on BatchNorm node's internals.
-    original_op, folded_op = _CreateFoldedOp(
-        graph,
-        bn,
-        has_scaling=has_scaling,
-        freeze_batch_norm_delay=freeze_batch_norm_delay,
-        is_training=is_training)
-
-    activation = common.GetEndpointActivationOp(graph, bn)
-    if activation:
-      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
-                                                     [original_op.outputs[0]],
-                                                     can_modify=[activation])
-      if nodes_modified_count != 1:
-        raise ValueError('Unexpected inputs to op: %s' % activation.name)
-      continue
-
-    # Treat consumer ops in bypass modules differently since they have Add
-    # operations instead of Relu* above.
-    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
-    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
-    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
-                                                   [original_op.outputs[0]],
-                                                   can_modify=[add_bypass])
-    if nodes_modified_count != 1:
-      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
-
-
-def _HasScaling(graph, input_to_ops_map, bn):
-  r"""Checks if batch norm  has scaling enabled.
-
-  Difference between batch norm with scaling and without is that with scaling:
-
-  Rsqrt -> mul -> mul_1
-              \-> mul_2
-
-  where
-    mul multiplies gamma by inverse square root of EMA of batch variance,
-    mul_1 multiplies output of mul with output from the base operation
-      (convolution, FC or depthwise convolution),
-    mul_2 multiplies output of mul with EMA of batch mean,
-  and without scaling:
-
-  Rsqrt -> mul
-       \-> mul_1
-
-  where
-    mul multiplies the inverse square root of EMA of batch variance with output
-      from the base operation,
-    mul_1 multiplies inverse square root of EMA of batch variance with EMA
-      of batch mean.
-
-  Args:
-    graph: Graph to inspect.
-    input_to_ops_map: InputToOps object containing mapping from tensor's name
-      to ops that take it as input.
-    bn: Batch norm layer prefix string.
+    is_training: Bool, True if training
 
-  Returns:
-    A boolean indicating whether this batch norm layer has scaling enabled.
+  Raises:
+    ValueError: When batch norm folding fails.
   """
-  rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
-  rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
+  input_to_ops_map = input_to_ops.InputToOps(graph)
 
-  return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
+  for bn in common.BatchNormGroups(graph):
+    has_scaling = _HasScaling(graph, input_to_ops_map, bn)
+
+    # The mangling code intimately depends on BatchNorm node's internals.
+    original_op, folded_op = _CreateFoldedOp(
+        graph,
+        bn,
+        has_scaling=has_scaling,
+        freeze_batch_norm_delay=freeze_batch_norm_delay,
+        is_training=is_training)
+
+    activation = common.GetEndpointActivationOp(graph, bn)
+    if activation:
+      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
+                                                     [original_op.outputs[0]],
+                                                     can_modify=[activation])
+      if nodes_modified_count != 1:
+        raise ValueError('Unexpected inputs to op: %s' % activation.name)
+      continue
+
+    # Treat consumer ops in bypass modules differently since they have Add
+    # operations instead of Relu* above.
+    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
+                                                   [original_op.outputs[0]],
+                                                   can_modify=[add_bypass])
+    if nodes_modified_count != 1:
+      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
 
 
 def _GetBatchNormParams(graph, context, has_scaling):
@@ -629,107 +543,6 @@ def _GetBatchNormParams(graph, context, has_scaling):
       batch_epsilon_tensor=batch_epsilon_tensor)
 
 
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
-                                 fused_batch_norm):
-  """Computes batch norm correction params.
-
-     Before batch normalization is frozen:
-     We use batch statistics for batch norm.
-       correction_scale = sigma_b/sigma_mv
-       correction_recip = 1/correction_scale
-       correction_offset = 0
-
-     After batch normalization is frozen:
-      correction_scale = sigma_b/sigma_mv
-      correction_recip = 1
-      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
-
-     Batch norm is frozen if global_step > bn_freeze_delay.
-     The corrections ensure that:
-     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
-     smoother training as the scaling on the weights changes slowly, rather than
-     jump across mini-batches
-     b) Changing the values of the corrections allows for one to switch between
-     using batch statistics to using moving mean and average, without requiring
-     changes to batch_norm
-
-
-  Args:
-    context: The scope under which we look for batch norm params
-    match: Object containg required batch norm tensors for correction
-      computation
-    freeze_batch_norm_delay: Delay in steps at which computation switches
-      from regular batch norm to frozen mean and variance.
-    fused_batch_norm: Bool, true if fused batch norm is used
-
-  Returns:
-    A tuple of correction_scale, correction_recip, correction_offset
-  """
-
-  g = ops.get_default_graph()
-  with g.name_scope(context + 'batch_norm_correction'):
-    recip_sigma_mv = math_ops.rsqrt(
-        match.moving_variance_tensor + match.batch_epsilon_tensor)
-    recip_sigma = math_ops.rsqrt(
-        match.variance_tensor + match.batch_epsilon_tensor)
-    correction_scale = math_ops.divide(
-        recip_sigma_mv, recip_sigma, name='scale_compute')
-    correction_scale = array_ops.identity(
-        correction_scale, name='correction_scale')
-    correction_recip = math_ops.reciprocal(
-        correction_scale, name='reciprocal_compute')
-    correction_offset = math_ops.multiply(
-        match.gamma_tensor,
-        match.mean_tensor * recip_sigma -
-        match.moving_mean_tensor * recip_sigma_mv,
-        name='offset_compute')
-
-    if freeze_batch_norm_delay is not None:
-      use_mv_avg = math_ops.greater_equal(
-          training_util.get_or_create_global_step(),
-          freeze_batch_norm_delay,
-          name='use_moving_average')
-    else:
-      use_mv_avg = False
-
-    bn_decay_zero = 0.0
-    bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
-    bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
-
-    bn_decay_mean_out = utils.smart_cond(
-        use_mv_avg,
-        lambda: bn_decay_zero,
-        lambda: match.bn_decay_mean_tensor,
-        name='freeze_moving_mean')
-    graph_editor.reroute_ts(
-        [bn_decay_mean_out], [match.bn_decay_mean_tensor],
-        can_modify=bn_decay_mean_consumers)
-
-    if fused_batch_norm is False:
-      bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
-      bn_decay_var_out = utils.smart_cond(
-          use_mv_avg,
-          lambda: bn_decay_zero,
-          lambda: match.bn_decay_var_tensor,
-          name='freeze_moving_var')
-      graph_editor.reroute_ts(
-          [bn_decay_var_out], [match.bn_decay_var_tensor],
-          can_modify=bn_decay_var_consumers)
-
-    correction_recip = utils.smart_cond(
-        use_mv_avg,
-        lambda: array_ops.ones(correction_scale.shape),
-        lambda: correction_recip,
-        name='correction_recip')
-
-    correction_offset = utils.smart_cond(
-        use_mv_avg,
-        lambda: correction_offset,
-        lambda: array_ops.zeros(correction_offset.shape),
-        name='correction_offset')
-  return correction_scale, correction_recip, correction_offset
-
-
 def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
                     is_training):
   """Folds in batch norm layer into preceding convolution or FC layer.
@@ -961,3 +774,121 @@ def _AssertShapesMatch(op_name, in_tensor, out_tensor):
   if not in_shape.is_compatible_with(out_shape):
     raise ValueError('%s should not change tensor shape: input %s, '
                      'output %s' % (op_name, in_shape, out_shape))
+
+
+def _HasScaling(graph, input_to_ops_map, bn):
+  r"""Checks if batch norm  has scaling enabled.
+
+  Difference between batch norm with scaling and without is that with scaling:
+
+  Rsqrt -> mul -> mul_1
+              \-> mul_2
+
+  where
+    mul multiplies gamma by inverse square root of EMA of batch variance,
+    mul_1 multiplies output of mul with output from the base operation
+      (convolution, FC or depthwise convolution),
+    mul_2 multiplies output of mul with EMA of batch mean,
+  and without scaling:
+
+  Rsqrt -> mul
+       \-> mul_1
+
+  where
+    mul multiplies the inverse square root of EMA of batch variance with output
+      from the base operation,
+    mul_1 multiplies inverse square root of EMA of batch variance with EMA
+      of batch mean.
+
+  Args:
+    graph: Graph to inspect.
+    input_to_ops_map: InputToOps object containing mapping from tensor's name
+      to ops that take it as input.
+    bn: Batch norm layer prefix string.
+
+  Returns:
+    A boolean indicating whether this batch norm layer has scaling enabled.
+  """
+  rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
+  rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
+
+  return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
+
+
+class _BatchNormMatch(object):
+  """Contains all information related to a found Fused/UnfusedBatchNorm."""
+
+  def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
+               weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
+               variance_tensor, moving_mean_tensor, moving_variance_tensor,
+               bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor):
+    self._layer_op = layer_op
+    self._bn_op = bn_op
+    self._output_tensor = output_tensor
+    self._input_tensor = input_tensor
+    self._weight_tensor = weight_tensor
+    self._gamma_tensor = gamma_tensor
+    self._beta_tensor = beta_tensor
+    self._mean_tensor = mean_tensor
+    self._variance_tensor = variance_tensor
+    self._moving_mean_tensor = moving_mean_tensor
+    self._moving_variance_tensor = moving_variance_tensor
+    self._bn_decay_mean_tensor = bn_decay_mean_tensor
+    self._bn_decay_var_tensor = bn_decay_var_tensor
+    self._batch_epsilon_tensor = batch_epsilon_tensor
+
+  @property
+  def layer_op(self):
+    return self._layer_op
+
+  @property
+  def bn_op(self):
+    return self._bn_op
+
+  @property
+  def output_tensor(self):
+    return self._output_tensor
+
+  @property
+  def input_tensor(self):
+    return self._input_tensor
+
+  @property
+  def weight_tensor(self):
+    return self._weight_tensor
+
+  @property
+  def gamma_tensor(self):
+    return self._gamma_tensor
+
+  @property
+  def beta_tensor(self):
+    return self._beta_tensor
+
+  @property
+  def mean_tensor(self):
+    return self._mean_tensor
+
+  @property
+  def variance_tensor(self):
+    return self._variance_tensor
+
+  @property
+  def moving_mean_tensor(self):
+    return self._moving_mean_tensor
+
+  @property
+  def moving_variance_tensor(self):
+    return self._moving_variance_tensor
+
+  @property
+  def batch_epsilon_tensor(self):
+    return self._batch_epsilon_tensor
+
+  @property
+  def bn_decay_mean_tensor(self):
+    return self._bn_decay_mean_tensor
+
+  @property
+  def bn_decay_var_tensor(self):
+    return self._bn_decay_var_tensor