width_stride = float(grid_anchor_generator['width_stride'][0])
height_stride = float(grid_anchor_generator['height_stride'][0])
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
+first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
+first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
print('Number of classes: %d' % num_classes)
print('Scales: %s' % str(scales))
removeIdentity(graph_def)
def to_remove(name, op):
- return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
+ return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
+ (name.startswith('CropAndResize') and op != 'CropAndResize')
removeUnusedNodesAndAttrs(to_remove, graph_def)
detectionOut.addAttr('num_classes', 2)
detectionOut.addAttr('share_location', True)
detectionOut.addAttr('background_label_id', 0)
-detectionOut.addAttr('nms_threshold', 0.7)
+detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
detectionOut.addAttr('top_k', 6000)
detectionOut.addAttr('code_type', "CENTER_SIZE")
-detectionOut.addAttr('keep_top_k', 100)
+detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
detectionOut.addAttr('clip', True)
graph_def.node.extend([detectionOut])
# Save as text.
+cropAndResizeNodesNames = []
for node in reversed(topNodes):
if node.op != 'CropAndResize':
graph_def.node.extend([node])
topNodes.pop()
else:
+ cropAndResizeNodesNames.append(node.name)
if numCropAndResize == 1:
break
else:
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
- 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
+ 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
+ 'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
+ 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
+ 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
del graph_def.node[i]
for node in graph_def.node:
- if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
+ if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
+ node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
node.op = 'Flatten'
node.input.pop()
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
node.addAttr('loc_pred_transposed', True)
+ if node.name.startswith('MaxPool2D'):
+ assert(node.op == 'MaxPool')
+ assert(len(cropAndResizeNodesNames) == 2)
+ node.input = [cropAndResizeNodesNames[0]]
+ del cropAndResizeNodesNames[0]
+
################################################################################
### Postprocessing
################################################################################
for node in reversed(topNodes):
graph_def.node.extend([node])
+ if node.name.startswith('MaxPool2D'):
+ assert(node.op == 'MaxPool')
+ assert(len(cropAndResizeNodesNames) == 1)
+ node.input = [cropAndResizeNodesNames[0]]
+
for i in reversed(range(len(graph_def.node))):
if graph_def.node[i].op == 'CropAndResize':
graph_def.node[i].input.insert(1, 'detection_out_final')