Update script to generate MobileNet-SSD V2 text graph
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 4 May 2018 04:55:18 +0000 (07:55 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 4 May 2018 04:55:18 +0000 (07:55 +0300)
samples/dnn/tf_text_graph_ssd.py

index 178e8de..57c3e04 100644 (file)
@@ -13,6 +13,7 @@ import tensorflow as tf
 import argparse
 from math import sqrt
 from tensorflow.core.framework.node_def_pb2 import NodeDef
+from tensorflow.tools.graph_transforms import TransformGraph
 from google.protobuf import text_format
 
 parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
@@ -32,7 +33,7 @@ args = parser.parse_args()
 
 # Nodes that should be kept.
 keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
-           'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool']
+           'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity']
 
 # Nodes attributes that could be removed because they are not used during import.
 unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
@@ -46,6 +47,10 @@ with tf.gfile.FastGFile(args.input, 'rb') as f:
     graph_def = tf.GraphDef()
     graph_def.ParseFromString(f.read())
 
+inpNames = ['image_tensor']
+outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
+graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order'])
+
 def getUnconnectedNodes():
     unconnected = []
     for node in graph_def.node:
@@ -98,6 +103,7 @@ def removeIdentity():
     for node in graph_def.node:
         if node.op == 'Identity':
             identities[node.name] = node.input[0]
+            graph_def.node.remove(node)
 
     for node in graph_def.node:
         for i in range(len(node.input)):