[nnpackge_tools] introduce gen_golden.py (#8719)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Tue, 5 Nov 2019 08:12:30 +0000 (17:12 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 5 Nov 2019 08:12:30 +0000 (17:12 +0900)
It introduces `gen_golden.py` tool.
You can generate golden data with

```
$ gen_golden.py -o out_dir /your/path/to/pb
```

Then, you will get out_dir/{input,expected}.h5.

No need to specify inputs and outputs.
No need to build nncc.

Prerequisite: tensorflow running environment.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
tools/nnpackage_tool/gen_golden/README.md [new file with mode: 0644]
tools/nnpackage_tool/gen_golden/gen_golden.py [new file with mode: 0644]

diff --git a/tools/nnpackage_tool/gen_golden/README.md b/tools/nnpackage_tool/gen_golden/README.md
new file mode 100644 (file)
index 0000000..d24202f
--- /dev/null
@@ -0,0 +1,32 @@
+# gen_golden
+
+`gen_golden` is a tool to generate golden data from graph def (pb).
+It generates random inputs and run tensorflow, then save input and output in our h5 format.
+
+## Prerequisite
+
+Install tensorflow 1.x. It is tested with tensorflow 1.14.
+TensorFlow 2.x will not work since `contrib` is removed.
+
+## Usage
+
+```
+Usage: gen_golden.py /path/to/pb
+
+Returns
+     0       success
+  non-zero   failure
+
+Options
+
+positional arguments:
+  graph_def             path to graph_def (pb)
+
+  optional arguments:
+    -h, --help            show this help message and exit
+    -o,--output           set output directory
+
+Examples:
+    gen_golden.py Add_000.py        => generate input.h5 and expected.h5 in ./
+    gen_golden.py -o ~/tmp Add_000.py        => generate input.h5 and expected.h5 in ~/tmp/
+```
diff --git a/tools/nnpackage_tool/gen_golden/gen_golden.py b/tools/nnpackage_tool/gen_golden/gen_golden.py
new file mode 100644 (file)
index 0000000..d1e7bec
--- /dev/null
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+
+import tensorflow as tf
+import os
+import sys
+import argparse
+
+
+# cmd arguments parsing
+def usage():
+    script = os.path.basename(os.path.basename(__file__))
+    print("Usage: {} path_to_pb".format(script))
+    sys.exit(-1)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('graph_def', type=str, help='path to graph_def (pb)')
+    parser.add_argument(
+        '-o', '--output', action='store', dest="out_dir", help="output directory")
+    args = parser.parse_args()
+
+    filename = sys.argv[1]
+    filename = args.graph_def
+
+    if args.out_dir:
+        out_dir = args.out_dir + '/'
+    else:
+        out_dir = "./"
+
+    # import graph_def (pb)
+    graph = tf.compat.v1.get_default_graph()
+    graph_def = tf.compat.v1.GraphDef()
+
+    with tf.io.gfile.GFile(filename, 'rb') as f:
+        graph_def.ParseFromString(f.read())
+        tf.import_graph_def(graph_def, name='')
+
+    # identify inputs and outputs
+    inputs = []
+    inputs = [t for t in tf.contrib.framework.get_placeholders(graph)]
+
+    ops = graph.get_operations()
+    uses = {}
+
+    for op in ops:
+        uses[op.name] = 0
+
+    for op in ops:
+        for producer in op.inputs:
+            uses[producer.name.replace(':0', '')] = int(uses.get(producer.name, 0)) + 1
+
+    output_names = [k for k, v in uses.items() if v == 0]
+
+    # gen random input and run
+    import numpy as np
+    input_tensors = {t.name: np.random.normal(size=t.shape) for t in inputs}
+    output_tensors = [graph.get_tensor_by_name(name + ":0") for name in output_names]
+
+    with tf.compat.v1.Session() as sess:
+        output_values = sess.run(output_tensors, feed_dict=input_tensors)
+
+    # dump input and output in h5
+    import h5py
+    with h5py.File(out_dir + "input.h5", 'w') as hf:
+        name_grp = hf.create_group("name")
+        val_grp = hf.create_group("value")
+        for idx, t in enumerate(inputs):
+            if (t.dtype != "float32"):
+                print("Assertion Failed. Only float32 is supported.")
+                sys.exit(-1)
+            val_grp.create_dataset(str(idx), data=input_tensors[t.name], dtype='>f4')
+            name_grp.attrs[str(idx)] = t.name
+
+    with h5py.File(out_dir + "expected.h5", 'w') as hf:
+        name_grp = hf.create_group("name")
+        val_grp = hf.create_group("value")
+        for idx, t in enumerate(output_tensors):
+            if (t.dtype != "float32"):
+                print("Assertion Failed. Only float32 is supported.")
+                sys.exit(-1)
+            val_grp.create_dataset(str(idx), data=output_values[idx], dtype='>f4')
+            name_grp.attrs[str(idx)] = t.name