From d381948cee5d6e2365693a5d90d328d2b2b7e170 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 4 May 2018 07:55:18 +0300 Subject: [PATCH] Update script to generate MobileNet-SSD V2 text graph --- samples/dnn/tf_text_graph_ssd.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/samples/dnn/tf_text_graph_ssd.py b/samples/dnn/tf_text_graph_ssd.py index 178e8de..57c3e04 100644 --- a/samples/dnn/tf_text_graph_ssd.py +++ b/samples/dnn/tf_text_graph_ssd.py @@ -13,6 +13,7 @@ import tensorflow as tf import argparse from math import sqrt from tensorflow.core.framework.node_def_pb2 import NodeDef +from tensorflow.tools.graph_transforms import TransformGraph from google.protobuf import text_format parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' @@ -32,7 +33,7 @@ args = parser.parse_args() # Nodes that should be kept. keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm', - 'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool'] + 'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity'] # Nodes attributes that could be removed because they are not used during import. unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', @@ -46,6 +47,10 @@ with tf.gfile.FastGFile(args.input, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) +inpNames = ['image_tensor'] +outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'] +graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order']) + def getUnconnectedNodes(): unconnected = [] for node in graph_def.node: @@ -98,6 +103,7 @@ def removeIdentity(): for node in graph_def.node: if node.op == 'Identity': identities[node.name] = node.input[0] + graph_def.node.remove(node) for node in graph_def.node: for i in range(len(node.input)): -- 2.7.4