Install node. (Any version will do. I recommend you to use `nvm`.)
-Build tensorflow's transform_graph.
-
-```
-$ bazel build tensorflow/tools/graph_transforms:transform_graph
-```
-
Set environmet variables from usage below.
-
## usage
```
(default=./build/externals/FLATBUFFERS/build/flatc)
tflite_schema path to tflite schema (i.e. schema.fbs)
circle_schema path to tflite schema (i.e. schema.fbs)
- tensorflow path to tensorflow source
```
## example
outputs=$3
suffix=${3//\//_}
-${tensorflow}/bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
---in_graph=$1 \
---out_graph=$name.$suffix.pb \
---inputs=${inputs} \
---outputs=${outputs} \
---transforms='remove_nodes(op=Identity, op=CheckNumerics) strip_unused_nodes'
-
+${script_dir}/pb_select_graph.py $1 $2 $3 $name.$suffix
tflite_convert --output_file=$name.$suffix.tflite --graph_def_file=$name.$suffix.pb --input_arrays=${inputs} --output_arrays=${outputs}
${flatc} --defaults-json --strict-json -t ${tflite_schema} -- $name.$suffix.tflite
node tools/nnpackage_tool/tflite2circle/fuse_instance_norm.js $name.$suffix.json
--- /dev/null
+#!/usr/bin/env python3
+
+import tensorflow as tf
+import tensorflow.tools.graph_transforms as gt
+import os
+import sys
+import argparse
+
+
+# cmd arguments parsing
+def usage():
+ script = os.path.basename(os.path.basename(__file__))
+ print("Usage: {} path_to_pb".format(script))
+ sys.exit(-1)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('graph_def', type=str, help='path to graph_def (pb)')
+ parser.add_argument('input_names', type=str, help='input tensor names separated by ,')
+ parser.add_argument(
+ 'output_names', type=str, help='output tensor names separated by ,')
+ parser.add_argument(
+ 'graph_outname', type=str, help='graph_def base name for selected subgraph')
+ parser.add_argument(
+ '-o', '--output', action='store', dest="out_dir", help="output directory")
+ args = parser.parse_args()
+
+ filename = args.graph_def
+ input_names = args.input_names.split(",")
+ output_names = args.output_names.split(",")
+ newfilename = args.graph_outname
+
+ if args.out_dir:
+ out_dir = args.out_dir + '/'
+ else:
+ out_dir = "./"
+
+ # import graph_def (pb)
+ graph = tf.compat.v1.get_default_graph()
+ graph_def = tf.compat.v1.GraphDef()
+
+ with tf.io.gfile.GFile(filename, 'rb') as f:
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name='')
+
+ transforms = ['remove_nodes(op=Identity, op=CheckNumerics)', 'strip_unused_nodes']
+
+ selected_graph_def = tf.tools.graph_transforms.TransformGraph(
+ graph_def, input_names, output_names, transforms)
+
+ tf.io.write_graph(selected_graph_def, out_dir, newfilename + ".pb", as_text=False)