From: Dmitry Kurtaev Date: Wed, 23 May 2018 13:49:31 +0000 (+0300) Subject: Fix batch normalization fusion from TensorFlow's SSDs X-Git-Tag: accepted/tizen/6.0/unified/20201030.111113~1^2~604^2~17^2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3346a7ac0ad21520d3e20cb9335334ba7499bf15;p=platform%2Fupstream%2Fopencv.git Fix batch normalization fusion from TensorFlow's SSDs --- diff --git a/samples/dnn/tf_text_graph_ssd.py b/samples/dnn/tf_text_graph_ssd.py index 57c3e04..f571047 100644 --- a/samples/dnn/tf_text_graph_ssd.py +++ b/samples/dnn/tf_text_graph_ssd.py @@ -64,36 +64,51 @@ removedNodes = [] # Detect unfused batch normalization nodes and fuse them. def fuse_batch_normalization(): - pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add'] - candidates = [] - - for node in graph_def.node: - if node.op == pattern[len(candidates)]: - candidates.append(node) + # Add_0 <-- moving_variance, add_y + # Rsqrt <-- Add_0 + # Mul_0 <-- Rsqrt, gamma + # Mul_1 <-- input, Mul_0 + # Mul_2 <-- moving_mean, Mul_0 + # Sub_0 <-- beta, Mul_2 + # Add_1 <-- Mul_1, Sub_0 + nodesMap = {node.name: node for node in graph_def.node} + subgraph = ['Add', + ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']], + ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]] + def checkSubgraph(node, targetNode, inputs, fusedNodes): + op = targetNode[0] + if node.op == op and (len(node.input) >= len(targetNode) - 1): + fusedNodes.append(node) + for i, inpOp in enumerate(targetNode[1:]): + if isinstance(inpOp, list): + if not node.input[i] in nodesMap or \ + not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes): + return False + else: + inputs[inpOp] = node.input[i] + + return True else: - candidates = [] - - if len(candidates) == len(pattern): - inp = candidates[3].input[0] - gamma = candidates[2].input[1] - beta = candidates[5].input[0] - moving_mean = candidates[4].input[0] - moving_variance = candidates[0].input[0] + return False + nodesToRemove = [] + for node in graph_def.node: + inputs = {} + fusedNodes = [] + if checkSubgraph(node, subgraph, inputs, fusedNodes): name = node.name node.Clear() node.name = name node.op = 'FusedBatchNorm' - node.input.append(inp) - node.input.append(gamma) - node.input.append(beta) - node.input.append(moving_mean) - node.input.append(moving_variance) + node.input.append(inputs['input']) + node.input.append(inputs['gamma']) + node.input.append(inputs['beta']) + node.input.append(inputs['moving_mean']) + node.input.append(inputs['moving_variance']) text_format.Merge('f: 0.001', node.attr["epsilon"]) - - for candidate in candidates[:-1]: - graph_def.node.remove(candidate) - candidates = [] + nodesToRemove += fusedNodes[1:] + for node in nodesToRemove: + graph_def.node.remove(node) fuse_batch_normalization()