Add optimizeGraph to model_freezer_util (#2118)
author윤지영/동작제어Lab(SR)/Engineer/삼성전자 <jy910.yun@samsung.com>
Tue, 31 Jul 2018 06:55:48 +0000 (15:55 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 31 Jul 2018 06:55:48 +0000 (15:55 +0900)
This function calls optimize_for_inference of tensorflow and
generates '*_optimized.pb'.

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
tools/tensorflow_model_freezer/model_freezer_util.py

index 85d16b0..3b847f0 100644 (file)
@@ -130,6 +130,33 @@ def savePbAndCkpt(sess, directory, fn_prefix):
             os.path.join(directory, 'checkoiint', fn_prefix + '.ckpt'))
 
 
+def optimizeGraph(input_graph_path, input_node_name, output_node_name):
+    ''' this function calls optimize_for_inference of tensorflow and generates '*_optimized.pb'.
+
+      - input_graph_path : must be a path to pb file
+      - input_node_name  : name of input operation node
+      - output_node_name : name of head(top) operation node
+    '''
+
+    (directory, fn, ext) = splitDirFilenameExt(input_graph_path)
+    output_optimized_graph_path = os.path.join(directory, fn + '_optimized.pb')
+
+    # Optimize for inference
+    input_graph_def = tf.GraphDef()
+    with tf.gfile.Open(input_graph_path, "rb") as f:
+        data = f.read()
+        input_graph_def.ParseFromString(data)
+        output_graph_def = optimize_for_inference_lib.optimize_for_inference(
+            input_graph_def, input_node_name.split(","), output_node_name.split(","),
+            tf.float32.as_datatype_enum)
+
+    # Save the optimized graph
+    f = tf.gfile.FastGFile(output_optimized_graph_path, "w")
+    f.write(output_graph_def.SerializeToString())
+
+    return output_optimized_graph_path
+
+
 # --------
 def freezeGraph(input_graph_path, checkpoint_path, output_node_name):
     ''' this function calls freeze_grapy.py of tensorflow and generates '*_frozen.pb' and '*_frozen.pbtxt'.