From 3c5af56e9af3cd28c766281c34f0082ea82636ec Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=98=A4=ED=98=95=EC=84=9D/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 19 Jul 2018 18:44:10 +0900 Subject: [PATCH] Fix checkpoint to pb file converter bug (#2022) Fix checkpoint to pb file converter bug: string, colon Fix format Signed-off-by: Hyeongseok Oh --- tools/pbfile_tool/convert_ckpt_to_pb.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/tools/pbfile_tool/convert_ckpt_to_pb.py b/tools/pbfile_tool/convert_ckpt_to_pb.py index dc72075..cd43143 100644 --- a/tools/pbfile_tool/convert_ckpt_to_pb.py +++ b/tools/pbfile_tool/convert_ckpt_to_pb.py @@ -21,26 +21,24 @@ import argparse import tensorflow as tf import model_freezer_util as util + def convert(checkpoint_dir, checkpoint_file_path): - meta_path = os.path.join(checkpoint_file_path+'.meta') # Your .meta file + meta_path = os.path.join(checkpoint_file_path + '.meta') # Your .meta file output_node_name = 'Model/concat' - output_node_names = [output_node_name] # Output nodes + output_node_names = [output_node_name] # Output nodes with tf.Session() as sess: - # Restore the graph saver = tf.train.import_meta_graph(meta_path) # Load weights - saver.restore(sess,tf.train.latest_checkpoint(checkpoint_dir)) + saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) # save the graph into pb saved_graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph_def, - output_node_names) + sess, sess.graph_def, output_node_names) pb_path = os.path.join(checkpoint_dir, 'graph.pb') with open(pb_path, 'wb') as f: @@ -48,20 +46,19 @@ def convert(checkpoint_dir, checkpoint_file_path): # freeze (frozen_pb_path, frozen_pbtxt_path) = util.freezeGraph(pb_path, checkpoint_file_path, - output_node_name) + output_node_name) print("Freeze() Finished. Created :") print("\t-{}\n\t-{}\n".format(frozen_pb_path, frozen_pbtxt_path)) # tensor board - tensorboardLogDir = util.generateTensorboardLog( - [frozen_pb_path], [''], - os.path.join(checkpoint_dir, ".tensorboard")) + tensorboardLogDir = util.generateTensorboardLog([frozen_pb_path], [''], + os.path.join( + checkpoint_dir, ".tensorboard")) print("") print( - "\t# Tensorboard: You can view original graph and frozen graph with tensorboard." - ) + "\t# Tensorboard: You can view original graph and frozen graph with tensorboard.") print("\t Run the following:") print("\t $ tensorboard --logdir={} ".format(tensorboardLogDir)) @@ -74,14 +71,10 @@ if __name__ == "__main__": help= "directory where checkpoint files are located. pb, pbtxt will also be generated into this folder." ) - parser.add_argument( - "checkpoint_file_name, - help= - "name of checkpoint file" - ) + parser.add_argument("checkpoint_file_name", help="name of checkpoint file") args = parser.parse_args() checkpoint_dir = args.checkpoint_dir checkpoint_file_path = os.path.join(checkpoint_dir, args.checkpoint_file_name) - convert(checkpoint_dir, checkpoint_file_path): + convert(checkpoint_dir, checkpoint_file_path) -- 2.7.4