From 13a39fae36bb4323b3460549e8b49a3beabe621f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=EC=83=81=EA=B7=9C/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 5 Nov 2019 17:12:30 +0900 Subject: [PATCH] [nnpackge_tools] introduce gen_golden.py (#8719) It introduces `gen_golden.py` tool. You can generate golden data with ``` $ gen_golden.py -o out_dir /your/path/to/pb ``` Then, you will get out_dir/{input,expected}.h5. No need to specify inputs and outputs. No need to build nncc. Prerequisite: tensorflow running environment. Signed-off-by: Sanggyu Lee --- tools/nnpackage_tool/gen_golden/README.md | 32 +++++++++++ tools/nnpackage_tool/gen_golden/gen_golden.py | 83 +++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 tools/nnpackage_tool/gen_golden/README.md create mode 100644 tools/nnpackage_tool/gen_golden/gen_golden.py diff --git a/tools/nnpackage_tool/gen_golden/README.md b/tools/nnpackage_tool/gen_golden/README.md new file mode 100644 index 0000000..d24202f --- /dev/null +++ b/tools/nnpackage_tool/gen_golden/README.md @@ -0,0 +1,32 @@ +# gen_golden + +`gen_golden` is a tool to generate golden data from graph def (pb). +It generates random inputs and run tensorflow, then save input and output in our h5 format. + +## Prerequisite + +Install tensorflow 1.x. It is tested with tensorflow 1.14. +TensorFlow 2.x will not work since `contrib` is removed. + +## Usage + +``` +Usage: gen_golden.py /path/to/pb + +Returns + 0 success + non-zero failure + +Options + +positional arguments: + graph_def path to graph_def (pb) + + optional arguments: + -h, --help show this help message and exit + -o,--output set output directory + +Examples: + gen_golden.py Add_000.py => generate input.h5 and expected.h5 in ./ + gen_golden.py -o ~/tmp Add_000.py => generate input.h5 and expected.h5 in ~/tmp/ +``` diff --git a/tools/nnpackage_tool/gen_golden/gen_golden.py b/tools/nnpackage_tool/gen_golden/gen_golden.py new file mode 100644 index 0000000..d1e7bec --- /dev/null +++ b/tools/nnpackage_tool/gen_golden/gen_golden.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +import tensorflow as tf +import os +import sys +import argparse + + +# cmd arguments parsing +def usage(): + script = os.path.basename(os.path.basename(__file__)) + print("Usage: {} path_to_pb".format(script)) + sys.exit(-1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('graph_def', type=str, help='path to graph_def (pb)') + parser.add_argument( + '-o', '--output', action='store', dest="out_dir", help="output directory") + args = parser.parse_args() + + filename = sys.argv[1] + filename = args.graph_def + + if args.out_dir: + out_dir = args.out_dir + '/' + else: + out_dir = "./" + + # import graph_def (pb) + graph = tf.compat.v1.get_default_graph() + graph_def = tf.compat.v1.GraphDef() + + with tf.io.gfile.GFile(filename, 'rb') as f: + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + # identify inputs and outputs + inputs = [] + inputs = [t for t in tf.contrib.framework.get_placeholders(graph)] + + ops = graph.get_operations() + uses = {} + + for op in ops: + uses[op.name] = 0 + + for op in ops: + for producer in op.inputs: + uses[producer.name.replace(':0', '')] = int(uses.get(producer.name, 0)) + 1 + + output_names = [k for k, v in uses.items() if v == 0] + + # gen random input and run + import numpy as np + input_tensors = {t.name: np.random.normal(size=t.shape) for t in inputs} + output_tensors = [graph.get_tensor_by_name(name + ":0") for name in output_names] + + with tf.compat.v1.Session() as sess: + output_values = sess.run(output_tensors, feed_dict=input_tensors) + + # dump input and output in h5 + import h5py + with h5py.File(out_dir + "input.h5", 'w') as hf: + name_grp = hf.create_group("name") + val_grp = hf.create_group("value") + for idx, t in enumerate(inputs): + if (t.dtype != "float32"): + print("Assertion Failed. Only float32 is supported.") + sys.exit(-1) + val_grp.create_dataset(str(idx), data=input_tensors[t.name], dtype='>f4') + name_grp.attrs[str(idx)] = t.name + + with h5py.File(out_dir + "expected.h5", 'w') as hf: + name_grp = hf.create_group("name") + val_grp = hf.create_group("value") + for idx, t in enumerate(output_tensors): + if (t.dtype != "float32"): + print("Assertion Failed. Only float32 is supported.") + sys.exit(-1) + val_grp.create_dataset(str(idx), data=output_values[idx], dtype='>f4') + name_grp.attrs[str(idx)] = t.name -- 2.7.4