--- /dev/null
+# gen_golden
+
+`gen_golden` is a tool to generate golden data from graph def (pb).
+It generates random inputs and run tensorflow, then save input and output in our h5 format.
+
+## Prerequisite
+
+Install tensorflow 1.x. It is tested with tensorflow 1.14.
+TensorFlow 2.x will not work since `contrib` is removed.
+
+## Usage
+
+```
+Usage: gen_golden.py /path/to/pb
+
+Returns
+ 0 success
+ non-zero failure
+
+Options
+
+positional arguments:
+ graph_def path to graph_def (pb)
+
+ optional arguments:
+ -h, --help show this help message and exit
+ -o,--output set output directory
+
+Examples:
+ gen_golden.py Add_000.py => generate input.h5 and expected.h5 in ./
+ gen_golden.py -o ~/tmp Add_000.py => generate input.h5 and expected.h5 in ~/tmp/
+```
--- /dev/null
+#!/usr/bin/env python3
+
+import tensorflow as tf
+import os
+import sys
+import argparse
+
+
+# cmd arguments parsing
+def usage():
+ script = os.path.basename(os.path.basename(__file__))
+ print("Usage: {} path_to_pb".format(script))
+ sys.exit(-1)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('graph_def', type=str, help='path to graph_def (pb)')
+ parser.add_argument(
+ '-o', '--output', action='store', dest="out_dir", help="output directory")
+ args = parser.parse_args()
+
+ filename = sys.argv[1]
+ filename = args.graph_def
+
+ if args.out_dir:
+ out_dir = args.out_dir + '/'
+ else:
+ out_dir = "./"
+
+ # import graph_def (pb)
+ graph = tf.compat.v1.get_default_graph()
+ graph_def = tf.compat.v1.GraphDef()
+
+ with tf.io.gfile.GFile(filename, 'rb') as f:
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name='')
+
+ # identify inputs and outputs
+ inputs = []
+ inputs = [t for t in tf.contrib.framework.get_placeholders(graph)]
+
+ ops = graph.get_operations()
+ uses = {}
+
+ for op in ops:
+ uses[op.name] = 0
+
+ for op in ops:
+ for producer in op.inputs:
+ uses[producer.name.replace(':0', '')] = int(uses.get(producer.name, 0)) + 1
+
+ output_names = [k for k, v in uses.items() if v == 0]
+
+ # gen random input and run
+ import numpy as np
+ input_tensors = {t.name: np.random.normal(size=t.shape) for t in inputs}
+ output_tensors = [graph.get_tensor_by_name(name + ":0") for name in output_names]
+
+ with tf.compat.v1.Session() as sess:
+ output_values = sess.run(output_tensors, feed_dict=input_tensors)
+
+ # dump input and output in h5
+ import h5py
+ with h5py.File(out_dir + "input.h5", 'w') as hf:
+ name_grp = hf.create_group("name")
+ val_grp = hf.create_group("value")
+ for idx, t in enumerate(inputs):
+ if (t.dtype != "float32"):
+ print("Assertion Failed. Only float32 is supported.")
+ sys.exit(-1)
+ val_grp.create_dataset(str(idx), data=input_tensors[t.name], dtype='>f4')
+ name_grp.attrs[str(idx)] = t.name
+
+ with h5py.File(out_dir + "expected.h5", 'w') as hf:
+ name_grp = hf.create_group("name")
+ val_grp = hf.create_group("value")
+ for idx, t in enumerate(output_tensors):
+ if (t.dtype != "float32"):
+ print("Assertion Failed. Only float32 is supported.")
+ sys.exit(-1)
+ val_grp.create_dataset(str(idx), data=output_values[idx], dtype='>f4')
+ name_grp.attrs[str(idx)] = t.name