# Detect unfused batch normalization nodes and fuse them.
def fuse_batch_normalization():
- pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add']
- candidates = []
-
- for node in graph_def.node:
- if node.op == pattern[len(candidates)]:
- candidates.append(node)
+ # Add_0 <-- moving_variance, add_y
+ # Rsqrt <-- Add_0
+ # Mul_0 <-- Rsqrt, gamma
+ # Mul_1 <-- input, Mul_0
+ # Mul_2 <-- moving_mean, Mul_0
+ # Sub_0 <-- beta, Mul_2
+ # Add_1 <-- Mul_1, Sub_0
+ nodesMap = {node.name: node for node in graph_def.node}
+ subgraph = ['Add',
+ ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
+ ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
+ def checkSubgraph(node, targetNode, inputs, fusedNodes):
+ op = targetNode[0]
+ if node.op == op and (len(node.input) >= len(targetNode) - 1):
+ fusedNodes.append(node)
+ for i, inpOp in enumerate(targetNode[1:]):
+ if isinstance(inpOp, list):
+ if not node.input[i] in nodesMap or \
+ not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
+ return False
+ else:
+ inputs[inpOp] = node.input[i]
+
+ return True
else:
- candidates = []
-
- if len(candidates) == len(pattern):
- inp = candidates[3].input[0]
- gamma = candidates[2].input[1]
- beta = candidates[5].input[0]
- moving_mean = candidates[4].input[0]
- moving_variance = candidates[0].input[0]
+ return False
+ nodesToRemove = []
+ for node in graph_def.node:
+ inputs = {}
+ fusedNodes = []
+ if checkSubgraph(node, subgraph, inputs, fusedNodes):
name = node.name
node.Clear()
node.name = name
node.op = 'FusedBatchNorm'
- node.input.append(inp)
- node.input.append(gamma)
- node.input.append(beta)
- node.input.append(moving_mean)
- node.input.append(moving_variance)
+ node.input.append(inputs['input'])
+ node.input.append(inputs['gamma'])
+ node.input.append(inputs['beta'])
+ node.input.append(inputs['moving_mean'])
+ node.input.append(inputs['moving_variance'])
text_format.Merge('f: 0.001', node.attr["epsilon"])
-
- for candidate in candidates[:-1]:
- graph_def.node.remove(candidate)
- candidates = []
+ nodesToRemove += fusedNodes[1:]
+ for node in nodesToRemove:
+ graph_def.node.remove(node)
fuse_batch_normalization()