Run entire SSDs from TensorFlow using Intel's Inference Engine
[platform/upstream/opencv.git] / samples / dnn / tf_text_graph_ssd.py
1 # This file is a part of OpenCV project.
2 # It is a subject to the license terms in the LICENSE file found in the top-level directory
3 # of this distribution and at http://opencv.org/license.html.
4 #
5 # Copyright (C) 2018, Intel Corporation, all rights reserved.
6 # Third party copyrights are property of their respective owners.
7 #
8 # Use this script to get the text graph representation (.pbtxt) of SSD-based
9 # deep learning network trained in TensorFlow Object Detection API.
10 # Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
11 # See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
12 import tensorflow as tf
13 import argparse
14 from math import sqrt
15 from tensorflow.core.framework.node_def_pb2 import NodeDef
16 from tensorflow.tools.graph_transforms import TransformGraph
17 from google.protobuf import text_format
18
19 parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
20                                              'SSD model from TensorFlow Object Detection API. '
21                                              'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
22 parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
23 parser.add_argument('--output', required=True, help='Path to output text graph.')
24 parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
25 parser.add_argument('--min_scale', default=0.2, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.')
26 parser.add_argument('--max_scale', default=0.95, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.')
27 parser.add_argument('--num_layers', default=6, type=int, help='Hyper-parameter of ssd_anchor_generator from config file.')
28 parser.add_argument('--aspect_ratios', default=[1.0, 2.0, 0.5, 3.0, 0.333], type=float, nargs='+',
29                     help='Hyper-parameter of ssd_anchor_generator from config file.')
30 parser.add_argument('--image_width', default=300, type=int, help='Training images width.')
31 parser.add_argument('--image_height', default=300, type=int, help='Training images height.')
32 args = parser.parse_args()
33
34 # Nodes that should be kept.
35 keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
36            'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity']
37
38 # Nodes attributes that could be removed because they are not used during import.
39 unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
40                'Index', 'Tperm', 'is_training', 'Tpaddings']
41
42 # Node with which prefixes should be removed
43 prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/')
44
45 # Read the graph.
46 with tf.gfile.FastGFile(args.input, 'rb') as f:
47     graph_def = tf.GraphDef()
48     graph_def.ParseFromString(f.read())
49
50 inpNames = ['image_tensor']
51 outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
52 graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order'])
53
54 def getUnconnectedNodes():
55     unconnected = []
56     for node in graph_def.node:
57         unconnected.append(node.name)
58         for inp in node.input:
59             if inp in unconnected:
60                 unconnected.remove(inp)
61     return unconnected
62
63 removedNodes = []
64
65 # Detect unfused batch normalization nodes and fuse them.
66 def fuse_batch_normalization():
67     # Add_0 <-- moving_variance, add_y
68     # Rsqrt <-- Add_0
69     # Mul_0 <-- Rsqrt, gamma
70     # Mul_1 <-- input, Mul_0
71     # Mul_2 <-- moving_mean, Mul_0
72     # Sub_0 <-- beta, Mul_2
73     # Add_1 <-- Mul_1, Sub_0
74     nodesMap = {node.name: node for node in graph_def.node}
75     subgraph = ['Add',
76         ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
77         ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
78     def checkSubgraph(node, targetNode, inputs, fusedNodes):
79         op = targetNode[0]
80         if node.op == op and (len(node.input) >= len(targetNode) - 1):
81             fusedNodes.append(node)
82             for i, inpOp in enumerate(targetNode[1:]):
83                 if isinstance(inpOp, list):
84                     if not node.input[i] in nodesMap or \
85                        not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
86                         return False
87                 else:
88                     inputs[inpOp] = node.input[i]
89
90             return True
91         else:
92             return False
93
94     nodesToRemove = []
95     for node in graph_def.node:
96         inputs = {}
97         fusedNodes = []
98         if checkSubgraph(node, subgraph, inputs, fusedNodes):
99             name = node.name
100             node.Clear()
101             node.name = name
102             node.op = 'FusedBatchNorm'
103             node.input.append(inputs['input'])
104             node.input.append(inputs['gamma'])
105             node.input.append(inputs['beta'])
106             node.input.append(inputs['moving_mean'])
107             node.input.append(inputs['moving_variance'])
108             text_format.Merge('f: 0.001', node.attr["epsilon"])
109             nodesToRemove += fusedNodes[1:]
110     for node in nodesToRemove:
111         graph_def.node.remove(node)
112
113 fuse_batch_normalization()
114
115 # Removes Identity nodes
116 def removeIdentity():
117     identities = {}
118     for node in graph_def.node:
119         if node.op == 'Identity':
120             identities[node.name] = node.input[0]
121             graph_def.node.remove(node)
122
123     for node in graph_def.node:
124         for i in range(len(node.input)):
125             if node.input[i] in identities:
126                 node.input[i] = identities[node.input[i]]
127
128 removeIdentity()
129
130 # Remove extra nodes and attributes.
131 for i in reversed(range(len(graph_def.node))):
132     op = graph_def.node[i].op
133     name = graph_def.node[i].name
134
135     if (not op in keepOps) or name.startswith(prefixesToRemove):
136         if op != 'Const':
137             removedNodes.append(name)
138
139         del graph_def.node[i]
140     else:
141         for attr in unusedAttrs:
142             if attr in graph_def.node[i].attr:
143                 del graph_def.node[i].attr[attr]
144
145 # Remove references to removed nodes except Const nodes.
146 for node in graph_def.node:
147     for i in reversed(range(len(node.input))):
148         if node.input[i] in removedNodes:
149             del node.input[i]
150
151 # Connect input node to the first layer
152 assert(graph_def.node[0].op == 'Placeholder')
153 # assert(graph_def.node[1].op == 'Conv2D')
154 weights = graph_def.node[1].input[0]
155 for i in range(len(graph_def.node[1].input)):
156     graph_def.node[1].input.pop()
157 graph_def.node[1].input.append(graph_def.node[0].name)
158 graph_def.node[1].input.append(weights)
159
160 # Create SSD postprocessing head ###############################################
161
162 # Concatenate predictions of classes, predictions of bounding boxes and proposals.
163 def tensorMsg(values):
164     if all([isinstance(v, float) for v in values]):
165         dtype = 'DT_FLOAT'
166         field = 'float_val'
167     elif all([isinstance(v, int) for v in values]):
168         dtype = 'DT_INT32'
169         field = 'int_val'
170     else:
171         raise Exception('Wrong values types')
172
173     msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
174     for value in values:
175         msg += '%s: %s ' % (field, str(value))
176     return msg + '}'
177
178 def addConstNode(name, values):
179     node = NodeDef()
180     node.name = name
181     node.op = 'Const'
182     text_format.Merge(tensorMsg(values), node.attr["value"])
183     graph_def.node.extend([node])
184
185 def addConcatNode(name, inputs, axisNodeName):
186     concat = NodeDef()
187     concat.name = name
188     concat.op = 'ConcatV2'
189     for inp in inputs:
190         concat.input.append(inp)
191     concat.input.append(axisNodeName)
192     graph_def.node.extend([concat])
193
194 addConstNode('concat/axis_flatten', [-1])
195 addConstNode('PriorBox/concat/axis', [-2])
196
197 for label in ['ClassPredictor', 'BoxEncodingPredictor']:
198     concatInputs = []
199     for i in range(args.num_layers):
200         # Flatten predictions
201         flatten = NodeDef()
202         inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
203         flatten.input.append(inpName)
204         flatten.name = inpName + '/Flatten'
205         flatten.op = 'Flatten'
206
207         concatInputs.append(flatten.name)
208         graph_def.node.extend([flatten])
209     addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')
210
211 idx = 0
212 for node in graph_def.node:
213     if node.name == ('BoxPredictor_%d/BoxEncodingPredictor/Conv2D' % idx):
214         text_format.Merge('b: true', node.attr["loc_pred_transposed"])
215         idx += 1
216 assert(idx == args.num_layers)
217
218 # Add layers that generate anchors (bounding boxes proposals).
219 scales = [args.min_scale + (args.max_scale - args.min_scale) * i / (args.num_layers - 1)
220           for i in range(args.num_layers)] + [1.0]
221
222 priorBoxes = []
223 for i in range(args.num_layers):
224     priorBox = NodeDef()
225     priorBox.name = 'PriorBox_%d' % i
226     priorBox.op = 'PriorBox'
227     priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
228     priorBox.input.append(graph_def.node[0].name)  # image_tensor
229
230     text_format.Merge('b: false', priorBox.attr["flip"])
231     text_format.Merge('b: false', priorBox.attr["clip"])
232
233     if i == 0:
234         widths = [0.1, args.min_scale * sqrt(2.0), args.min_scale * sqrt(0.5)]
235         heights = [0.1, args.min_scale / sqrt(2.0), args.min_scale / sqrt(0.5)]
236     else:
237         widths = [scales[i] * sqrt(ar) for ar in args.aspect_ratios]
238         heights = [scales[i] / sqrt(ar) for ar in args.aspect_ratios]
239
240         widths += [sqrt(scales[i] * scales[i + 1])]
241         heights += [sqrt(scales[i] * scales[i + 1])]
242     widths = [w * args.image_width for w in widths]
243     heights = [h * args.image_height for h in heights]
244     text_format.Merge(tensorMsg(widths), priorBox.attr["width"])
245     text_format.Merge(tensorMsg(heights), priorBox.attr["height"])
246     text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), priorBox.attr["variance"])
247
248     graph_def.node.extend([priorBox])
249     priorBoxes.append(priorBox.name)
250
251 addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
252
253 # Sigmoid for classes predictions and DetectionOutput layer
254 sigmoid = NodeDef()
255 sigmoid.name = 'ClassPredictor/concat/sigmoid'
256 sigmoid.op = 'Sigmoid'
257 sigmoid.input.append('ClassPredictor/concat')
258 graph_def.node.extend([sigmoid])
259
260 detectionOut = NodeDef()
261 detectionOut.name = 'detection_out'
262 detectionOut.op = 'DetectionOutput'
263
264 detectionOut.input.append('BoxEncodingPredictor/concat')
265 detectionOut.input.append(sigmoid.name)
266 detectionOut.input.append('PriorBox/concat')
267
268 text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['num_classes'])
269 text_format.Merge('b: true', detectionOut.attr['share_location'])
270 text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
271 text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
272 text_format.Merge('i: 100', detectionOut.attr['top_k'])
273 text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
274 text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
275 text_format.Merge('f: 0.01', detectionOut.attr['confidence_threshold'])
276
277 graph_def.node.extend([detectionOut])
278
279 while True:
280     unconnectedNodes = getUnconnectedNodes()
281     unconnectedNodes.remove(detectionOut.name)
282     if not unconnectedNodes:
283         break
284
285     for name in unconnectedNodes:
286         for i in range(len(graph_def.node)):
287             if graph_def.node[i].name == name:
288                 del graph_def.node[i]
289                 break
290
291 # Save as text.
292 tf.train.write_graph(graph_def, "", args.output, as_text=True)