[pb2nnpkg] introduce pb_select_graph (#8936)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Thu, 14 Nov 2019 05:27:59 +0000 (14:27 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 14 Nov 2019 05:27:59 +0000 (14:27 +0900)
pb2nnpkg uses python script to select subgraph.
It does not require you to build tensorflow any longer.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
tools/nnpackage_tool/pb2nnpkg/README.md
tools/nnpackage_tool/pb2nnpkg/pb2nnpkg.sh
tools/nnpackage_tool/pb2nnpkg/pb_select_graph.py [new file with mode: 0755]

index b5fa084..1b0e164 100644 (file)
@@ -10,15 +10,8 @@ Install tensorflow >= 1.12. It is tested with tensorflow 1.13, 1.14 and 2.0.
 
 Install node. (Any version will do. I recommend you to use `nvm`.)
 
-Build tensorflow's transform_graph.
-
-```
-$ bazel build tensorflow/tools/graph_transforms:transform_graph
-```
-
 Set environmet variables from usage below.
 
-
 ## usage
 
 ```
@@ -39,7 +32,6 @@ Environment variables:
                    (default=./build/externals/FLATBUFFERS/build/flatc)
    tflite_schema   path to tflite schema (i.e. schema.fbs)
    circle_schema   path to tflite schema (i.e. schema.fbs)
-   tensorflow      path to tensorflow source
 ```
 
 ## example
index 6bef46c..56bff45 100755 (executable)
@@ -65,13 +65,7 @@ inputs=$2
 outputs=$3
 suffix=${3//\//_}
 
-${tensorflow}/bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
---in_graph=$1 \
---out_graph=$name.$suffix.pb \
---inputs=${inputs} \
---outputs=${outputs} \
---transforms='remove_nodes(op=Identity, op=CheckNumerics) strip_unused_nodes'
-
+${script_dir}/pb_select_graph.py $1 $2 $3 $name.$suffix
 tflite_convert --output_file=$name.$suffix.tflite --graph_def_file=$name.$suffix.pb --input_arrays=${inputs} --output_arrays=${outputs}
 ${flatc} --defaults-json --strict-json -t ${tflite_schema} -- $name.$suffix.tflite
 node tools/nnpackage_tool/tflite2circle/fuse_instance_norm.js $name.$suffix.json
diff --git a/tools/nnpackage_tool/pb2nnpkg/pb_select_graph.py b/tools/nnpackage_tool/pb2nnpkg/pb_select_graph.py
new file mode 100755 (executable)
index 0000000..2a30eb2
--- /dev/null
@@ -0,0 +1,52 @@
+#!/usr/bin/env python3
+
+import tensorflow as tf
+import tensorflow.tools.graph_transforms as gt
+import os
+import sys
+import argparse
+
+
+# cmd arguments parsing
+def usage():
+    script = os.path.basename(os.path.basename(__file__))
+    print("Usage: {} path_to_pb".format(script))
+    sys.exit(-1)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('graph_def', type=str, help='path to graph_def (pb)')
+    parser.add_argument('input_names', type=str, help='input tensor names separated by ,')
+    parser.add_argument(
+        'output_names', type=str, help='output tensor names separated by ,')
+    parser.add_argument(
+        'graph_outname', type=str, help='graph_def base name for selected subgraph')
+    parser.add_argument(
+        '-o', '--output', action='store', dest="out_dir", help="output directory")
+    args = parser.parse_args()
+
+    filename = args.graph_def
+    input_names = args.input_names.split(",")
+    output_names = args.output_names.split(",")
+    newfilename = args.graph_outname
+
+    if args.out_dir:
+        out_dir = args.out_dir + '/'
+    else:
+        out_dir = "./"
+
+    # import graph_def (pb)
+    graph = tf.compat.v1.get_default_graph()
+    graph_def = tf.compat.v1.GraphDef()
+
+    with tf.io.gfile.GFile(filename, 'rb') as f:
+        graph_def.ParseFromString(f.read())
+        tf.import_graph_def(graph_def, name='')
+
+    transforms = ['remove_nodes(op=Identity, op=CheckNumerics)', 'strip_unused_nodes']
+
+    selected_graph_def = tf.tools.graph_transforms.TransformGraph(
+        graph_def, input_names, output_names, transforms)
+
+    tf.io.write_graph(selected_graph_def, out_dir, newfilename + ".pb", as_text=False)