Fuse tf.nn.l2_normalize layer
[platform/upstream/opencv.git] / modules / dnn / misc / quantize_face_detector.py
1 import argparse
2 import cv2 as cv
3 import tensorflow as tf
4 import numpy as np
5 import struct
6
7 from tensorflow.python.tools import optimize_for_inference_lib
8 from tensorflow.tools.graph_transforms import TransformGraph
9 from tensorflow.core.framework.node_def_pb2 import NodeDef
10 from google.protobuf import text_format
11
12 parser = argparse.ArgumentParser(description="Use this script to create TensorFlow graph "
13                                              "with weights from OpenCV's face detection network. "
14                                              "Only backbone part of SSD model is converted this way. "
15                                              "Look for .pbtxt configuration file at "
16                                              "https://github.com/opencv/opencv_extra/tree/master/testdata/dnn/opencv_face_detector.pbtxt")
17 parser.add_argument('--model', help='Path to .caffemodel weights', required=True)
18 parser.add_argument('--proto', help='Path to .prototxt Caffe model definition', required=True)
19 parser.add_argument('--pb', help='Path to output .pb TensorFlow model', required=True)
20 parser.add_argument('--pbtxt', help='Path to output .pbxt TensorFlow graph', required=True)
21 parser.add_argument('--quantize', help='Quantize weights to uint8', action='store_true')
22 parser.add_argument('--fp16', help='Convert weights to half precision floats', action='store_true')
23 args = parser.parse_args()
24
25 assert(not args.quantize or not args.fp16)
26
27 dtype = tf.float16 if args.fp16 else tf.float32
28
29 ################################################################################
30 cvNet = cv.dnn.readNetFromCaffe(args.proto, args.model)
31
32 def dnnLayer(name):
33     return cvNet.getLayer(long(cvNet.getLayerId(name)))
34
35 def scale(x, name):
36     with tf.variable_scope(name):
37         layer = dnnLayer(name)
38         w = tf.Variable(layer.blobs[0].flatten(), dtype=dtype, name='mul')
39         if len(layer.blobs) > 1:
40             b = tf.Variable(layer.blobs[1].flatten(), dtype=dtype, name='add')
41             return tf.nn.bias_add(tf.multiply(x, w), b)
42         else:
43             return tf.multiply(x, w, name)
44
45 def conv(x, name, stride=1, pad='SAME', dilation=1, activ=None):
46     with tf.variable_scope(name):
47         layer = dnnLayer(name)
48         w = tf.Variable(layer.blobs[0].transpose(2, 3, 1, 0), dtype=dtype, name='weights')
49         if dilation == 1:
50             conv = tf.nn.conv2d(x, filter=w, strides=(1, stride, stride, 1), padding=pad)
51         else:
52             assert(stride == 1)
53             conv = tf.nn.atrous_conv2d(x, w, rate=dilation, padding=pad)
54
55         if len(layer.blobs) > 1:
56             b = tf.Variable(layer.blobs[1].flatten(), dtype=dtype, name='bias')
57             conv = tf.nn.bias_add(conv, b)
58         return activ(conv) if activ else conv
59
60 def batch_norm(x, name):
61     with tf.variable_scope(name):
62         # Unfortunately, TensorFlow's batch normalization layer doesn't work with fp16 input.
63         # Here we do a cast to fp32 but remove it in the frozen graph.
64         if x.dtype != tf.float32:
65             x = tf.cast(x, tf.float32)
66
67         layer = dnnLayer(name)
68         assert(len(layer.blobs) >= 3)
69
70         mean = layer.blobs[0].flatten()
71         std = layer.blobs[1].flatten()
72         scale = layer.blobs[2].flatten()
73
74         eps = 1e-5
75         hasBias = len(layer.blobs) > 3
76         hasWeights = scale.shape != (1,)
77
78         if not hasWeights and not hasBias:
79             mean /= scale[0]
80             std /= scale[0]
81
82         mean = tf.Variable(mean, dtype=tf.float32, name='mean')
83         std = tf.Variable(std, dtype=tf.float32, name='std')
84         gamma = tf.Variable(scale if hasWeights else np.ones(mean.shape), dtype=tf.float32, name='gamma')
85         beta = tf.Variable(layer.blobs[3].flatten() if hasBias else np.zeros(mean.shape), dtype=tf.float32, name='beta')
86         bn = tf.nn.fused_batch_norm(x, gamma, beta, mean, std, eps,
87                                     is_training=False)[0]
88         if bn.dtype != dtype:
89             bn = tf.cast(bn, dtype)
90         return bn
91
92 def l2norm(x, name):
93     with tf.variable_scope(name):
94         layer = dnnLayer(name)
95         w = tf.Variable(layer.blobs[0].flatten(), dtype=dtype, name='mul')
96         return tf.nn.l2_normalize(x, 3, epsilon=1e-10) * w
97
98 ### Graph definition ###########################################################
99 inp = tf.placeholder(dtype, [1, 300, 300, 3], 'data')
100 data_bn = batch_norm(inp, 'data_bn')
101 data_scale = scale(data_bn, 'data_scale')
102 data_scale = tf.pad(data_scale, [[0, 0], [3, 3], [3, 3], [0, 0]])
103 conv1_h = conv(data_scale, stride=2, pad='VALID', name='conv1_h')
104 conv1_bn_h = batch_norm(conv1_h, 'conv1_bn_h')
105 conv1_scale_h = scale(conv1_bn_h, 'conv1_scale_h')
106 conv1_relu = tf.nn.relu(conv1_scale_h)
107 conv1_pool = tf.layers.max_pooling2d(conv1_relu, pool_size=(3, 3), strides=(2, 2),
108                                      padding='SAME', name='conv1_pool')
109
110 layer_64_1_conv1_h = conv(conv1_pool, 'layer_64_1_conv1_h')
111 layer_64_1_bn2_h = batch_norm(layer_64_1_conv1_h, 'layer_64_1_bn2_h')
112 layer_64_1_scale2_h = scale(layer_64_1_bn2_h, 'layer_64_1_scale2_h')
113 layer_64_1_relu2 = tf.nn.relu(layer_64_1_scale2_h)
114 layer_64_1_conv2_h = conv(layer_64_1_relu2, 'layer_64_1_conv2_h')
115 layer_64_1_sum = layer_64_1_conv2_h + conv1_pool
116
117 layer_128_1_bn1_h = batch_norm(layer_64_1_sum, 'layer_128_1_bn1_h')
118 layer_128_1_scale1_h = scale(layer_128_1_bn1_h, 'layer_128_1_scale1_h')
119 layer_128_1_relu1 = tf.nn.relu(layer_128_1_scale1_h)
120 layer_128_1_conv1_h = conv(layer_128_1_relu1, stride=2, name='layer_128_1_conv1_h')
121 layer_128_1_bn2 = batch_norm(layer_128_1_conv1_h, 'layer_128_1_bn2')
122 layer_128_1_scale2 = scale(layer_128_1_bn2, 'layer_128_1_scale2')
123 layer_128_1_relu2 = tf.nn.relu(layer_128_1_scale2)
124 layer_128_1_conv2 = conv(layer_128_1_relu2, 'layer_128_1_conv2')
125 layer_128_1_conv_expand_h = conv(layer_128_1_relu1, stride=2, name='layer_128_1_conv_expand_h')
126 layer_128_1_sum = layer_128_1_conv2 + layer_128_1_conv_expand_h
127
128 layer_256_1_bn1 = batch_norm(layer_128_1_sum, 'layer_256_1_bn1')
129 layer_256_1_scale1 = scale(layer_256_1_bn1, 'layer_256_1_scale1')
130 layer_256_1_relu1 = tf.nn.relu(layer_256_1_scale1)
131 layer_256_1_conv1 = tf.pad(layer_256_1_relu1, [[0, 0], [1, 1], [1, 1], [0, 0]])
132 layer_256_1_conv1 = conv(layer_256_1_conv1, stride=2, pad='VALID', name='layer_256_1_conv1')
133 layer_256_1_bn2 = batch_norm(layer_256_1_conv1, 'layer_256_1_bn2')
134 layer_256_1_scale2 = scale(layer_256_1_bn2, 'layer_256_1_scale2')
135 layer_256_1_relu2 = tf.nn.relu(layer_256_1_scale2)
136 layer_256_1_conv2 = conv(layer_256_1_relu2, 'layer_256_1_conv2')
137 layer_256_1_conv_expand = conv(layer_256_1_relu1, stride=2, name='layer_256_1_conv_expand')
138 layer_256_1_sum = layer_256_1_conv2 + layer_256_1_conv_expand
139
140 layer_512_1_bn1 = batch_norm(layer_256_1_sum, 'layer_512_1_bn1')
141 layer_512_1_scale1 = scale(layer_512_1_bn1, 'layer_512_1_scale1')
142 layer_512_1_relu1 = tf.nn.relu(layer_512_1_scale1)
143 layer_512_1_conv1_h = conv(layer_512_1_relu1, 'layer_512_1_conv1_h')
144 layer_512_1_bn2_h = batch_norm(layer_512_1_conv1_h, 'layer_512_1_bn2_h')
145 layer_512_1_scale2_h = scale(layer_512_1_bn2_h, 'layer_512_1_scale2_h')
146 layer_512_1_relu2 = tf.nn.relu(layer_512_1_scale2_h)
147 layer_512_1_conv2_h = conv(layer_512_1_relu2, dilation=2, name='layer_512_1_conv2_h')
148 layer_512_1_conv_expand_h = conv(layer_512_1_relu1, 'layer_512_1_conv_expand_h')
149 layer_512_1_sum = layer_512_1_conv2_h + layer_512_1_conv_expand_h
150
151 last_bn_h = batch_norm(layer_512_1_sum, 'last_bn_h')
152 last_scale_h = scale(last_bn_h, 'last_scale_h')
153 fc7 = tf.nn.relu(last_scale_h, name='last_relu')
154
155 conv6_1_h = conv(fc7, 'conv6_1_h', activ=tf.nn.relu)
156 conv6_2_h = conv(conv6_1_h, stride=2, name='conv6_2_h', activ=tf.nn.relu)
157 conv7_1_h = conv(conv6_2_h, 'conv7_1_h', activ=tf.nn.relu)
158 conv7_2_h = tf.pad(conv7_1_h, [[0, 0], [1, 1], [1, 1], [0, 0]])
159 conv7_2_h = conv(conv7_2_h, stride=2, pad='VALID', name='conv7_2_h', activ=tf.nn.relu)
160 conv8_1_h = conv(conv7_2_h, pad='SAME', name='conv8_1_h', activ=tf.nn.relu)
161 conv8_2_h = conv(conv8_1_h, pad='SAME', name='conv8_2_h', activ=tf.nn.relu)
162 conv9_1_h = conv(conv8_2_h, 'conv9_1_h', activ=tf.nn.relu)
163 conv9_2_h = conv(conv9_1_h, pad='SAME', name='conv9_2_h', activ=tf.nn.relu)
164
165 conv4_3_norm = l2norm(layer_256_1_relu1, 'conv4_3_norm')
166
167 ### Locations and confidences ##################################################
168 locations = []
169 confidences = []
170 flattenLayersNames = []  # Collect all reshape layers names that should be replaced to flattens.
171 for top, suffix in zip([locations, confidences], ['_mbox_loc', '_mbox_conf']):
172     for bottom, name in zip([conv4_3_norm, fc7, conv6_2_h, conv7_2_h, conv8_2_h, conv9_2_h],
173                             ['conv4_3_norm', 'fc7', 'conv6_2', 'conv7_2', 'conv8_2', 'conv9_2']):
174         name += suffix
175         flat = tf.layers.flatten(conv(bottom, name))
176         flattenLayersNames.append(flat.name[:flat.name.find(':')])
177         top.append(flat)
178
179 mbox_loc = tf.concat(locations, axis=-1, name='mbox_loc')
180 mbox_conf = tf.concat(confidences, axis=-1, name='mbox_conf')
181
182 total = int(np.prod(mbox_conf.shape[1:]))
183 mbox_conf_reshape = tf.reshape(mbox_conf, [-1, 2], name='mbox_conf_reshape')
184 mbox_conf_softmax = tf.nn.softmax(mbox_conf_reshape, name='mbox_conf_softmax')
185 mbox_conf_flatten = tf.reshape(mbox_conf_softmax, [-1, total], name='mbox_conf_flatten')
186 flattenLayersNames.append('mbox_conf_flatten')
187
188 with tf.Session() as sess:
189     sess.run(tf.global_variables_initializer())
190
191     ### Check correctness ######################################################
192     out_nodes = ['mbox_loc', 'mbox_conf_flatten']
193     inp_nodes = [inp.name[:inp.name.find(':')]]
194
195     np.random.seed(2701)
196     inputData = np.random.standard_normal([1, 3, 300, 300]).astype(np.float32)
197
198     cvNet.setInput(inputData)
199     outDNN = cvNet.forward(out_nodes)
200
201     outTF = sess.run([mbox_loc, mbox_conf_flatten], feed_dict={inp: inputData.transpose(0, 2, 3, 1)})
202     print 'Max diff @ locations:  %e' % np.max(np.abs(outDNN[0] - outTF[0]))
203     print 'Max diff @ confidence: %e' % np.max(np.abs(outDNN[1] - outTF[1]))
204
205     # Save a graph
206     graph_def = sess.graph.as_graph_def()
207
208     # Freeze graph. Replaces variables to constants.
209     graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, out_nodes)
210     # Optimize graph. Removes training-only ops, unused nodes.
211     graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, inp_nodes, out_nodes, dtype.as_datatype_enum)
212     # Fuse constant operations.
213     transforms = ["fold_constants(ignore_errors=True)"]
214     if args.quantize:
215         transforms += ["quantize_weights(minimum_size=0)"]
216     transforms += ["sort_by_execution_order"]
217     graph_def = TransformGraph(graph_def, inp_nodes, out_nodes, transforms)
218
219     # By default, float16 weights are stored in repeated tensor's field called
220     # `half_val`. It has type int32 with leading zeros for unused bytes.
221     # This type is encoded by Varint that means only 7 bits are used for value
222     # representation but the last one is indicated the end of encoding. This way
223     # float16 might takes 1 or 2 or 3 bytes depends on value. To impove compression,
224     # we replace all `half_val` values to `tensor_content` using only 2 bytes for everyone.
225     for node in graph_def.node:
226         if 'value' in node.attr:
227             halfs = node.attr["value"].tensor.half_val
228             if not node.attr["value"].tensor.tensor_content and halfs:
229                 node.attr["value"].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)
230                 node.attr["value"].tensor.ClearField('half_val')
231
232     # Serialize
233     with tf.gfile.FastGFile(args.pb, 'wb') as f:
234             f.write(graph_def.SerializeToString())
235
236
237 ################################################################################
238 # Write a text graph representation
239 ################################################################################
240 def tensorMsg(values):
241     msg = 'tensor { dtype: DT_FLOAT tensor_shape { dim { size: %d } }' % len(values)
242     for value in values:
243         msg += 'float_val: %f ' % value
244     return msg + '}'
245
246 # Remove Const nodes and unused attributes.
247 for i in reversed(range(len(graph_def.node))):
248     if graph_def.node[i].op in ['Const', 'Dequantize']:
249         del graph_def.node[i]
250     for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim',
251                  'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training',
252                  'Tpaddings']:
253         if attr in graph_def.node[i].attr:
254             del graph_def.node[i].attr[attr]
255
256 # Append prior box generators
257 min_sizes = [30, 60, 111, 162, 213, 264]
258 max_sizes = [60, 111, 162, 213, 264, 315]
259 steps = [8, 16, 32, 64, 100, 300]
260 aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]]
261 layers = [conv4_3_norm, fc7, conv6_2_h, conv7_2_h, conv8_2_h, conv9_2_h]
262 for i in range(6):
263     priorBox = NodeDef()
264     priorBox.name = 'PriorBox_%d' % i
265     priorBox.op = 'PriorBox'
266     priorBox.input.append(layers[i].name[:layers[i].name.find(':')])
267     priorBox.input.append(inp_nodes[0])  # data
268
269     text_format.Merge('i: %d' % min_sizes[i], priorBox.attr["min_size"])
270     text_format.Merge('i: %d' % max_sizes[i], priorBox.attr["max_size"])
271     text_format.Merge('b: true', priorBox.attr["flip"])
272     text_format.Merge('b: false', priorBox.attr["clip"])
273     text_format.Merge(tensorMsg(aspect_ratios[i]), priorBox.attr["aspect_ratio"])
274     text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), priorBox.attr["variance"])
275     text_format.Merge('f: %f' % steps[i], priorBox.attr["step"])
276     text_format.Merge('f: 0.5', priorBox.attr["offset"])
277     graph_def.node.extend([priorBox])
278
279 # Concatenate prior boxes
280 concat = NodeDef()
281 concat.name = 'mbox_priorbox'
282 concat.op = 'ConcatV2'
283 for i in range(6):
284     concat.input.append('PriorBox_%d' % i)
285 concat.input.append('mbox_loc/axis')
286 graph_def.node.extend([concat])
287
288 # DetectionOutput layer
289 detectionOut = NodeDef()
290 detectionOut.name = 'detection_out'
291 detectionOut.op = 'DetectionOutput'
292
293 detectionOut.input.append('mbox_loc')
294 detectionOut.input.append('mbox_conf_flatten')
295 detectionOut.input.append('mbox_priorbox')
296
297 text_format.Merge('i: 2', detectionOut.attr['num_classes'])
298 text_format.Merge('b: true', detectionOut.attr['share_location'])
299 text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
300 text_format.Merge('f: 0.45', detectionOut.attr['nms_threshold'])
301 text_format.Merge('i: 400', detectionOut.attr['top_k'])
302 text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
303 text_format.Merge('i: 200', detectionOut.attr['keep_top_k'])
304 text_format.Merge('f: 0.01', detectionOut.attr['confidence_threshold'])
305
306 graph_def.node.extend([detectionOut])
307
308 # Replace L2Normalization subgraph onto a single node.
309 for i in reversed(range(len(graph_def.node))):
310     if graph_def.node[i].name in ['conv4_3_norm/l2_normalize/Square',
311                                   'conv4_3_norm/l2_normalize/Sum',
312                                   'conv4_3_norm/l2_normalize/Maximum',
313                                   'conv4_3_norm/l2_normalize/Rsqrt']:
314         del graph_def.node[i]
315 for node in graph_def.node:
316     if node.name == 'conv4_3_norm/l2_normalize':
317         node.op = 'L2Normalize'
318         node.input.pop()
319         node.input.pop()
320         node.input.append(layer_256_1_relu1.name)
321         node.input.append('conv4_3_norm/l2_normalize/Sum/reduction_indices')
322         break
323
324 softmaxShape = NodeDef()
325 softmaxShape.name = 'reshape_before_softmax'
326 softmaxShape.op = 'Const'
327 text_format.Merge(
328 'tensor {'
329 '  dtype: DT_INT32'
330 '  tensor_shape { dim { size: 3 } }'
331 '  int_val: 0'
332 '  int_val: -1'
333 '  int_val: 2'
334 '}', softmaxShape.attr["value"])
335 graph_def.node.extend([softmaxShape])
336
337 for node in graph_def.node:
338     if node.name == 'mbox_conf_reshape':
339         node.input[1] = softmaxShape.name
340     elif node.name == 'mbox_conf_softmax':
341         text_format.Merge('i: 2', node.attr['axis'])
342     elif node.name in flattenLayersNames:
343         node.op = 'Flatten'
344         inpName = node.input[0]
345         node.input.pop()
346         node.input.pop()
347         node.input.append(inpName)
348
349 tf.train.write_graph(graph_def, "", args.pbtxt, as_text=True)