From 057236b52af8338938675ad7d1cecfd7b23e6f82 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=EC=A7=80=EC=98=81/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Tue, 31 Jul 2018 15:55:48 +0900 Subject: [PATCH] Add optimizeGraph to model_freezer_util (#2118) This function calls optimize_for_inference of tensorflow and generates '*_optimized.pb'. Signed-off-by: Jiyoung Yun --- .../tensorflow_model_freezer/model_freezer_util.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tools/tensorflow_model_freezer/model_freezer_util.py b/tools/tensorflow_model_freezer/model_freezer_util.py index 85d16b0..3b847f0 100644 --- a/tools/tensorflow_model_freezer/model_freezer_util.py +++ b/tools/tensorflow_model_freezer/model_freezer_util.py @@ -130,6 +130,33 @@ def savePbAndCkpt(sess, directory, fn_prefix): os.path.join(directory, 'checkoiint', fn_prefix + '.ckpt')) +def optimizeGraph(input_graph_path, input_node_name, output_node_name): + ''' this function calls optimize_for_inference of tensorflow and generates '*_optimized.pb'. + + - input_graph_path : must be a path to pb file + - input_node_name : name of input operation node + - output_node_name : name of head(top) operation node + ''' + + (directory, fn, ext) = splitDirFilenameExt(input_graph_path) + output_optimized_graph_path = os.path.join(directory, fn + '_optimized.pb') + + # Optimize for inference + input_graph_def = tf.GraphDef() + with tf.gfile.Open(input_graph_path, "rb") as f: + data = f.read() + input_graph_def.ParseFromString(data) + output_graph_def = optimize_for_inference_lib.optimize_for_inference( + input_graph_def, input_node_name.split(","), output_node_name.split(","), + tf.float32.as_datatype_enum) + + # Save the optimized graph + f = tf.gfile.FastGFile(output_optimized_graph_path, "w") + f.write(output_graph_def.SerializeToString()) + + return output_optimized_graph_path + + # -------- def freezeGraph(input_graph_path, checkpoint_path, output_node_name): ''' this function calls freeze_grapy.py of tensorflow and generates '*_frozen.pb' and '*_frozen.pbtxt'. -- 2.7.4