Update Mask-RCNN networks generator
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 13 Nov 2018 10:22:39 +0000 (13:22 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 13 Nov 2018 10:22:39 +0000 (13:22 +0300)
samples/dnn/tf_text_graph_mask_rcnn.py

index b92d462..aaefe45 100644 (file)
@@ -38,6 +38,8 @@ aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
 width_stride = float(grid_anchor_generator['width_stride'][0])
 height_stride = float(grid_anchor_generator['height_stride'][0])
 features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
+first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
+first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
 
 print('Number of classes: %d' % num_classes)
 print('Scales:            %s' % str(scales))
@@ -53,7 +55,8 @@ graph_def = parseTextGraph(args.output)
 removeIdentity(graph_def)
 
 def to_remove(name, op):
-    return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
+    return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
+           (name.startswith('CropAndResize') and op != 'CropAndResize')
 
 removeUnusedNodesAndAttrs(to_remove, graph_def)
 
@@ -123,20 +126,22 @@ detectionOut.input.append('proposals')
 detectionOut.addAttr('num_classes', 2)
 detectionOut.addAttr('share_location', True)
 detectionOut.addAttr('background_label_id', 0)
-detectionOut.addAttr('nms_threshold', 0.7)
+detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
 detectionOut.addAttr('top_k', 6000)
 detectionOut.addAttr('code_type', "CENTER_SIZE")
-detectionOut.addAttr('keep_top_k', 100)
+detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
 detectionOut.addAttr('clip', True)
 
 graph_def.node.extend([detectionOut])
 
 # Save as text.
+cropAndResizeNodesNames = []
 for node in reversed(topNodes):
     if node.op != 'CropAndResize':
         graph_def.node.extend([node])
         topNodes.pop()
     else:
+        cropAndResizeNodesNames.append(node.name)
         if numCropAndResize == 1:
             break
         else:
@@ -166,11 +171,15 @@ for i in reversed(range(len(graph_def.node))):
 
     if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
                                   'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
-                                  'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
+                                  'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
+                                  'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
+                                  'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
+                                  'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
         del graph_def.node[i]
 
 for node in graph_def.node:
-    if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
+    if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
+       node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
         node.op = 'Flatten'
         node.input.pop()
 
@@ -178,6 +187,12 @@ for node in graph_def.node:
                      'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
         node.addAttr('loc_pred_transposed', True)
 
+    if node.name.startswith('MaxPool2D'):
+        assert(node.op == 'MaxPool')
+        assert(len(cropAndResizeNodesNames) == 2)
+        node.input = [cropAndResizeNodesNames[0]]
+        del cropAndResizeNodesNames[0]
+
 ################################################################################
 ### Postprocessing
 ################################################################################
@@ -223,6 +238,11 @@ graph_def.node.extend([detectionOut])
 for node in reversed(topNodes):
     graph_def.node.extend([node])
 
+    if node.name.startswith('MaxPool2D'):
+        assert(node.op == 'MaxPool')
+        assert(len(cropAndResizeNodesNames) == 1)
+        node.input = [cropAndResizeNodesNames[0]]
+
 for i in reversed(range(len(graph_def.node))):
     if graph_def.node[i].op == 'CropAndResize':
         graph_def.node[i].input.insert(1, 'detection_out_final')