Fix checkpoint to pb file converter bug (#2022)
author오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 19 Jul 2018 09:44:10 +0000 (18:44 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 19 Jul 2018 09:44:10 +0000 (18:44 +0900)
Fix checkpoint to pb file converter bug: string, colon
Fix format

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
tools/pbfile_tool/convert_ckpt_to_pb.py

index dc72075..cd43143 100644 (file)
@@ -21,26 +21,24 @@ import argparse
 import tensorflow as tf
 import model_freezer_util as util
 
+
 def convert(checkpoint_dir, checkpoint_file_path):
 
-    meta_path = os.path.join(checkpoint_file_path+'.meta') # Your .meta file
+    meta_path = os.path.join(checkpoint_file_path + '.meta')  # Your .meta file
     output_node_name = 'Model/concat'
-    output_node_names = [output_node_name]    # Output nodes
+    output_node_names = [output_node_name]  # Output nodes
 
     with tf.Session() as sess:
 
-
         # Restore the graph
         saver = tf.train.import_meta_graph(meta_path)
 
         # Load weights
-        saver.restore(sess,tf.train.latest_checkpoint(checkpoint_dir))
+        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
 
         # save the graph into pb
         saved_graph_def = tf.graph_util.convert_variables_to_constants(
-            sess,
-            sess.graph_def,
-            output_node_names)
+            sess, sess.graph_def, output_node_names)
 
         pb_path = os.path.join(checkpoint_dir, 'graph.pb')
         with open(pb_path, 'wb') as f:
@@ -48,20 +46,19 @@ def convert(checkpoint_dir, checkpoint_file_path):
 
     # freeze
     (frozen_pb_path, frozen_pbtxt_path) = util.freezeGraph(pb_path, checkpoint_file_path,
-                                                            output_node_name)
+                                                           output_node_name)
 
     print("Freeze() Finished. Created :")
     print("\t-{}\n\t-{}\n".format(frozen_pb_path, frozen_pbtxt_path))
 
     # tensor board
-    tensorboardLogDir = util.generateTensorboardLog(
-        [frozen_pb_path], [''],
-        os.path.join(checkpoint_dir, ".tensorboard"))
+    tensorboardLogDir = util.generateTensorboardLog([frozen_pb_path], [''],
+                                                    os.path.join(
+                                                        checkpoint_dir, ".tensorboard"))
 
     print("")
     print(
-        "\t# Tensorboard: You can view original graph and frozen graph with tensorboard."
-    )
+        "\t# Tensorboard: You can view original graph and frozen graph with tensorboard.")
     print("\t  Run the following:")
     print("\t  $ tensorboard --logdir={} ".format(tensorboardLogDir))
 
@@ -74,14 +71,10 @@ if __name__ == "__main__":
         help=
         "directory where checkpoint files are located. pb, pbtxt will also be generated into this folder."
     )
-    parser.add_argument(
-        "checkpoint_file_name,
-        help=
-        "name of checkpoint file"
-    )
+    parser.add_argument("checkpoint_file_name", help="name of checkpoint file")
 
     args = parser.parse_args()
     checkpoint_dir = args.checkpoint_dir
     checkpoint_file_path = os.path.join(checkpoint_dir, args.checkpoint_file_name)
 
-    convert(checkpoint_dir, checkpoint_file_path):
+    convert(checkpoint_dir, checkpoint_file_path)