import tensorflow as tf
import model_freezer_util as util
+
def convert(checkpoint_dir, checkpoint_file_path):
- meta_path = os.path.join(checkpoint_file_path+'.meta') # Your .meta file
+ meta_path = os.path.join(checkpoint_file_path + '.meta') # Your .meta file
output_node_name = 'Model/concat'
- output_node_names = [output_node_name] # Output nodes
+ output_node_names = [output_node_name] # Output nodes
with tf.Session() as sess:
-
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
- saver.restore(sess,tf.train.latest_checkpoint(checkpoint_dir))
+ saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
# save the graph into pb
saved_graph_def = tf.graph_util.convert_variables_to_constants(
- sess,
- sess.graph_def,
- output_node_names)
+ sess, sess.graph_def, output_node_names)
pb_path = os.path.join(checkpoint_dir, 'graph.pb')
with open(pb_path, 'wb') as f:
# freeze
(frozen_pb_path, frozen_pbtxt_path) = util.freezeGraph(pb_path, checkpoint_file_path,
- output_node_name)
+ output_node_name)
print("Freeze() Finished. Created :")
print("\t-{}\n\t-{}\n".format(frozen_pb_path, frozen_pbtxt_path))
# tensor board
- tensorboardLogDir = util.generateTensorboardLog(
- [frozen_pb_path], [''],
- os.path.join(checkpoint_dir, ".tensorboard"))
+ tensorboardLogDir = util.generateTensorboardLog([frozen_pb_path], [''],
+ os.path.join(
+ checkpoint_dir, ".tensorboard"))
print("")
print(
- "\t# Tensorboard: You can view original graph and frozen graph with tensorboard."
- )
+ "\t# Tensorboard: You can view original graph and frozen graph with tensorboard.")
print("\t Run the following:")
print("\t $ tensorboard --logdir={} ".format(tensorboardLogDir))
help=
"directory where checkpoint files are located. pb, pbtxt will also be generated into this folder."
)
- parser.add_argument(
- "checkpoint_file_name,
- help=
- "name of checkpoint file"
- )
+ parser.add_argument("checkpoint_file_name", help="name of checkpoint file")
args = parser.parse_args()
checkpoint_dir = args.checkpoint_dir
checkpoint_file_path = os.path.join(checkpoint_dir, args.checkpoint_file_name)
- convert(checkpoint_dir, checkpoint_file_path):
+ convert(checkpoint_dir, checkpoint_file_path)