Fix batch normalization fusion from TensorFlow's SSDs
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 23 May 2018 13:49:31 +0000 (16:49 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 23 May 2018 13:49:31 +0000 (16:49 +0300)
samples/dnn/tf_text_graph_ssd.py

index 57c3e04..f571047 100644 (file)
@@ -64,36 +64,51 @@ removedNodes = []
 
 # Detect unfused batch normalization nodes and fuse them.
 def fuse_batch_normalization():
-    pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add']
-    candidates = []
-
-    for node in graph_def.node:
-        if node.op == pattern[len(candidates)]:
-            candidates.append(node)
+    # Add_0 <-- moving_variance, add_y
+    # Rsqrt <-- Add_0
+    # Mul_0 <-- Rsqrt, gamma
+    # Mul_1 <-- input, Mul_0
+    # Mul_2 <-- moving_mean, Mul_0
+    # Sub_0 <-- beta, Mul_2
+    # Add_1 <-- Mul_1, Sub_0
+    nodesMap = {node.name: node for node in graph_def.node}
+    subgraph = ['Add',
+        ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
+        ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
+    def checkSubgraph(node, targetNode, inputs, fusedNodes):
+        op = targetNode[0]
+        if node.op == op and (len(node.input) >= len(targetNode) - 1):
+            fusedNodes.append(node)
+            for i, inpOp in enumerate(targetNode[1:]):
+                if isinstance(inpOp, list):
+                    if not node.input[i] in nodesMap or \
+                       not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
+                        return False
+                else:
+                    inputs[inpOp] = node.input[i]
+
+            return True
         else:
-            candidates = []
-
-        if len(candidates) == len(pattern):
-            inp = candidates[3].input[0]
-            gamma = candidates[2].input[1]
-            beta = candidates[5].input[0]
-            moving_mean = candidates[4].input[0]
-            moving_variance = candidates[0].input[0]
+            return False
 
+    nodesToRemove = []
+    for node in graph_def.node:
+        inputs = {}
+        fusedNodes = []
+        if checkSubgraph(node, subgraph, inputs, fusedNodes):
             name = node.name
             node.Clear()
             node.name = name
             node.op = 'FusedBatchNorm'
-            node.input.append(inp)
-            node.input.append(gamma)
-            node.input.append(beta)
-            node.input.append(moving_mean)
-            node.input.append(moving_variance)
+            node.input.append(inputs['input'])
+            node.input.append(inputs['gamma'])
+            node.input.append(inputs['beta'])
+            node.input.append(inputs['moving_mean'])
+            node.input.append(inputs['moving_variance'])
             text_format.Merge('f: 0.001', node.attr["epsilon"])
-
-            for candidate in candidates[:-1]:
-                graph_def.node.remove(candidate)
-            candidates = []
+            nodesToRemove += fusedNodes[1:]
+    for node in nodesToRemove:
+        graph_def.node.remove(node)
 
 fuse_batch_normalization()